Model training

In this section, we'll kick off the training process. We'll use the helper function of the mnist_dataset object in order to get a random batch from the dataset with a specific size; then we'll run the optimizer on this batch of images.

Let's start this section by creating the session variable, which will be responsible for executing the computational graph that we defined earlier:

# creating the session
sess = tf.Session()

Next up, let's kick off the training process:

num_epochs = 20
train_batch_size = 200

sess.run(tf.global_variables_initializer())
for e in range(num_epochs):
for ii in range(mnist_dataset.train.num_examples//train_batch_size):
input_batch = mnist_dataset.train.next_batch(train_batch_size)
feed_dict = {inputs_values: input_batch[0], targets_values: input_batch[0]}
input_batch_cost, _ = sess.run([model_cost, model_optimizier], feed_dict=feed_dict)

print("Epoch: {}/{}...".format(e+1, num_epochs),
"Training loss: {:.3f}".format(input_batch_cost))
Output:
.
.
.
Epoch: 20/20... Training loss: 0.091
Epoch: 20/20... Training loss: 0.091
Epoch: 20/20... Training loss: 0.093
Epoch: 20/20... Training loss: 0.093
Epoch: 20/20... Training loss: 0.095
Epoch: 20/20... Training loss: 0.095
Epoch: 20/20... Training loss: 0.089
Epoch: 20/20... Training loss: 0.095
Epoch: 20/20... Training loss: 0.095
Epoch: 20/20... Training loss: 0.096
Epoch: 20/20... Training loss: 0.094
Epoch: 20/20... Training loss: 0.093
Epoch: 20/20... Training loss: 0.094
Epoch: 20/20... Training loss: 0.093
Epoch: 20/20... Training loss: 0.095
Epoch: 20/20... Training loss: 0.094
Epoch: 20/20... Training loss: 0.096
Epoch: 20/20... Training loss: 0.092
Epoch: 20/20... Training loss: 0.093
Epoch: 20/20... Training loss: 0.091
Epoch: 20/20... Training loss: 0.093
Epoch: 20/20... Training loss: 0.091
Epoch: 20/20... Training loss: 0.095
Epoch: 20/20... Training loss: 0.094
Epoch: 20/20... Training loss: 0.091
Epoch: 20/20... Training loss: 0.096
Epoch: 20/20... Training loss: 0.089
Epoch: 20/20... Training loss: 0.090
Epoch: 20/20... Training loss: 0.094
Epoch: 20/20... Training loss: 0.088
Epoch: 20/20... Training loss: 0.094
Epoch: 20/20... Training loss: 0.093
Epoch: 20/20... Training loss: 0.091
Epoch: 20/20... Training loss: 0.095
Epoch: 20/20... Training loss: 0.093
Epoch: 20/20... Training loss: 0.091
Epoch: 20/20... Training loss: 0.094
Epoch: 20/20... Training loss: 0.090
Epoch: 20/20... Training loss: 0.091
Epoch: 20/20... Training loss: 0.095
Epoch: 20/20... Training loss: 0.095
Epoch: 20/20... Training loss: 0.094
Epoch: 20/20... Training loss: 0.092
Epoch: 20/20... Training loss: 0.092
Epoch: 20/20... Training loss: 0.093
Epoch: 20/20... Training loss: 0.093

After running the preceding code snippet for 20 epochs, we will get a trained model that is able to generate or reconstruct images from the test set of the MNIST data. Bear in mind that if we feed images that are not similar to the ones that the model was trained on, then the reconstruction process just won't work because autoencoders are data-specific.

Let's test the trained model by feeding some images from the test set and see how the model is able to reconstruct them in the decoder part:

fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(20,4))

input_images = mnist_dataset.test.images[:10]
reconstructed_images, compressed_images = sess.run([decoding_layer, encoding_layer], feed_dict={inputs_values: input_images})

for imgs, row in zip([input_images, reconstructed_images], axes):
for img, ax in zip(imgs, row):
ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)

fig.tight_layout(pad=0.1)

Output:

Figure 7: Examples of the original test images (first row) and their constructions (second row)

As you can see, the reconstructed images are very close to the input ones, but we can probably get better images using convolution layers in the encoder-decoder part.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset