Define the batch size and the number of epochs and initialize all the TensorFlow variables:
num_epochs = 100
session = tf.InteractiveSession()
session.run(tf.global_variables_initializer())
Define a helper function for visualizing results:
def plot(c, x):
c_ = np.argmax(c, 1)
sort_indices = np.argsort(c_, 0)
x_reshape = np.reshape(x[sort_indices], [batch_size, 28, 28])
x_reshape = np.reshape( np.expand_dims(x_reshape, axis=0), [4, (batch_size // 4), 28, 28])
values = []
for i in range(0,4):
row = np.concatenate( [x_reshape[i,j,:,:] for j in range(0,(batch_size // 4))], axis=1)
values.append(row)
return np.concatenate(values, axis=0)