Implementing semi-supervised Learning using GANs

Semi-supervised learning aims to make use of a large amount of unlabelled data to boost the performance of a model having less amount of labeled data. These type of models can be very useful when collecting labeled data is quite cumbersome and expensive. Several semi-supervised deep learning models have performed quite well on standard benchmarks. In this blog, we will learn how GANs can help in semi-supervised learning.

If you are new to GANs, you should first read this blog: An Introduction to Generative Adversarial Networks. Generally in GANs, we train using two networks adversely, generator and discriminator. After training the GAN network we discard the discriminator and only use generator network to generate the new data. Now in the semi-supervised model after training the network we will discard the generator model and use the discriminator model. But here the discriminator model is designed differently.

In semi-supervised GAN (SGAN) discriminator is not only trained to discriminate between real and fake data but also to predict the label for the input image. Let say we take an example of MNIST dataset. In MNIST dataset there are basically handwritten digits from 0-9, a total of 10 classes. Now in semi-supervised GAN for MNIST digits, the discriminator will be trained for real or fake images and for predicting these 10 classes also.

So in SGANs, the discriminator is trained with these three types of datasets.

  1. Fake images generated by generator network.
  2. Real images from a dataset without having any labels (a large amount of unlabeled data).
  3. Real images from the dataset with labels ( less number of the labeled dataset)

While generator in SGAN will be trained in a similar way as it is trained in vanilla GANs. This type of training will allow the model to learn useful features extracted from unlabeled dataset and use these features to train a supervised discriminator to predict the labels of the input image.

Implementing Semi-Supervised GAN

Now we will implement a semi-supervised GAN using MNIST digits dataset. If you want to implement a simple GAN you can follow this blog: Implementation of GANs to generated Handwritten Digits.

MNIST digits dataset consists of 60000 training images from which we will only use 1000 labeled images and rest as unlabeled images. We will select random 1000 labeled images containing 100 images for each class. Let’s see the code for this:

Discriminator in SGAN

For this semi-supervised GAN model, we will create two discriminator models both of them share weights of every layer but have different output layers. One model will be the binary classifier model (discriminate between real and fake images) and another will be multi-class classifier model (predicts labels for the input image). Let’s see the code for this:

Generator in SGAN

Generator in this SGAN is a simple multi-layer neural network having three hidden layers with units 512, 256 and 128. The output layer is having a shape of the original image (28, 28,1). Input to the generator will we random noise of vector size 100. Here is the code.

Training the model

Training this model will consist of the following steps:

  1. Sample both label and unlabeled data from the MNIST dataset, also normalize and make labels of data into categorical form.
  2. Train the multi-class discriminator model with labeled real images (take a batch from images)
  3. Train the binary-class discriminator model with unlabeled real images (take a batch from images)
  4. Sample noise of vector size 100 and train the binary-class discriminator model with fake images generated by generator network.
  5. Sample noise of vector size 100 and train the combined model to train the generator network.
  6. Repeat steps from 2-5 for some number of iterations. I have trained it for 10000 iterations.

In the above training steps, you can see that we are training multi-class discriminator and binary-class discriminator in different steps. But actually they are sharing weights of the same network except for the output layer (As I have mentioned earlier).

Also, Binary-class discriminator is trained two times in every iteration, one with real images taken from the dataset and another with fake images generated from the generative network. While multi-class discriminator is trained once in each iteration, only with real labeled images. This is because multi-class labels are not available for generated images.

I have also tested the SGAN model with 10000 test dataset provided by MNIST after every 1000 iteration. Here is the result of that.

Now you can see that I have trained this SGAN model with only 1000 labeled images and it gives an accuracy of about 94.8%, that is quite nice.

Give me the full code!

Hope you enjoy reading.

If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.

Leave a Reply