CycleGAN is a variant of a generative adversarial network and was introduced to perform image translation from domain X to domain Y without using a paired set of training examples. In the previous blog, I have already described CycleGAN in detail. In this blog, we will implement CycleGAN to translate apple images to orange images and vice-versa with the help of Keras library. Here are some recommended blogs that you should refer before implementing CycleGAN:
- Cycle-Consistent Generative Adversarial Networks (CycleGAN)
- Image to Image Translation Using Conditional GAN
- Implementation of Image-to-image translation using conditional GAN
Load the Dataset And Preprocess
CycleGAN does not require any paired dataset as compared to other image translation algorithms. Hence here we will use two sets of datasets. One consists of apple images and the other consists of orange images. Both the datasets are not paired with each other. Here are some images from the dataset:
You can download the dataset from this link. Or run the following command from your terminal.
1 2 |
wget https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip unzip apple2orange.zip |
Dataset consists of four folders: trainA, trainB, testA, and testB. ‘A’ dataset consists of apple images and the ‘B’ dataset consist of orange images. Training set consists of approx 1000 images for each type and the test set consists of approx 200 images corresponding to each type.
So, let’s first import all the required libraries:
1 2 3 4 5 6 7 8 9 |
import cv2 import os from tqdm import tqdm from keras.layers import BatchNormalization, Reshape, Dense, Input, LeakyReLU, Conv2D, Conv2DTranspose, Concatenate, ReLU, Dropout, ZeroPadding2D from keras.models import Model from keras.initializers import RandomNormal from keras.optimizers import Adam import numpy as np import time |
Dataset is a little preprocessed as it contains all images of equal size (256, 256, 3). Other preprocessing steps that we are going to use are normalization and random flipping. Here we are normalizing every image between -1 to 1 and randomly flipping horizontally. Here is the code:
1 2 3 4 5 6 |
def load_img(file_path): img = cv2.imread(file_path) if np.random.rand() > 0.5: img = cv2.flip(img, 1) img = (img/127.5) - 1 return img |
Now load the training images from the directory into a list.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
train_a = [] train_b = [] trainA_path = r'/content/apple2orange/trainA' for files in tqdm(os.listdir(trainA_path)): file_path = os.path.join(trainA_path, files) input_img = load_img(file_path) train_a.append(input_img) trainB_path = r'/content/apple2orange/trainB' for files in tqdm(os.listdir(trainB_path)): file_path = os.path.join(trainB_path, files) input_img = load_img(file_path) train_b.append(input_img) train_a = np.array(train_a) train_b = np.array(train_b) |
Build the Generator
The network architecture that I have used is very similar to the architecture used in image-to-image translation with conditional GAN. The major difference is the loss function. In CycleGAN two more losses have been introduced. One is cycle consistency loss and the other is identity loss.
Here generator network is a U-net architecture. This U-net architecture consists of the encoder-decoder model with a skip connection between encoder and decoder. Here we will use two generator networks. One will translate from apple to orange (G: X -> Y) and the other will translate from orange to apple (F: Y -> X). Each generator network is consists of encoder and decoder. Each encoder block is consist of three layers (Conv -> BatchNorm -> Leakyrelu). And each block in decoder network is consist of four layers (Transposed Conv -> BatchNorm -> Dropout -> Relu). The generator will take an image as input and outputs a generated image. Both images will have a size of (256, 256, 3). Here is the code:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
def generator(): image_input = Input(shape=(256, 256, 3)) # Encoder Network conv_1 = Conv2D(64,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(image_input) act_1 = LeakyReLU(alpha=0.2)(conv_1) conv_2 = Conv2D(128,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_1) batch_norm_2 = BatchNormalization(momentum=0.8)(conv_2) act_2 = LeakyReLU(alpha=0.2)(batch_norm_2) conv_3 = Conv2D(256,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_2) batch_norm_3 = BatchNormalization(momentum=0.8)(conv_3) act_3 = LeakyReLU(alpha=0.2)(batch_norm_3) conv_4 = Conv2D(512,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_3) batch_norm_4 = BatchNormalization(momentum=0.8)(conv_4) act_4 = LeakyReLU(alpha=0.2)(batch_norm_4) conv_5 = Conv2D(512,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_4) batch_norm_5 = BatchNormalization(momentum=0.8)(conv_5) act_5 = LeakyReLU(alpha=0.2)(batch_norm_5) conv_6 = Conv2D(512,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_5) batch_norm_6 = BatchNormalization(momentum=0.8)(conv_6) act_6 = LeakyReLU(alpha=0.2)(batch_norm_6) conv_7 = Conv2D(512,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_6) batch_norm_7 = BatchNormalization()(conv_7) act_7= LeakyReLU(alpha=0.2)(batch_norm_7) conv_8 = Conv2D(512,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_7) batch_norm_8 = BatchNormalization(momentum=0.8)(conv_8) act_8 = LeakyReLU(alpha=0.2)(batch_norm_8) # Decoder Network and skip connections with encoder convt_1 = Conv2DTranspose(512,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_8) batch_normt_1 = BatchNormalization(momentum=0.8)(convt_1) drop_1 = Dropout(0.5)(batch_normt_1) actt_1 = ReLU()(drop_1) concat_1 = Concatenate()([actt_1, act_7]) convt_2 = Conv2DTranspose(512,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(concat_1) batch_normt_2 = BatchNormalization(momentum=0.8)(convt_2) drop_2 = Dropout(0.5)(batch_normt_2) actt_2 = ReLU()(drop_2) concat_2 = Concatenate()([actt_2, act_6]) convt_3 = Conv2DTranspose(512,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(concat_2) batch_normt_3 = BatchNormalization(momentum=0.8)(convt_3) drop_3 = Dropout(0.5)(batch_normt_3) actt_3 = ReLU()(drop_3) concat_3 = Concatenate()([actt_3, act_5]) convt_4 = Conv2DTranspose(512,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(concat_3) batch_normt_4 = BatchNormalization(momentum=0.8)(convt_4) actt_4 = ReLU()(batch_normt_4) concat_4 = Concatenate()([actt_4, act_4]) convt_5 = Conv2DTranspose(256,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(concat_4) batch_normt_5 = BatchNormalization(momentum=0.8)(convt_5) actt_5 = ReLU()(batch_normt_5) concat_5 = Concatenate()([actt_5, act_3]) convt_6 = Conv2DTranspose(128,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(concat_5) batch_normt_6 = BatchNormalization(momentum=0.8)(convt_6) actt_6 = ReLU()(batch_normt_6) concat_6 = Concatenate()([actt_6, act_2]) convt_7 = Conv2DTranspose(64,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(concat_6) batch_normt_7 = BatchNormalization(momentum=0.8)(convt_7) actt_7 = ReLU()(batch_normt_7) concat_7 = Concatenate()([actt_7, act_1]) outputs = Conv2DTranspose(3,4,strides=2,use_bias=False,activation='tanh',kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(concat_7) gen_model = Model(image_input, outputs) # gen_model.summary() return gen_model |
1 2 |
genA = generator() genB = generator() |
Build the Discriminator
Discriminator network is a patchGAN pretty similar to the one used in the code for image-to-image translation with conditional GAN. Here two discriminators will be used. One discriminator will discriminate between images generated by generator A and orange images. And another discriminator is used to discriminate between image generated by generator B and apple images.
This patchGAN is nothing but a convolution network. The difference between patchGAN and normal convolution network is that instead of producing output as single scalar vector it generates an NxN array. This NxN array maps to the patch from the input images. And then takes an average to classify the whole image as real or fake.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
def discriminator(): img_inp = Input(shape = (256, 256, 3)) conv_1 = Conv2D(64,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(img_inp) act_1 = LeakyReLU(alpha=0.2)(conv_1) conv_2 = Conv2D(128,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_1) batch_norm_2 = BatchNormalization(momentum=0.8)(conv_2) act_2 = LeakyReLU(alpha=0.2)(batch_norm_2) conv_3 = Conv2D(256,4,strides=2,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02),padding='same')(act_2) batch_norm_3 = BatchNormalization(momentum=0.8)(conv_3) act_3 = LeakyReLU(alpha=0.2)(batch_norm_3) zero_pad = ZeroPadding2D()(act_3) conv_4 = Conv2D(512,4,strides=1,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02))(zero_pad) batch_norm_4 = BatchNormalization(momentum=0.8)(conv_4) act_4 = LeakyReLU(alpha=0.2)(batch_norm_4) zero_pad_1 = ZeroPadding2D()(act_4) outputs = Conv2D(1,4,strides=1,use_bias=False,kernel_initializer=RandomNormal(mean=0.,stddev=0.02))(zero_pad_1) disc_model = Model(img_inp, outputs) # disc_model.summary() return disc_model |
1 2 3 |
discA = discriminator() discB = discriminator() discA.summary() |
Combined Network
Now we will create a combined network to train the generator model. Here both discriminators will be non-trainable. To train the generator network we will also use cycle consistency loss and identity loss.
Cycle consistency says that if we translate an English sentence to a french sentence and then translate back it to English sentence we should arrive at the original sentence. To calculate the cycle consistency loss first pass the input image A to generator A and then pass the predicted output to the generator B. Now calculate the loss between image generated from generator B and input image B. Same goes while taking image B as input to the generator B.
In case of identity loss, If we are passing image from domain A to generator A and trying to generate image looking similar to image from domain B then identity loss makes sure that even if we pass image from domain B to generator A it should generate image from domain B. Here is the code for combined model.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
def combined(): inputA = Input(shape = (256, 256, 3)) inputB = Input(shape = (256, 256, 3)) gen_imgB = genA(inputA) gen_imgA = genB(inputB) #for cycle consistency reconstruct_imgA = genB(gen_imgB) reconstruct_imgB = genA(gen_imgA) # identity mapping gen_orig_imgB = genA(inputB) gen_orig_imgA = genB(inputA) discA.trainable = False discB.trainable = False valid_imgA = discA(gen_imgA) valid_imgB = discA(gen_imgB) comb_model = Model([inputA, inputB], [valid_imgA, valid_imgB, reconstruct_imgA, reconstruct_imgB, gen_orig_imgA, gen_orig_imgA]) # comb_model.summary() return comb_model |
1 |
comb_model = combined() |
Loss, Optimizer and Compile the Models
Here we are using mse loss for the discriminator networks and mae loss for the generator network. Optimizer use here is Adam. The batch size for the network is 1 and the total number of epochs is 200.
1 2 3 4 5 6 7 8 9 10 |
optimizer = Adam(0.0002, 0.5) discA.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) discB.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) comb_model.compile(loss=['mse', 'mse', 'mae', 'mae', 'mae','mae'],loss_weights=[ 1, 1, 10, 10, 1, 1],optimizer=optimizer) disc_patch = (30, 30, 1) epochs = 200 valid = np.ones((1,) + disc_patch) fake = np.zeros((1,) + disc_patch) |
Train the Network
- Generate image from generator A using image from domain A, Similarly generate an image from generator B using image from domain B.
- Train discriminator A on batch using images from domain A and images generated from generator B as real and fake image respectively.
- Train discriminator B on batch using images from domain B and images generated from generator A as real and fake image respectively.
- Train generator on batch using the combined model.
- Repeat steps from 1 to 4 for every image in the training dataset and then repeat this process for 200 epochs.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
def train(): for j in range(epochs): t1 = time.time() for i in range(len(train_a)): img_a = np.expand_dims(train_a[i], axis = 0) img_b = np.expand_dims(train_b[i], axis = 0) img_b_gen = genA.predict(img_a) img_a_gen = genB.predict(img_b) #train discriminator A dA_real_loss = discA.train_on_batch(img_a, valid) dA_fake_loss = discA.train_on_batch(img_a_gen, fake) #train discriminator B dB_real_loss = discB.train_on_batch(img_b, valid) dB_fake_loss = discB.train_on_batch(img_b_gen, fake) # train generator g_loss = comb_model.train_on_batch([img_a, img_b], [valid, valid, img_a, img_b, img_a, img_b]) if i ==993: print('time taken for one epoch', time.time()-t1) print(j, i, dA_real_loss, dA_fake_loss, dB_real_loss, dB_fake_loss, g_loss) |
1 |
train() |
Hope you enjoy reading.
If you have any doubts/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.
How to plot the real and fake generated apple in single plot to compare both of them. So we can understand the performance of model and also manually visualize the difference between the real and fake. I am so happy for your post. I need to visualize the real and fake apple side by side. It’s an emergency, please revert back asap.