In the last blog we have seen that what is a capsule network and how it can overcome the problems associated with convolutional neural network. In this blog we will implement a capsule network in keras.
Here, we will use handwritten digit dataset(MNIST) and train the capsule network to classify the digits. MNIST digit dataset consists of
Capsule Network architecture is somewhat similar to
- Initial convolutional layer
- Primary capsule layer
- Digit capsule layer
- Decoder network
- Loss Functions
- Training and testing of model
Initial Convolution Layer:
Initially we will use a convolution layer to detect low level features of an image. It will use 256 filters each of size 9*9 with stride 1 and activation function is relu. Input size of image is 28*28, after applying this layer output size will be 20*20*256.
1 2 3 4 |
input_shape = Input(shape=(28,28,1)) # size of input image is 28*28 # a convolution layer output shape = 20*20*256 conv1 = Conv2D(256, (9,9), activation = 'relu', padding = 'valid')(input_shape) |
Primary Capsule Layer:
The output from the previous layer is being passed to 256 filters each of size 9*9 with a stride of 2 w
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
# convolution layer with stride 2 and 256 filters of size 9*9 conv2 = Conv2D(256, (9,9), strides = 2, padding = 'valid')(conv1) # reshape into 1152 capsules of 8 dimensional vectors reshaped = Reshape((6*6*32,8))(conv2) # squash the reshaped output to make length of vector b/w 0 and 1 squashed_output = Lambda(squash)(reshaped) def squash(inputs): # take norm of input vectors squared_norm = K.sum(K.square(inputs), axis = -1, keepdims = True) # use the formula for non-linear function to return squashed output return ((squared_norm/(1+squared_norm))/(K.sqrt(squared_norm+K.epsilon())))*inputs |
Digit Capsule Layer:
Logic and algorithm used for this layer is explained in the previous blog. Here we will see what we need to do in code to implement it. We need to write a custom layer in keras. It will take 1152*8 as its input and produces output of size 10*16, where 10 capsules each represents an output class with 16 dimensional vector. Then each of these 10 capsules are converted into single value to predict the output class using a lambda layer.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
class DigitCapsuleLayer(Layer): # creating a layer class in keras def __init__(self, **kwargs): super(DigitCapsuleLayer, self).__init__(**kwargs) self.kernel_initializer = initializers.get('glorot_uniform') def build(self, input_shape): # initialize weight matrix for each capsule in lower layer self.W = self.add_weight(shape = [10, 6*6*32, 16, 8], initializer = self.kernel_initializer, name = 'weights') self.built = True def call(self, inputs): inputs = K.expand_dims(inputs, 1) inputs = K.tile(inputs, [1, 10, 1, 1]) # matrix multiplication b/w previous layer output and weight matrix inputs = K.map_fn(lambda x: K.batch_dot(x, self.W, [2, 3]), elems=inputs) b = tf.zeros(shape = [K.shape(inputs)[0], 10, 6*6*32]) # routing algorithm with updating coupling coefficient c, using scalar product b/w input capsule and output capsule for i in range(3-1): c = tf.nn.softmax(b, dim=1) s = K.batch_dot(c, inputs, [2, 2]) v = squash(s) b = b + K.batch_dot(v, inputs, [2,3]) return v def compute_output_shape(self, input_shape): return tuple([None, 10, 16]) def output_layer(inputs): return K.sqrt(K.sum(K.square(inputs), -1) + K.epsilon()) digit_caps = DigitCapsuleLayer()(squashed_output) outputs = Lambda(output_layer)(digit_caps) |
Decoder Network:
To further boost the pose parameters learned by the digit capsule layer, we can add decoder network to reconstruct the input image. In this part, decoder network will be fed with an input of size 10*16 (digit capsule layer output) and will reconstruct back the original image of size 28*28.
During training time input to the decoder is the output from digit capsule layer which is masked with original labels. It means that other vectors except the vector corresponding to correct label will be multiplied with zero. So that decoder can only be trained with correct digit capsule. In test time input to decoder will be the same output from digit capsule layer but masked with highest length vector in that layer. Lets see the code.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
def mask(outputs): if type(outputs) != list: # mask at test time norm_outputs = K.sqrt(K.sum(K.square(outputs), -1) + K.epsilon()) y = K.one_hot(indices=K.argmax(norm_outputs, 1), num_classes = 10) y = Reshape((10,1))(y) return Flatten()(y*outputs) else: # mask at train time y = Reshape((10,1))(outputs[1]) masked_output = y*outputs[0] return Flatten()(masked_output) inputs = Input(shape = (10,)) masked = Lambda(mask)([digit_caps, inputs]) masked_for_test = Lambda(mask)(digit_caps) decoded_inputs = Input(shape = (16*10,)) dense1 = Dense(512, activation = 'relu')(decoded_inputs) dense2 = Dense(1024, activation = 'relu')(dense1) decoded_outputs = Dense(784, activation = 'sigmoid')(dense2) decoded_outputs = Reshape((28,28,1))(decoded_outputs) |
Loss Functions:
It uses two loss function one is probabilistic loss function used for classifying digits image and another is reconstruction loss which is mean squared error. Lets see probabilistic loss which is simple to understand once you look at following code.
1 2 3 4 5 |
def loss_fn(y_true, y_pred): L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1)) return K.mean(K.sum(L, 1)) |
Training and Testing of model:
Now define our training and testing model and train it on MNIST digit dataset.
1 2 3 4 5 6 7 8 9 |
decoder = Model(decoded_inputs, decoded_outputs) model = Model([input_shape,inputs],[outputs,decoder(masked)]) test_model = Model(input_shape,[outputs,decoder(masked_for_test)]) m = 128 epochs = 10 model.compile(optimizer=keras.optimizers.Adam(lr=0.001),loss=[loss_fn,'mse'],loss_weights = [1. ,0.0005],metrics=['accuracy']) model.fit([x_train, y_train],[y_train,x_train], batch_size = m, epochs = epochs, validation_data = ([x_test, y_test],[y_test,x_test])) |
In test data set it was able to achieve 99.09% accuracy. Pretty good yeah! Also reconstructed images looks good. Here are the reconstructed images generated by decoder network.
Capsule Network comes with promising results and yet to be explored thoroughly. There are various bits and bytes where it can be explored. Research on a capsule network is still in an early stage but it has given clear indication that it is worth exploring.
Hope you enjoy reading.
If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.