Keras Callbacks – EarlyStopping

One common problem that we face while training a neural network is of overfitting. This refers to a situation where the model fails to generalize. In other words, the model performs poorly on the test/validation set as compared to the training set. Take a look at the plot below.

Clearly, after ‘t’ epochs, the model starts overfitting. This is clear by the increasing gap between the train and the validation error in the above plot. Wouldn’t it be nice if we stop the training where the gap starts increasing? This will help prevent the model from overfitting. This method is known as Early Stopping. Some of the pros of using this method are

  • Prevents the model from overfitting
  • Parameter-free unlike other regularization techniques like L2 etc.
  • Removes the need to manually set the number of epochs. Because now the model will automatically stop training when the monitored quantity stops improving.

Fortunately, in Keras, this is done using the EarlyStopping callback. So, let’s first discuss its Keras API and then we will learn how to use this.

Keras API

In this, you first need to provide which quantity to monitor using the “monitor” argument. This can take a value from ‘loss’, ‘acc’, ‘val_loss’, ‘val_acc’ or ‘val_metric’ where metric is the name of the metric used. For instance, if the metric is set to ‘mse’ then pass ‘val_mse’.

After setting the monitored quantity, you need to decide whether you want to minimize or maximize it. For instance, we want to minimize loss and maximize accuracy. This can be done using the “mode” argument. This can take value from [‘min‘, ‘max‘, ‘auto‘]. Default is the ‘auto’ mode. In ‘auto’ mode, this automatically infers whether to maximize or minimize depending upon the monitored quantity name.

This stops training whenever the monitored quantity stops improving. By default, any fractional change is considered as an improvement. For instance, if ‘val_acc’ increases from 90% to 90.0001% this is also considered as an improvement. The meaning of improvement may vary from one application to another. So, here we have an argument “min_delta“. Using this we can set the minimum change in the monitored quantity to qualify as an improvement. For instance, if min_delta=1, so all the absolute changes of less than 1, will count as no improvement.

Note: This difference is calculated as the current monitored quantity value minus the best-monitored quantity value until now.

As we already know that neural networks mostly face the problem of plateaus. So monitored quantity may not show improvement for some time and then improve afterward. So, it’s better to wait for a few epochs before making the final decision to stop the training process. This can be done using the “patience” argument. For instance, a patience=3 means if the monitored quantity doesn’t improve for 3 epochs, stop the training process.

The model will stop training some epochs (specified by the “patience” argument) after the best-monitored quantity value. So, the weights you will get are not the best weights. To retrieve the best weights, set the “restore_best_weights” argument to True.

Sometimes for a task, we have a baseline in our mind that at least I should get a minimum of 75% accuracy within 5 epochs. If you are not getting this, there is no point training the model any further. Then you should try changing the hyperparameters and again retrain the model. In this, you can set the baseline using the “baseline” argument. If the monitored quantity minus the min_delta is not surpassing the baseline within the epochs specified by the patience argument, then the training process is stopped.

For instance, below is an example where the baseline is set to 98%.

The training process stops because of the val_acc – min_delta < baseline for the patience interval (3 epochs). This is shown below.

After surpassing the baseline, the Early Stopping callback will work as normal i.e. stop training when the monitored quantity stops improving.

Note: If you are not sure about the baseline in your task, just set this argument to None.

I hope you get some feeling about the EarlyStopping callback. Now let’s see how to use this.

How to use this?

Firstly, you need to create an instance of the “EarlyStopping” class as shown below.

Then pass this instance in the list while fitting the model.

That’s all for Early Stopping. Hope you enjoy reading.

If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.

Leave a Reply