Load the CIFAR dataset:
(x_train, y_train), _ = cifar10.load_data()
x_train = x_train.astype('float32')/255.0
Let's see what we have in our dataset. Define a helper function for plotting the image:
def plot_images(X):
plt.figure(1)
z = 0
for i in range(0,4):
for j in range(0,4):
plt.subplot2grid((4,4),(i,j))
plt.imshow(toimage(X[z]))
z = z + 1
plt.show()
Let's plot a few images:
plot_images(x_train[:17])
The plotted images are shown as follows: