Finally, let's go ahead and kick off the training process after putting it all together:
class GAN:
def __init__(self, real_size, z_size, learning_rate, num_classes=10, alpha=0.2, beta1=0.5):
tf.reset_default_graph()
self.learning_rate = tf.Variable(learning_rate, trainable=False)
model_inputs = inputs(real_size, z_size)
self.input_actual, self.input_latent_z, self.target, self.label_mask = model_inputs
self.drop_out_rate = tf.placeholder_with_default(.5, (), "drop_out_rate")
losses_results = model_losses(self.input_actual, self.input_latent_z,
real_size[2], self.target, num_classes,
label_mask=self.label_mask,
leaky_alpha=0.2,
drop_out_rate=self.drop_out_rate)
self.disc_loss, self.gen_loss, self.correct, self.masked_correct, self.samples = losses_results
self.disc_opt, self.gen_opt, self.shrink_learning_rate = model_optimizer(self.disc_loss, self.gen_loss,
self.learning_rate, beta1)
def view_generated_samples(epoch, samples, nrows, ncols, figsize=(5, 5)):
fig, axes = plt.subplots(figsize=figsize, nrows=nrows, ncols=ncols,
sharey=True, sharex=True)
for ax, img in zip(axes.flatten(), samples[epoch]):
ax.axis('off')
img = ((img - img.min()) * 255 / (img.max() - img.min())).astype(np.uint8)
ax.set_adjustable('box-forced')
im = ax.imshow(img)
plt.subplots_adjust(wspace=0, hspace=0)
return fig, axes
def train(net, dataset, epochs, batch_size, figsize=(5, 5)):
saver = tf.train.Saver()
sample_z = np.random.normal(0, 1, size=(50, latent_space_z_size))
samples, train_accuracies, test_accuracies = [], [], []
steps = 0
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for e in range(epochs):
print("Epoch", e)
num_samples = 0
num_correct_samples = 0
for x, y, label_mask in dataset.batches(batch_size):
assert 'int' in str(y.dtype)
steps += 1
num_samples += label_mask.sum()
# Sample random noise for G
batch_z = np.random.normal(0, 1, size=(batch_size, latent_space_z_size))
_, _, correct = sess.run([net.disc_opt, net.gen_opt, net.masked_correct],
feed_dict={net.input_actual: x, net.input_latent_z: batch_z,
net.target: y, net.label_mask: label_mask})
num_correct_samples += correct
sess.run([net.shrink_learning_rate])
training_accuracy = num_correct_samples / float(num_samples)
print(" Classifier train accuracy: ", training_accuracy)
num_samples = 0
num_correct_samples = 0
for x, y in dataset.batches(batch_size, which_set="test"):
assert 'int' in str(y.dtype)
num_samples += x.shape[0]
correct, = sess.run([net.correct], feed_dict={net.input_real: x,
net.y: y,
net.drop_rate: 0.})
num_correct_samples += correct
testing_accuracy = num_correct_samples / float(num_samples)
print(" Classifier test accuracy", testing_accuracy)
gen_samples = sess.run(
net.samples,
feed_dict={net.input_latent_z: sample_z})
samples.append(gen_samples)
_ = view_generated_samples(-1, samples, 5, 10, figsize=figsize)
plt.show()
# Save history of accuracies to view after training
train_accuracies.append(training_accuracy)
test_accuracies.append(testing_accuracy)
saver.save(sess, './checkpoints/generator.ckpt')
with open('samples.pkl', 'wb') as f:
pkl.dump(samples, f)
return train_accuracies, test_accuracies, samples
Don't forget to create a directory called checkpoints:
real_size = (32,32,3)
latent_space_z_size = 100
learning_rate = 0.0003
net = GAN(real_size, latent_space_z_size, learning_rate)
dataset = Dataset(train_data, test_data)
train_batch_size = 128
num_epochs = 25
train_accuracies, test_accuracies, samples = train(net,
dataset,
num_epochs,
train_batch_size,
figsize=(10,5))
Finally, at Epoch 24, you should get something close to this:
Figure 8: Sample images created by the generator network using the feature matching loss
fig, ax = plt.subplots()
plt.plot(train_accuracies, label='Train', alpha=0.5)
plt.plot(test_accuracies, label='Test', alpha=0.5)
plt.title("Accuracy")
plt.legend()
Figure 9: Train versus Test accuracy over the training process
Although feature matching loss performs well on the task of semi-supervised learning, the images produced by the generator are not as good as the ones created in the previous chapter. But this implementation was mainly introduced to demonstrate how we can use GANs for semi-supervised learning setups.