Keras Callbacks – LearningRateScheduler

In neural networks, setting up a good learning rate is always a challenging task. If the learning rate is set too high, this can cause undesirable divergent behavior in your loss function or sometimes your model can converge too quickly to a sub-optimal value. If it is set too low, the training process may take a long time. Thus, it often proves sometimes useful to decay the learning rate as the training progresses. This can be done using the Learning rate schedules or the adaptive learning rate methods like SGD, Adam, etc. In this blog, we will only discuss Learning rate schedules.

Learning rate schedules as clear from the name adjusts the learning rates based on some schedule. For instance, time decay, exponential decay, etc. To implement these decays, Keras has provided a callback known as LearningRateScheduler that adjusts the weights based on the decay function provided. So, let’s discuss its Keras API.

Here, the “schedule” is a decay function that takes epoch number and the current learning rate as the input and returns the new learning rate. The verbose argument tells us whether to print the following message when changing the learning rate or not.

Note: This method overwrites the learning rate of the optimizer used.

Now, let’s discuss the schedule argument in more detail. Here, we will use the time decay function. This updates the learning rate by the expression below

Now, let’s see how to use this using the LearningRateScheduler callback. First, create a function that takes epoch and learning rate as arguments as shown below

Then pass this function in the LearningRateScheduler callback as shown below

Now, simply pass this callback as a list in the fit() method as shown below.

You can also check how the learning rate varies by the following command. This returns a list of learning rates over the epochs.

You can also plot the learning rate over epochs using any plotting library.

Similarly, you can use other decays like step decay, exponential decay or any custom decay. Just create a function and pass it to the LearningRateScheduler callback. Hope you enjoy reading.

If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.

Leave a Reply