In this blog, we will discuss Keras TerminateOnNaN callback. As clear from the name, this terminates the training when a Nan loss is encountered. Below is the Keras API for this callback.
1 |
keras.callbacks.TerminateOnNaN() |
This checks the loss at every batch end and if that loss is nan or inf, this callback stops the training. This prints out the batch number at which it stops the training. Something like this will be printed.
Below is the code, taken from Keras that shows how this works.
1 2 3 4 5 6 7 8 |
class TerminateOnNaN(Callback): def on_batch_end(self, batch, logs=None): logs = logs or {} loss = logs.get('loss') if loss is not None: if np.isnan(loss) or np.isinf(loss): print('Batch %d: Invalid loss, terminating training' % (batch)) self.model.stop_training = True |
Similarly, you can create your own custom callback that tracks some other metrics. Now, let’s see how to use this callback.
1 2 3 4 5 6 7 8 9 |
# Load data, preprocessing and build model ... # First, create an instance of this TerminateOnNaN class from keras.callbacks import TerminateOnNaN call = TerminateOnNaN() #Then pass this as a list in the fit() method record = model.fit(..., callbacks=[call],...) |
Hope you enjoy reading.
If you have any doubt/suggestion please feel free to ask and I will do my best to help or improve myself. Good-bye until next time.