InfoGAN is an extension to the generative adversarial networks. Generative adversarial networks are trained to generate new images that look similar to the original images. But they do not provide any control over the generation of the new images. Let’s say you have trained a GAN network to generate new faces that look similar to the given dataset. But there you will not have any control over these faces such as the colour of the eyes, hairstyles, etc. But with the help of InfoGAN, we can achieve these results because InfoGAN is able to learn the disentangled representation.
Introduction
A generative adversarial network consist of two networks – a generator and a discriminator. Both of these networks are trained in an adversarial manner. While the generator tries to generate images similar to original images, discriminator tries to differentiate between images generated by the generator and original images. Training continues until discriminator is fooled half the time by generator and generator is able to generate images similar to original images.
Control Variables
In a general GAN, a random input noise vector is given as input to the generator network which does not provide any information to the generator network i.e. in which manner outputs should be generated. While InfoGAN uses latent code along with noise vector to generate images accordingly. Input to the generator of the InfoGAN can be given in two parts:
- Continuous noise vector, z.
- Latent codes which can be both discrete and continuous, c.
Let say we have trained our InfoGAN on MNIST handwritten digit datasets. Here discrete latent codes (0-9) can be used to generate specific digits between 0-9. While continuous latent codes can be used to generate digits with varying thickness and orientation.
Mutual Information
InfoGAN stands for information maximizing GAN. To maximize information, InfoGAN uses mutual information. In information theory, the mutual information between X and Y, I(X; Y ), measures the “amount of information” learned from knowledge of random variable Y about the other random variable X. In InfoGAN there should be high mutual information between latent code c and generated images.
To maximize this mutual information, the InfoGAN model requires an extra network named as an auxiliary model. This auxiliary model shares all the weights from the discriminator network except the output layer. As the discriminator network has an output layer which predicts the given input image is real or fake, the auxiliary network predicts the latent codes.
So the InfoGAN will consist of three networks – Generator, Discriminator, and auxiliary network. Both the discriminator and auxiliary networks are used to improve the generator network. Here, the generation of real looking images by generator network is regularized by the discriminator network and maximization of mutual information is regularized by the auxiliary network.
Implementation
In this blog, we will implement InfoGAN using MNIST handwritten digit dataset. To maximize the information we will only use discrete codes to generate particular digits. In addition to this, you can also use two continuous variables to define the rotation and thickness of the generated digits.
Imports and Initialization
|
(x_train, y_train), (x_test, y_test) = mnist.load_data() batch_size = 16 half_batch_size = 8 latent_dim = 100 + 10 iterations = 60000 optimizer = Adam(0.0002, 0.5) |
Generator Network
|
def generator(): input_gen = Input(shape = (latent_dim,)) dense1 = Reshape((7,7,16))(Dense(7*7*16)(input_gen)) batch_norm_1 = BatchNormalization()(dense1) trans_1 = Conv2DTranspose(128, 3, padding='same', activation=LeakyReLU(alpha=0.2), strides=(2, 2))(batch_norm_1) batch_norm_2 = BatchNormalization()(trans_1) trans_2 = Conv2DTranspose(128, 3, padding='same', activation=LeakyReLU(alpha=0.2), strides=(2, 2))(batch_norm_2) output = Conv2D(1, (28,28), activation='tanh', padding='same')(trans_2) gen_model = Model(input_gen, output) gen_model.compile(loss='binary_crossentropy', optimizer=optimizer) print(gen_model.summary()) return gen_model |
Input to the generator network consists of shape (110, 1), where 100 is the noise vector size and 10 is the latent code size. Here latent codes are one-hot encoded discrete number between 0-9. I have used deconvolutional layers to upsample and finally produce the shape of (28,28,1). Batch normalization is used to improve the quality of the trained network and for stabilization.
Discriminator and Auxiliary Network
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
|
def discriminator(): input_disc = Input(shape = (28, 28, 1)) conv_1 = Conv2D(16, 3, padding = 'same', activation = LeakyReLU(alpha=0.2))(input_disc) batch_norm1 = BatchNormalization()(conv_1) pool_1 = AveragePooling2D(strides = (2,2))(batch_norm1) conv_2 = Conv2D(32, 3, padding = 'same', activation = LeakyReLU(alpha=0.2))(pool_1) batch_norm2 = BatchNormalization()(conv_2) pool_2 = AveragePooling2D(strides = (2,2))(batch_norm2) conv_3 = Conv2D(64, 3, padding = 'same', activation = LeakyReLU(alpha=0.2))(pool_2) batch_norm3 = BatchNormalization()(conv_3) pool_3 = AveragePooling2D(strides = (2,2))(conv_3) flatten_1 = Flatten()(pool_3) output = Dense(1, activation = 'sigmoid')(flatten_1) q_output_catgorical = Dense(10, activation = 'softmax')(flatten_1) disc_model = Model(input_disc, output) disc_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) q_model = Model(input_disc, q_output_catgorical) q_model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy']) print(disc_model.summary()) print(q_model.summary()) return disc_model, q_model |
As I have already told that auxiliary network shares all the weights of the discriminator network except the output layer there is no need to create two separate functions for this. Networks take images of shape (28, 28, 1) as input. convolutional, batch normalization and pooling layers are used to create the network. The output shape of the discriminator network is 1 as it only predicts the input image is real or fake. But the output shape of the auxiliary network is 10 as it predicts latent code.
Combined Model
|
def combined(): inputs = Input(shape = (latent_dim,)) gen_img = generator_model(inputs) discriminator_model.trainable = False disc_outs = discriminator_model(gen_img) q_outs = auxiliary_model(gen_img) comb_model = Model(inputs, [disc_outs, q_outs]) comb_model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=optimizer, metrics=['accuracy']) print(comb_model.summary()) return comb_model |
A combined model is created to train the generator network. Here we do discriminator network as non-trainable as discriminator network is trained separately. The combined model takes random noise and latent code as input. This input is fed to the generator network and the generated image is fed to both discriminator and auxiliary network.
Training InfoGAN
Training a GAN model is always a difficult task. A careful hyperparameter tuning is always required. We will use the following steps to train the InfoGAN model.
- Normalize the input images from the MNIST dataset.
- Train the discriminator model using real images from the MNIST dataset.
- Train the discriminator model using real images and corresponding labels.
- Train the discriminator model using fake images generated from the generator network.
- Train the auxiliary network using fake images generated from the generator and random latent codes.
- Train the generator network using a combined model without training the discriminator.
- Repeat the steps from 2-6 for some iterations. I have trained it for 60000 iterations.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
|
generator_model = generator() discriminator_model, auxiliary_model = discriminator() combined_model = combined() def train(): train_data = (x_train.astype(np.float32) - 127.5) / 127.5 train_data = np.expand_dims(train_data, -1) train_data_y = y_train for i in range(iterations): batch_indx = np.random.randint(0, train_data.shape[0], size = (half_batch_size)) batch_x = train_data[batch_indx] batch_y = to_categorical(train_data_y[batch_indx], 10) real_loss = discriminator_model.train_on_batch(batch_x, np.ones((half_batch_size,1))) q_real_loss = auxiliary_model.train_on_batch(batch_x, batch_y) random_y = to_categorical(np.random.randint(0,10,half_batch_size), 10) input_noise = np.random.normal(0, 1, size=(half_batch_size, 100)) gen_outs = generator_model.predict(np.hstack((input_noise, random_y))) fake_loss = discriminator_model.train_on_batch(gen_outs, np.zeros((half_batch_size,1))) q_fake_loss = auxiliary_model.train_on_batch(gen_outs, random_y) noise = np.random.normal(0, 1, size=(batch_size, 100)) latent_code = to_categorical(np.random.randint(0,10,batch_size), 10) full_batch_input_noise = np.hstack((noise, latent_code)) gan_loss = combined_model.train_on_batch(full_batch_input_noise, [np.ones((batch_size,1)), latent_code]) if i%5000 == 0: print(i, fake_loss, real_loss, gan_loss, q_real_loss, q_fake_loss) |
Generation
Now we will generate images from the trained gan model. The generator will be provided with random noise and one hot encoded input digit between 0-9 whichever digit we want to generate.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
|
# generating new images from trained network import matplotlib.pyplot as plt r, c = 10, 5 gen_imgs = [] for indx in range(10): noise = np.random.normal(0, 1, (5, 100)) categorical_code = to_categorical([indx]*5, 10) input_noise = np.hstack((noise, categorical_code)) outs = generator_model.predict(input_noise) gen_imgs.extend(outs) gen_imgs = np.array(gen_imgs) gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 plt.show() fig.savefig("mnist.png") plt.close() |
Here are the generated results from the model:
Referenced Research Paper: InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets
Hope you enjoy reading.
If you have any doubts/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.