Training process

Next, we will train the model. Here, we will be using 100 iterations. Let's go over the code for this, which has been summarized into five points:

# 1. Generate 50 fake images from noise
for (i in 1:100) {noise <- matrix(rnorm(b*l), nrow = b, ncol= l)}
fake <- g %>% predict(noise)

# 2. Combine real & fake images
stop <- start + b - 1
real <- trainx[start:stop,,,]
real <- array_reshape(real, c(nrow(real), 28, 28, 1))
rows <- nrow(real)
both <- array(0, dim = c(rows * 2, dim(real)[-1]))
both[1:rows,,,] <- fake
both[(rows+1):(rows*2),,,] <- real
labels <- rbind(matrix(runif(b, 0.9,1), nrow = b, ncol = 1),
matrix(runif(b, 0, 0.1), nrow = b, ncol = 1))
start <- start + b

# 3. Train discriminator
dloss[i] <- d %>% train_on_batch(both, labels)

# 4. Train generator using gan
fakeAsReal <- array(runif(b, 0, 0.1), dim = c(b, 1))
gloss[i] <- gan %>% train_on_batch(noise, fakeAsReal)

# 5. Save fake image
f <- fake[1,,,]
dim(f) <- c(28,28,1)
image_array_save(f, path = file.path(dir, paste0("f", i, ".png")))}

In the preceding code, we can observe the following:

  1. We start by simulating random data points from the standard normal distribution and the save results as noise. Then, we use the generator network g to create fake images from this data containing random noise. Note that noise is 50 x 28 in size and that fake is 50 x 28 x 28 x 1 in size and contains 50 fake images in each iteration.
  2. We update the values of start and stop based on the batch size. For the first iteration, start and stop have values of 1 and 50, respectively. For the second iteration, start and stop have values of 51 and 100, respectively. Similarly, for the 100th iteration, start and stop have values of 4,951 and 5,000, respectively. Since trainx, which contains the handwritten digit five, has more than 5,000 images, none of the images are repeated during these 100 iterations. Thus, in each iteration, 50 real images are selected and stored in real, which is 50 x 28 x 28 in size. We use reshape to change the dimensions to 50 x 28 x 28 x 1, so that they match the dimensions of the fake images.
  3. Then, we create an empty array called both that's 100 x 28 x 28 x 1 in size to store real and fake image data. The first 50 images in both contain fake data while the next 50 images contain real images. We also generate 50 random numbers between 0.9 and 1 using uniform distribution to use as labels for fake images and similar random numbers between 0 and 0.1 to use as labels for real images. Note that we do not use 0 to represent real and 1 to represent fake images and instead introduce some randomness or noise. Artificially introducing some noise in the values of labels helps at the time of training the network.
  4. We train the discriminator network using image data contained in both and the correct category information contained in labels. We also store the discriminator loss values in dloss for all 100 iterations. If the discriminator network learns to do well in classifying fake and real images, then this loss value will be low.
  5. We try to fool the network by labeling the noise containing random values between 0 and 0.1, which we had used for real images. The resulting loss values are stored in gloss for all 100 iterations. If the network learns to do well in presenting fake images and makes the network classify them as real, then this loss value will be low.
  6. We save the first fake image from each of the 100 iterations so that we can review it and observe the impact of the training process.

Note that, usually, the training process for generative adversarial networks requires a significant amount of computational resources. However, the example we are using here is meant to quickly illustrate how this process works and complete the training process in a reasonable amount of time. For 100 iterations and a computer with 8 GB of RAM, it should take less than a minute to run all the code.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset