In this blog, we will discuss how to create custom callbacks in Keras. This is actually very simple. You just need to create a class that takes keras.callbacks.Callback() as its base class. The set of methods that we can use is also fixed. We just need to write the logic. Let’s understand this with the help of an example. Here, we will create a callback that stops the training when the accuracy has reached a threshold and prints the message.
1 2 3 4 5 6 7 |
class myCallback(keras.callbacks.Callback): def on_epoch_end(self, epoch, logs={}): if logs.get('acc')>0.99: self.model.stop_training = True print('Stopped training as accuracy above threshold') callbacks1 = myCallback() |
In “Line-1“, we create a class “mycallback” that takes keras.callbacks.Callback() as its base class.
In “Line-2“, we define a method “on_epoch_end”. Note that the name of the functions that we can use is already predefined according to their functionality. For instance, if we define a function by the name “on_epoch_end“, then this function will be implemented at the end of every epoch. If you change this function name to “on_epoch_end1“, then this function will not be implemented.
Below are some of the method names that we can use. The name of these functions is aptly named according to their functionality. The arguments that they can take is already fixed.
The epoch and batch arguments refer to the current epoch and batch number. And “logs” is a dictionary that records all the training events like “loss”, “acc” etc.
In “Line-3,4“, we define a stopping condition and if met stop training the model. Note that we can access the model being trained through the base class. And so we can use any other model properties like save_weights, save, trainable, etc.
At last, we create an instance of this class and pass this instance as a list in the fit() method. Below is the output of applying the above callback.
Below is another example that saves the weights at the end of each epoch.
1 2 3 4 5 |
class myCallback(keras.callbacks.Callback): def on_batch_end(self, batch,logs={}): self.model.save_weights('D:/downloads/best{}.hdf5'.format(batch)) callbacks1 = myCallback() |
This way you can create many other custom callbacks in keras. This gives you more control over the training process. Hope you enjoy reading.
If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.