Model losses

Now it's time to define the model losses. First off, the discriminator loss will be divided into two parts:

  • One which will represent the GAN problem, which is the unsupervised loss
  • The second one will compute the individual actual class probabilities, which is the supervised loss

For the discriminator's unsupervised loss, it has to discriminate between actual training images and the generated images by the generator.

As for a regular GAN, half of the time, the discriminator will get unlabeled images from the training set as an input and the other half, fake, unlabeled images from the generator.

 

For the second part of the discriminator loss, which is the supervised loss, we need to build upon the logits from the discriminator. So, we will use the softmax cross entropy since it's a multi classification problem.

As mentioned in the Enhanced Techniques for Training GANs paper, we should use feature matching for the generator loss. As the authors describe: 

"Feature matching is the concept of penalizing the mean absolute error between the average value of some set of features on the training data and the average values of that set of features on the generated samples. To do that, we take some set of statistics (the moments) from two different sources and force them to be similar. First, we take the average of the features extracted from the discriminator when a real training minibatch is being processed. Second, we compute the moments in the same way, but now for when a minibatch composed of fake images that come from the generator was being analyzed by the discriminator. Finally, with these two sets of moments, the generator loss is the mean absolute difference between them. In other words, as the paper emphasizes: We train the generator to match the expected values of the features on an intermediate layer of the discriminator."

 

And finally, the model loss function will look like this:

def model_losses(input_actual, input_latent_z, output_dim, target, num_classes, label_mask, leaky_alpha=0.2,
drop_out_rate=0.):

# These numbers multiply the size of each layer of the generator and the discriminator,
# respectively. You can reduce them to run your code faster for debugging purposes.
gen_size_mult = 32
disc_size_mult = 64

# Here we run the generator and the discriminator
gen_model = generator(input_latent_z, output_dim, leaky_alpha=leaky_alpha, size_mult=gen_size_mult)
disc_on_data = discriminator(input_actual, leaky_alpha=leaky_alpha, drop_out_rate=drop_out_rate,
size_mult=disc_size_mult)
disc_model_real, class_logits_on_data, gan_logits_on_data, data_features = disc_on_data
disc_on_samples = discriminator(gen_model, reuse_vars=True, leaky_alpha=leaky_alpha,
drop_out_rate=drop_out_rate, size_mult=disc_size_mult)
disc_model_fake, class_logits_on_samples, gan_logits_on_samples, sample_features = disc_on_samples

# Here we compute `disc_loss`, the loss for the discriminator.
disc_loss_actual = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=gan_logits_on_data,
labels=tf.ones_like(gan_logits_on_data)))
disc_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=gan_logits_on_samples,
labels=tf.zeros_like(gan_logits_on_samples)))
target = tf.squeeze(target)
classes_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=class_logits_on_data,
labels=tf.one_hot(target,
num_classes + extra_class,
dtype=tf.float32))
classes_cross_entropy = tf.squeeze(classes_cross_entropy)
label_m = tf.squeeze(tf.to_float(label_mask))
disc_loss_class = tf.reduce_sum(label_m * classes_cross_entropy) / tf.maximum(1., tf.reduce_sum(label_m))
disc_loss = disc_loss_class + disc_loss_actual + disc_loss_fake

# Here we set `gen_loss` to the "feature matching" loss invented by Tim Salimans.
sampleMoments = tf.reduce_mean(sample_features, axis=0)
dataMoments = tf.reduce_mean(data_features, axis=0)

gen_loss = tf.reduce_mean(tf.abs(dataMoments - sampleMoments))

prediction_class = tf.cast(tf.argmax(class_logits_on_data, 1), tf.int32)
check_prediction = tf.equal(tf.squeeze(target), prediction_class)
correct = tf.reduce_sum(tf.to_float(check_prediction))
masked_correct = tf.reduce_sum(label_m * tf.to_float(check_prediction))

return disc_loss, gen_loss, correct, masked_correct, gen_model
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset