We set the digit (label) to generate as 7:
label_to_generate = 7
onehot = np.eye(10)
Set the number of iterations:
for epoch in range(num_epochs):
for i in range(len(images) // batch_size):
Sample images based on the batch size:
batch_image = images[i * batch_size:(i + 1) * batch_size]
Sample the condition that is, digit we want to generate:
batch_c = labels[i * batch_size:(i + 1) * batch_size]
Sample noise:
batch_noise = np.random.normal(0, 1, (batch_size, 100))
Train the generator and compute the generator loss:
generator_loss, _ = session.run([D_loss, D_optimizer], {x: batch_image, c: batch_c, z: batch_noise})
Train the discriminator and compute the discriminator loss:
discriminator_loss, _ = session.run([G_loss, G_optimizer], {x: batch_image, c: batch_c, z: batch_noise})
Randomly sample noise:
noise = np.random.rand(1,100)
Select the digit we want to generate:
gen_label = np.array([[label_to_generate]]).reshape(-1)
Convert the selected digit into a one-hot encoded vector:
one_hot_targets = np.eye(num_classes)[gen_label]
Feed the noise and one hot encoded condition to the generator and generate the fake image:
_fake_x = session.run(fake_x, {z: noise, c: one_hot_targets})
_fake_x = _fake_x.reshape(28,28)
Print the loss of generator and discriminator and plot the generator image:
print("Epoch: {},Discriminator Loss:{}, Generator Loss: {}".format(epoch,discriminator_loss,generator_loss))
#plot the generated image
display.clear_output(wait=True)
plt.imshow(_fake_x)
plt.show()
As you can see following plot, the generator has now learned to generate the digit 7 instead of generating other digits randomly: