CNN model with image augmentation

Let's improve upon our regularized CNN model by adding in more data using a proper image augmentation strategy. Since our previous model was trained on the same small sample of data points each time, it wasn't able to generalize well and ended up overfitting after a few epochs.

The idea behind image augmentation is that we follow a set process of taking in existing images from our training dataset and applying some image transformation operations to them, such as rotation, shearing, translation, zooming, and so on, to produce new, altered versions of existing images. Due to these random transformations, we don't get the same images each time, and we will leverage Python generators to feed in these new images to our model during training.

The Keras framework has an excellent utility called ImageDataGenerator that can help us in doing all the preceding operations. Let's initialize two of the data generators for our training and validation datasets:

train_datagen = ImageDataGenerator(rescale=1./255, zoom_range=0.3,  
                                   rotation_range=50, 
                                   width_shift_range=0.2,  
                                   height_shift_range=0.2,   
                                   shear_range=0.2,  
                                   horizontal_flip=True,   
                                   fill_mode='nearest') 
 
val_datagen = ImageDataGenerator(rescale=1./255) 

There are a lot of options available in ImageDataGenerator and we have just utilized a few of them. Feel free to check out the documentation at https://keras.io/preprocessing/image/ to get a more detailed perspective. In our training data generator, we take in the raw images and then perform several transformations on them to generate new images. These include the following:

  • Zooming the image randomly by a factor of 0.3 using the zoom_range parameter.
  • Rotating the image randomly by 50 degrees using the rotation_range parameter.
  • Translating the image randomly horizontally or vertically by a 0.2 factor of the image's width or height using the width_shift_range and the height_shift_range parameters.
  • Applying shear-based transformations randomly using the shear_range parameter.
  • Randomly flipping half of the images horizontally using the horizontal_flip parameter.
  • Leveraging the fill_mode parameter to fill in new pixels for images after we apply any of the preceding operations (especially rotation or translation). In this case, we just fill in the new pixels with their nearest surrounding pixel values.

Let's see how some of these generated images might look so that you can understand them better. We will take two sample images from our training dataset to illustrate the same. The first image is an image of a cat:

img_id = 2595 
cat_generator = train_datagen.flow(train_imgs[img_id:img_id+1],  
                                   train_labels[img_id:img_id+1], 
                                   batch_size=1) 
cat = [next(cat_generator) for i in range(0,5)] 
fig, ax = plt.subplots(1,5, figsize=(16, 6)) 
print('Labels:', [item[1][0] for item in cat]) 
l = [ax[i].imshow(cat[i][0][0]) for i in range(0,5)] 

You can clearly see in the following output that we generate a new version of our training image each time (with translations, rotations, and zoom) and also we assign a label of cat to it so that the model can extract relevant features from these images and also remember that these are cats:

Let's look at an image that is a dog now:

img_id = 1991 
dog_generator = train_datagen.flow(train_imgs[img_id:img_id+1],  
                                   train_labels[img_id:img_id+1], 
                                   batch_size=1) 
dog = [next(dog_generator) for i in range(0,5)] 
fig, ax = plt.subplots(1,5, figsize=(15, 6)) 
print('Labels:', [item[1][0] for item in dog]) 
l = [ax[i].imshow(dog[i][0][0]) for i in range(0,5)] 

This shows us how image augmentation helps in creating new images, and how training a model on them should help in combating overfitting:

Remember for our validation generator, we just need to send the validation images (original ones) to the model for evaluation; hence, we just scale the image pixels (between 0-1) and do not apply any transformations. We just apply image augmentation transformations only on our training images:

train_generator = train_datagen.flow(train_imgs, train_labels_enc,  
                                     batch_size=30) 
val_generator = val_datagen.flow(validation_imgs,   
                                 validation_labels_enc,   
                                 batch_size=20) 
 
input_shape = (150, 150, 3)  

Let's now train a CNN model with regularization using the image augmentation data generators we created. We will use the same model architecture from before:

from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout 
from keras.models import Sequential 
from keras import optimizers 
 
model = Sequential() 
# convolution and pooling layers 
model.add(Conv2D(16, kernel_size=(3, 3), activation='relu',  
                 input_shape=input_shape)) 
model.add(MaxPooling2D(pool_size=(2, 2))) 
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu')) 
model.add(MaxPooling2D(pool_size=(2, 2))) 
model.add(Conv2D(128, kernel_size=(3, 3), activation='relu')) 
model.add(MaxPooling2D(pool_size=(2, 2))) 
model.add(Conv2D(128, kernel_size=(3, 3), activation='relu')) 
model.add(MaxPooling2D(pool_size=(2, 2))) 
 
model.add(Flatten()) 
model.add(Dense(512, activation='relu')) 
model.add(Dropout(0.3)) 
model.add(Dense(512, activation='relu')) 
model.add(Dropout(0.3)) 
model.add(Dense(1, activation='sigmoid')) 
 
model.compile(loss='binary_crossentropy', 
              optimizer=optimizers.RMSprop(lr=1e-4), 
              metrics=['accuracy']) 

We reduce the default learning rate by a factor of 10 here for our optimizer to prevent the model from getting stuck in a local minima or overfit, as we will be sending a lot of images with random transformations. To train the model, we need to slightly modify our approach now, since we are using data generators. We will leverage the fit_generator(...) function from Keras to train this model. The train_generator generates 30 images each time, so we will use the steps_per_epoch parameter and set it to 100 to train the model on 3,000 randomly generated images from the training data for each epoch. Our val_generator generates 20 images each time so we will set the validation_steps parameter to 50 to validate our model accuracy on all the 1,000 validation images (remember we are not augmenting our validation dataset):

history = model.fit_generator(train_generator,  
                              steps_per_epoch=100, epochs=100, 
                              validation_data=val_generator,  
                              validation_steps=50, verbose=1) 
 
Epoch 1/100 
100/100 - 12s - loss: 0.6924 - acc: 0.5113 - val_loss: 0.6943 - val_acc: 0.5000 
Epoch 2/100 
100/100 - 11s - loss: 0.6855 - acc: 0.5490 - val_loss: 0.6711 - val_acc: 0.5780 
... 
... 
Epoch 99/100 
100/100 - 11s - loss: 0.3735 - acc: 0.8367 - val_loss: 0.4425 - val_acc: 0.8340 
Epoch 100/100 
100/100 - 11s - loss: 0.3733 - acc: 0.8257 - val_loss: 0.4046 - val_acc: 0.8200 

We get a validation accuracy jump to around 82%, which is almost 4-5% better than our previous model. Also, our training accuracy is very similar to our validation accuracy, indicating our model isn't overfitting anymore. The following depict the model accuracy and loss per epoch:

While there are some spikes in the validation accuracy and loss, overall, we see that it is much closer to the training accuracy, with the loss indicating that we obtained a model that generalizes much better as compared to our previous models. Let's save this model now so we can evaluate it later on our test dataset:

model.save('cats_dogs_cnn_img_aug.h5') 

We will now try and leverage the power of transfer learning to see if we can build a better model.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset