CNN basic example – MNIST digit classification

In this section, we will do a complete example of implementing a CNN for digit classification using the MNIST dataset. We will build a simple model of two convolution layers and fully connected layers.

Let's start off by importing the libraries that will be needed for this implementation:

%matplotlib inline
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix
import math

Next, we will use TensorFlow helper functions to download and preprocess the MNIST dataset as follows:

from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('data/MNIST/', one_hot=True)
Output:
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting data/MNIST/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting data/MNIST/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting data/MNIST/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting data/MNIST/t10k-labels-idx1-ubyte.gz

The dataset is split into three disjoint sets: training, validation, and testing. So, let's print the number of images in each set:

print("- Number of images in the training set:		{}".format(len(mnist_data.train.labels)))
print("- Number of images in the test set: {}".format(len(mnist_data.test.labels)))
print("- Number of images in the validation set: {}".format(len(mnist_data.validation.labels)))
- Number of images in the training set: 55000
- Number of images in the test set: 10000
- Number of images in the validation set: 5000

The actual labels of the images are stored in a one-hot encoding format, so we have an array of 10 values of zeros except for the index of the class that this image represents. For later use, we need to get the class numbers of the dataset as integers:

mnist_data.test.cls_integer = np.argmax(mnist_data.test.labels, axis=1)

Let's define some known variables to be used later in our implementation:

# Default size for the input monocrome images of MNIST
image_size = 28

# Each image is stored as vector of this size.
image_size_flat = image_size * image_size

# The shape of each image
image_shape = (image_size, image_size)

# All the images in the mnist dataset are stored as a monocrome with only 1 channel
num_channels = 1

# Number of classes in the MNIST dataset from 0 till 9 which is 10
num_classes = 10

Next, we need to define a helper function to plot some images from the dataset. This helper function will plot the images in a grid of nine subplots:

def plot_imgs(imgs, cls_actual, cls_predicted=None):
assert len(imgs) == len(cls_actual) == 9

# create a figure with 9 subplots to plot the images.
fig, axes = plt.subplots(3, 3)
fig.subplots_adjust(hspace=0.3, wspace=0.3)

for i, ax in enumerate(axes.flat):
# plot the image at the ith index
ax.imshow(imgs[i].reshape(image_shape), cmap='binary')

# labeling the images with the actual and predicted classes.
if cls_predicted is None:
xlabel = "True: {0}".format(cls_actual[i])
else:
xlabel = "True: {0}, Pred: {1}".format(cls_actual[i], cls_predicted[i])

# Remove ticks from the plot.
ax.set_xticks([])
ax.set_yticks([])

# Show the classes as the label on the x-axis.
ax.set_xlabel(xlabel)


plt.show()

Let's plot some images from the test set and see what it looks like:

# Visualizing 9 images form the test set.
imgs = mnist_data.test.images[0:9]

# getting the actual classes of these 9 images
cls_actual = mnist_data.test.cls_integer[0:9]

#plotting the images
plot_imgs(imgs=imgs, cls_actual=cls_actual)

Here is the output:

Figure 9.12: A visualization of some examples from the MNIST dataset
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset