Generative adversarial networks (GANs) are trained to generate new images that look similar to original images. Let say we have trained a GAN network on MNIST digit dataset that consists of 0-9 handwritten digits. Now if we generate images from this trained GAN network, it will randomly generate images which can be any digit between 0 to 9. But if we want to generate images only for a particular digit, it will be difficult. One way is to find a mapping between random noise given as input to generator and images generated by the network. But with the variations in random input noise, it is really difficult to find the mapping. Here comes the conditional GANs.
A GAN network will be a conditional GAN if we train both the discriminator and generator conditioned on some sort of auxiliary information. This information can be class labels, black&white images, and other modalities. In this blog, we will learn how to generate images from a conditional GANs (cGAN) conditioned on the class label.
After the introduction of conditional GANs in 2014, there has been a wide range of applications developed based on this network. Some of them are:
- Image to Image Translation: With the use of cGAN there has been a various implementation of image to image translations like translation from day to night, translation from black and white to color, translation from sketches to color photographs, etc.
- Face Aging: Uses conditional GANs to generate face photographs with different ages, from younger to older.
- Text to Image: Inspired by the idea of conditional GANs, generates images given text explaining the image.
That’s enough for the introduction now we will implement a conditional GANs to generate handwritten digits conditioned on class labels.
Here we will use MNIST digits dataset to train this conditional GAN. This dataset consists of images of digits ranging from 0-9 and corresponding labels. Create a cgan.py file and insert the following code:
1 2 3 4 5 |
from keras.layers import Input, Dense, Reshape, BatchNormalization, LeakyReLU, Conv2DTranspose, Conv2D, AveragePooling2D, Flatten, Embedding, Concatenate from keras.models import Model from keras.optimizers import Adam from keras.datasets import mnist import numpy as np |
Line 1 imports all the required layers from keras. Line 2 and 3 imports Model and optimizer respectively. Line 4 imports required MNIST dataset from keras. If you haven’t done it earlier it will download the data first. Line 5 imports the numpy package.
As we have imported all the necessary packages, next we will create our cGAN architecture. To create this network, first, we will create a class and initialize all the necessary variables.
1 2 3 4 5 6 7 8 9 10 11 12 13 |
class GAN(): def __init__(self): (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data() self.batch_size = 16 self.half_batch_size = 8 self.latent_dim = 100 self.iterations = 30000 self.optimizer = Adam(0.0002, 0.5) self.generator_model = self.generator() self.discriminator_model = self.discriminator() self.combined_model = self.combined() |
In the above code, Line 1 creates a class named as GAN. Line 2 defines an init function which is used to initialize all the required variables. Line 4 loads the data which consists of training and test data both with their labels. Line 5-9 initializes hyperparameters required for the network. Line 10-12 call the functions generator, the discriminator and combined model which we will later define in this class.
After initializing all the required variables we will next define the generator function of class GAN.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
def generator(self): input_gen = Input(shape = (1, )) embed_gen = Embedding(10, 50)(input_gen) dense_layer_gen = Dense(7*7)(embed_gen) reshaped_dense_gen = Reshape((7, 7, 1))(dense_layer_gen) input_gen_2 = Input(shape = (self.latent_dim,)) dense1 = Reshape((7,7,16))(Dense(7*7*16)(input_gen_2)) concat_layer_gen = Concatenate()([reshaped_dense_gen, dense1]) batch_norm_1 = BatchNormalization()(concat_layer_gen) trans_1 = Conv2DTranspose(128, 3, padding='same', activation=LeakyReLU(alpha=0.2), strides=(2, 2))(batch_norm_1) batch_norm_2 = BatchNormalization()(trans_1) trans_2 = Conv2DTranspose(128, 3, padding='same', activation=LeakyReLU(alpha=0.2), strides=(2, 2))(batch_norm_2) output = Conv2D(1, (28,28), activation='tanh', padding='same')(trans_2) gen_model = Model([input_gen, input_gen_2], output) gen_model.compile(loss='binary_crossentropy', optimizer=self.optimizer) print(gen_model.summary()) return gen_model |
In generator we are taking two inputs, one is random noise of shape (100,) and another is class label of shape (1,) which will be an integer between 0-9. This extra input taken as class label will be our condition to GAN. During test time we will use this class label as a condition to generate images for that specific class only.
In the above code, Line 3-6 is for our input of class label. Here we have added Embedding layer to this conditional input which consists of weights and will be trained during the generator training. This embedding layer converts positive integers to a dense vector of fixed size. Here we have taken embedding of size 50. After this embedding layer we have added a dense layer and then reshaped it to make compatible during concatenation with random noise.
Line 8-9 creates an input layer for random noise and reshape it. Line 11 and 12 concatenate both the inputs after reshaping and then applied the batch norm. Batch normalization is really helpful in improving the quality of the model and stabilizing the training process.
Line 13-15 are for two upsampling layers (deconvolutional layers) with added batch normalization layer. Line 16 is an output layer with shape equals real images (28, 28, 1). Line 17, we create a generator model. Line-18 is for compiling the model where loss is cross-entropy and optimizer is Adam optimizer.
This GAN class is also consist of discriminator network which is also conditioned on class labels.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
def discriminator(self): input_class = Input(shape = (1, )) embed = Embedding(10, 50)(input_class) dense_layer = Dense(28*28)(embed) reshaped_dense = Reshape((28,28,1))(dense_layer) input_disc = Input(shape = (28, 28, 1)) concat_layer = Concatenate()([input_disc, reshaped_dense]) conv_1 = Conv2D(16, 3, padding = 'same', activation = LeakyReLU(alpha=0.2))(concat_layer) batch_norm1 = BatchNormalization()(conv_1) pool_1 = AveragePooling2D(strides = (2,2))(batch_norm1) conv_2 = Conv2D(32, 3, padding = 'same', activation = LeakyReLU(alpha=0.2))(pool_1) batch_norm2 = BatchNormalization()(conv_2) pool_2 = AveragePooling2D(strides = (2,2))(batch_norm2) conv_3 = Conv2D(64, 3, padding = 'same', activation = LeakyReLU(alpha=0.2))(pool_2) batch_norm3 = BatchNormalization()(conv_3) pool_3 = AveragePooling2D(strides = (2,2))(conv_3) flatten_1 = Flatten()(pool_3) output = Dense(1, activation = 'sigmoid')(flatten_1) disc_model = Model([input_class, input_disc], output) disc_model.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) print(disc_model.summary()) return disc_model |
In the above code, line 3-6 are doing the same for converting class label input to embedding as we have seen in the case of generator network except for reshaping it to (28, 28, 1) instead of reshaping it to (7, 7, 1). Line 8 describes the second input layer which is an image (either real or fake). then in Line 10 we concatenate both the inputs to make it compatible with our discriminator network.
Line 11-19 is basically a combination of conv layer -> batch norm layer -> average pooling layer. Convolution layers are having filter size of 16, 32 and 64. Here we have used the average pooling layer instead of using max pooling layer as it is recommended to not use max pooling layers with GAN architectures.
Finally, from line 20-21 we first flatten the output from the previous layer and added a fully connected layer with shape 1 which is treated as output layer for our discriminator model. This model will discriminate between real and fake image. Line 22-23 we created discriminator model which takes two inputs with one output and then compiled the model with cross-entropy loss and Adam optimizer.
This was our discriminator model, now we will create a combined model which consists of both discriminator and generator to train the generator network.
1 2 3 4 5 6 7 8 9 10 11 12 |
def combined(self): inputs = Input(shape = (self.latent_dim,)) input_comb = Input(shape = (1,)) gen_img = self.generator_model([input_comb, inputs]) self.discriminator_model.trainable = False outs = self.discriminator_model([input_comb, gen_img]) comb_model = Model([input_comb, inputs], outs) comb_model.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) print(comb_model.summary()) return comb_model |
In the above code, we created a combined model which takes two inputs one is random noise of shape (100, ) and another is the class label of shape (1, ). Generator model takes these two inputs and generates the new image which is then fed to the discriminator model to predict the output. Here, only the generator is being trained and the discriminator is made non-trainable.
Next we will train the whole GAN networks using these networks.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
def train(self): train_data = (self.x_train.astype(np.float32) - 127.5) / 127.5 train_data = np.expand_dims(train_data, -1) train_data_y = self.y_train for i in range(self.iterations): batch_indx = np.random.randint(0, train_data.shape[0], size = (self.half_batch_size)) batch_x = train_data[batch_indx] batch_y = train_data_y[batch_indx] real_loss = self.discriminator_model.train_on_batch([batch_y, batch_x], np.ones((self.half_batch_size,1))) random_y = np.random.randint(0,10,self.half_batch_size) input_noise = np.random.normal(0, 1, size=(self.half_batch_size, 100)) gen_outs = self.generator_model.predict([random_y, input_noise]) fake_loss = self.discriminator_model.train_on_batch([random_y, gen_outs], np.zeros((self.half_batch_size,1))) full_batch_input_noise = np.random.normal(0, 1, size=(self.batch_size, 100)) gan_loss = self.combined_model.train_on_batch([np.random.randint(0,10,self.batch_size), full_batch_input_noise], np.array([1] * self.batch_size)) print(i, fake_loss, real_loss, gan_loss) |
In the above code, from line 3-4, first, we first normalize the input image in the range of -1 to 1 and then reshape it to (28,28, 1). From line 9-11 we randomly select the real images and their corresponding labels equals to half the batch size. Line 13, we train the discriminator network using these real images conditioned on real class labels.
Then Line 15 we select the random labels between 0-9 of half the batch size for the input to the generator because during training we can not have the class labels for random noise to the generator. Then Line 16-17 we take random noise of shape (half_batch_size, 100) and generate the images from generator network which will be fake input images to the discriminator. Then Line 19 we train the discriminator network with these fake generated images which is conditioned on random class labels.
Finally, in line 21-22, we train our generator network using the combined model. Here we take the random noise and random class labels as input to the combined model.
We train this network for some number of iterations until our network is not able to fool the discriminator network. Finally, after training this network we can discard the discriminator network and use the generator network to generate new images conditioned on class labels.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# training the network # generating new images from trained network import matplotlib.pyplot as plt r, c = 10, 10 noise = np.random.normal(0, 1, (10, 100)) gen_imgs = [] for indx in range(10): gen_imgs.extend(gan.generator_model.predict([np.array([indx]*10), noise])) # Rescale images 0 - 1 gen_imgs = np.array(gen_imgs) gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 plt.show() fig.savefig("mnist.png") plt.close() |
Above code is used to test our trained cGAN. Here are the outputs generated from the network.
Hope you enjoy reading.
If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.
References:
- Face aging: https://arxiv.org/pdf/1702.01983.pdf
- Image to Image translation: https://phillipi.github.io/pix2pix/