In the previous blog, we studied about GANs, now in this blog, we will implement GANs to generate MNIST digits dataset.
In the generative adversarial networks, both generator and discriminator are trained simultaneously. Both networks can overpower each other if not trained properly. If discriminator is trained more than it will easily detect fake and real image then the generator will not able to generate real-looking images. And if the generator is trained heavily then discriminator will not be able to classify between real and fake images. We can solve this problem by properly setting the learning rate for both networks.
When we train discriminator we do not train generator and when we train generator we do not train discriminator. This makes the generator to train properly. Now, let’s look into the code for each part on the GAN network.
Discriminator Network:
We are using MNIST digits dataset which is having an image shape of (28, 28, 1). Since the image size is small we can use MLP network for discriminator instead of using convolutional layers. To do this first we need to reshape input into a single vector of size (784, 1). Then I have applied three dense layers of 512, 256 and 128 hidden units in each layers.
1 2 3 4 5 6 7 8 9 10 11 12 |
def discriminator(self): input_disc = Input(shape = (784,)) hidden1 = Dense(512, activation = 'relu')(input_disc) hidden2 = Dense(256, activation = 'relu')(hidden1) hidden3 = Dense(128, activation = 'relu')(hidden2) output = Dense(1, activation = 'sigmoid')(hidden3) disc_model = Model(input_disc, output) disc_model.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) print(disc_model.summary()) return disc_model |
Generator Network:
To create generator network we will first take random noise as input with the shape of (100, 1). Then I have used three hidden layers with shape of 256, 512 and 1024. The output of the generator network is then reshaped to (28, 28, 1). I have batch normalization in each hidden layer. Batch normalization improves the quality of the trained model and also stabilizes the training process.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
def generator(self): input_gen = Input(shape = (self.latent_dim,)) hidden1 = BatchNormalization(momentum=0.8)(Dense(256, activation = 'relu')(input_gen)) hidden2 = BatchNormalization(momentum=0.8)(Dense(512, activation = 'relu')(hidden1)) hidden3 = BatchNormalization(momentum=0.8)(Dense(1024, activation = 'relu')(hidden2)) output = Dense(784, activation='tanh')(hidden3) reshaped_output = Reshape((28, 28, 1))(output) gen_model = Model(input_gen, reshaped_output) gen_model.compile(loss='binary_crossentropy', optimizer=self.optimizer) print(gen_model.summary()) return gen_model |
Combined Model:
To train the generator we need to create a combined model where we do not train the discriminator model. In combined model random noise is being given as input to the generator network and the output image is then passed through the discriminator network to get the label. Here I have flagged discriminator model as non-trainable.
1 2 3 4 5 6 7 8 9 10 11 12 |
def combined(self): inputs = Input(shape = (self.latent_dim,)) gen_img = self.generator_model(inputs) gen_img = Reshape((784,))(gen_img) self.discriminator_model.trainable = False outs = self.discriminator_model(gen_img) comb_model = Model(inputs, outs) comb_model.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) print(comb_model.summary()) return comb_model |
Training the GAN network:
Training a GAN network requires careful hyper-parameters tuning. If the model is not trained carefully it will not converge to produce good results. We will use the following steps to train this GAN network:
- Firstly we will normalize input dataset (MNIST images).
- Train the discriminator with real images (from MNIST dataset)
- Sample same number of noise vectors to predict the output from generator network (Generator is not trained here).
- Train the discriminator network with images generated in the previous step.
- Take new random samples to train the generator with a combined model without training discriminator.
- Repeat from step 2-5 for some number of iterations. I have trained it for 30000 iterations.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
def train(self): train_data = (self.x_train.astype(np.float32) - 127.5) / 127.5 for i in range(self.iterations): batch_indx = np.random.randint(0, train_data.shape[0], size = (self.half_batch_size)) batch_x = train_data[batch_indx] batch_x = batch_x.reshape((-1, 784)) input_noise = np.random.normal(0, 1, size=(self.half_batch_size, 100)) gen_outs = self.generator_model.predict(input_noise) gen_outs = gen_outs.reshape((-1, 784)) real_loss = self.discriminator_model.train_on_batch(batch_x, np.ones((self.half_batch_size,1))) fake_loss = self.discriminator_model.train_on_batch(gen_outs, np.zeros((self.half_batch_size,1))) disc_loss = 0.5*np.add(fake_loss,real_loss) full_batch_input_noise = np.random.normal(0, 1, size=(self.batch_size, 100)) gan_loss = self.combined_model.train_on_batch(full_batch_input_noise, np.array([1] * self.batch_size)) print(i, disc_loss, gan_loss) |
Take a look into the generated images from this GAN network.
Here is the full code.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
from keras.layers import Input, Dense, Reshape, BatchNormalization from keras.models import Model from keras.optimizers import Adam from keras.datasets import mnist import numpy as np class GAN(): def __init__(self): (self.x_train, self.y_train), (self.x_test, self.y_test) = mnist.load_data() self.batch_size = 128 self.half_batch_size = 64 self.latent_dim = 100 self.iterations = 30000 self.optimizer = Adam(0.0002, 0.5) self.generator_model = self.generator() self.discriminator_model = self.discriminator() self.combined_model = self.combined() def generator(self): input_gen = Input(shape = (self.latent_dim,)) hidden1 = BatchNormalization(momentum=0.8)(Dense(256, activation = 'relu')(input_gen)) hidden2 = BatchNormalization(momentum=0.8)(Dense(512, activation = 'relu')(hidden1)) hidden3 = BatchNormalization(momentum=0.8)(Dense(1024, activation = 'relu')(hidden2)) output = Dense(784, activation='tanh')(hidden3) reshaped_output = Reshape((28, 28, 1))(output) gen_model = Model(input_gen, reshaped_output) gen_model.compile(loss='binary_crossentropy', optimizer=self.optimizer) print(gen_model.summary()) return gen_model def discriminator(self): input_disc = Input(shape = (784,)) hidden1 = Dense(512, activation = 'relu')(input_disc) hidden2 = Dense(256, activation = 'relu')(hidden1) hidden3 = Dense(128, activation = 'relu')(hidden2) output = Dense(1, activation = 'sigmoid')(hidden3) disc_model = Model(input_disc, output) disc_model.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) print(disc_model.summary()) return disc_model def combined(self): inputs = Input(shape = (self.latent_dim,)) gen_img = self.generator_model(inputs) gen_img = Reshape((784,))(gen_img) self.discriminator_model.trainable = False outs = self.discriminator_model(gen_img) comb_model = Model(inputs, outs) comb_model.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) print(comb_model.summary()) return comb_model def train(self): train_data = (self.x_train.astype(np.float32) - 127.5) / 127.5 for i in range(self.iterations): batch_indx = np.random.randint(0, train_data.shape[0], size = (self.half_batch_size)) batch_x = train_data[batch_indx] batch_x = batch_x.reshape((-1, 784)) input_noise = np.random.normal(0, 1, size=(self.half_batch_size, 100)) gen_outs = self.generator_model.predict(input_noise) gen_outs = gen_outs.reshape((-1, 784)) fake_loss = self.discriminator_model.train_on_batch(gen_outs, np.zeros((self.half_batch_size,1))) real_loss = self.discriminator_model.train_on_batch(batch_x, np.ones((self.half_batch_size,1))) disc_loss = 0.5*np.add(fake_loss,real_loss) full_batch_input_noise = np.random.normal(0, 1, size=(self.batch_size, 100)) gan_loss = self.combined_model.train_on_batch(full_batch_input_noise, np.array([1] * self.batch_size)) print(i, disc_loss, gan_loss) # training the network gan = GAN() gan.train() # generating new images from trained network import matplotlib.pyplot as plt r, c = 5, 5 noise = np.random.normal(0, 1, (r * c, 100)) gen_imgs = gan.generator_model.predict(noise) # Rescale images 0 - 1 gen_imgs = 0.5 * gen_imgs + 0.5 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 plt.show() fig.savefig("mnist.png") plt.close() |
Hope you enjoy reading.
If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.