Implementing Capsule Network in Keras

In the last blog we have seen that what is a capsule network and how it can overcome the problems associated with convolutional neural network. In this blog we will implement a capsule network in keras.

You can find full code here.

Here, we will use handwritten digit dataset(MNIST) and train the capsule network to classify the digits. MNIST digit dataset consists of grayscale images of size 28*28.

Capsule Network architecture is somewhat similar to convolutional neural network except capsule layers. We can break the implementation of capsule network into following steps:

  1. Initial convolutional layer
  2. Primary capsule layer
  3. Digit capsule layer
  4. Decoder network
  5. Loss Functions
  6. Training and testing of model

Initial Convolution Layer:

Initially we will use a convolution layer to detect low level features of an image. It will use 256 filters each of size 9*9 with stride 1 and activation function is relu. Input size of image is 28*28, after applying this layer output size will be 20*20*256.

Primary Capsule Layer:

The output from the previous layer is being passed to 256 filters each of size 9*9 with a stride of 2 which will produce an output of size 6*6*256. This output is then reshaped into 8-dimensional vector. So shape will be 6*6*32 capsules each of which will be 8-dimensional. Then it will pass through a non-linear function(squash) so that length of output vector can be maintained between 0 and 1.

Digit Capsule Layer:

Logic and algorithm used for this layer is explained in the previous blog. Here we will see what we need to do in code to implement it. We need to write a custom layer in keras. It will take 1152*8 as its input and produces output of size 10*16, where 10 capsules each represents an output class with 16 dimensional vector. Then each of these 10 capsules are converted into single value to predict the output class using a lambda layer.

Decoder Network:

To further boost the pose parameters learned by the digit capsule layer, we can add decoder network to reconstruct the input image. In this part, decoder network will be fed with an input of size 10*16 (digit capsule layer output) and will reconstruct back the original image of size 28*28. Decoder will consist of 3 dense layer having 512, 1024 and 784 nodes.

During training time input to the decoder is the output from digit capsule layer which is masked with original labels. It means that other vectors except the vector corresponding to correct label will be multiplied with zero. So that decoder can only be trained with correct digit capsule. In test time input to decoder will be the same output from digit capsule layer but masked with highest length vector in that layer. Lets see the code.

Loss Functions:

It uses two loss function one is probabilistic loss function used for classifying digits image and another is reconstruction loss which is mean squared error. Lets see probabilistic loss which is simple to understand once you look at following code.

Training and Testing of model:

Now define our training and testing model and train it on MNIST digit dataset.

In test data set it was able to achieve 99.09% accuracy. Pretty good yeah! Also reconstructed images looks good. Here are the reconstructed images generated by decoder network.

Capsule Network comes with promising results and yet to be explored thoroughly. There are various bits and bytes where it can be explored. Research on a capsule network is still in an early stage but it has given clear indication that it is worth exploring.

Hope you enjoy reading.

If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.

14 thoughts on “Implementing Capsule Network in Keras

  1. sandeep

    Hey hi, how can I predict on a new image. I mean how to send the inputs to predict function?

    Reply
  2. Atul Krishna Singh

    you can simply use following line:

    where x_test are input images.

    Reply
  3. satish

    Hello,

    Thanks for the nice post.

    When I try to run the code, I get the following error:

    ValueError: Can not do batch_dot on inputs with shapes (None, 10, 10, 1152, 16) and (None, 10, 1152, 1152, 16) with axes=[2, 3]. x.shape[2] != y.shape[3] (10 != 1152).

    I cannot understand the exact problem here. Sorry, I am newbie. Could anyone please help me..

    Reply
  4. satish

    Hello,

    Can anyone help me to sort this error out:

    ValueError: Can not do batch_dot on inputs with shapes (None, 10, 10, 1152, 16) and (None, 10, 1152, 1152, 16) with axes=[2, 3]. x.shape[2] != y.shape[3] (10 != 1152).

    I get this error when I try to run the DigitCapsuleLayer block of the code, i.e.

    class DigitCapsuleLayer(Layer):
    # creating a layer class in keras
    def __init__(self, **kwargs):
    super(DigitCapsuleLayer, self).__init__(**kwargs)
    self.kernel_initializer = initializers.get(‘glorot_uniform’)

    def build(self, input_shape):
    # initialize weight matrix for each capsule in lower layer
    self.W = self.add_weight(shape = [10, 6*6*32, 16, 8], initializer = self.kernel_initializer, name = ‘weights’)
    self.built = True

    def call(self, inputs):
    inputs = K.expand_dims(inputs, 1)
    inputs = K.tile(inputs, [1, 10, 1, 1])
    # matrix multiplication b/w previous layer output and weight matrix
    inputs = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs)
    b = tf.zeros(shape = [K.shape(inputs)[0], 10, 6*6*32])

    # routing algorithm with updating coupling coefficient c, using scalar product b/w input capsule and output capsule
    for i in range(3-1):
    c = tf.nn.softmax(b, dim=1)
    s = K.batch_dot(c, inputs, [2, 2])
    v = squash(s)
    b = b + K.batch_dot(v, inputs, [2,3])

    return v
    def compute_output_shape(self, input_shape):
    return tuple([None, 10, 16])

    Reply
    1. Atul Krishna Singh

      Hi Satish,

      See your code at this line:

      In keras batch_dot() is used to compute dot product between two keras tensor or variable where both should be in batch.

      In the code line I mentioned above you have specified target dimension as [2,3] which means that the sizes of x.shape[2] and W.shape[3] should be equal. Which is not in your code. That’s why there is an error.
      Hope this helps.

      Reply
      1. satish

        But I am using the code given in this web page only. I do not have my own core. Then, how come I get this error message? I remember few weeks ago when I tried this code there was no error. But today I tried the same code and I have this issue. Is there something wrong with my Keras version?

        Reply
        1. Mehmet Ali

          Configuring Keras version with following line has solved in my case:

          !pip install q keras==2.1.2

          Reply
    1. kang & atul Post author

      you can simply use following line:

      where x_test are input images.

      Take the label_predicted for classification and you can ignore the image_predicted.

      Reply
  5. Naseer

    How can I use this code for just classification part that is I have MNIST data and I just need to test the classification accuracy as I do not need reconstruction error and reconstruction part?

    Reply
    1. kang & atul Post author

      Hi Naseer,
      Firstly reconstruction part is also helping in classification part by boosting pose parameters learnt by digit capsule layer.
      Still if you want to only try classification part then you need to remove the decoder part from the model and create model accordingly. Then there is no need to use multi input and multi ouptut model, single input and output model will work.
      Hope this helps.

      Reply
  6. sreagm

    Can I use the same flow as described here to train a capsule network for binary classification task?

    Reply
  7. Saim

    I am facing difficulty to visualize and understand in my mind that how the Batch_dot between inputs and Weights has happened. Could you please explain. I haven’t find any solution on the internet . Although there is some explanation of 2D multiplication but not about 4D

    Reply

Leave a Reply