Keras model training generates subsequent versions of a model  after each epoch. The model weights are adjusted, and as a result, a new version of the model is created. Each new version will have varying performance levels as evaluated against a validation set.

Training and validation loss will decrease with the number of training epochs if everything goes well. However, the best model is rarely obtained at the end of the training process.

In the overfitting case  –  in starting, both training and validation losses decrease as training progresses. At some point, the validation loss might start increasing, even though the training loss continues to decrease. From this point on, subsequent model versions produced during the training process are overfitting the training data. 

Keras Early Stopping Monitor Options

These model versions are less likely to generalize well to unseen data. In this case, the best model would be the one obtained at the point where the validation loss started to diverge.

A much better way to handle this is to stop training when you measure that the validation loss is no longer improving. This can be achieved using a Keras callback

Early Stopping

Keras EarlyStopping callback interrupts training once a target metric has stopped improving for a fixed number of epochs. For instance, this callback allows you to interrupt training as soon as you start overfitting, thus avoiding having to retrain your model for a smaller number of epochs. 

import keras

#Monitors the model’s validation loss

#Interrupts training when validation loss has stopped improving for more than one epoch (that is, two epochs)

callbacks_list = [
        keras.callbacks.EarlyStopping(
            monitor='val_loss',
            patience=1,
        )]

#You monitor validation loss, so it should be part of the model’s metrics.
model.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['acc','val_loss'])

#Note that because the callback will monitor validation loss and validation accuracy, you need to pass validation_data to the call to fit.

model.fit(x, y,
          epochs=10,
          batch_size=32,
          callbacks=callbacks_list,
          validation_data=(x_val, y_val))
If you pass metrics.Metric objects, monitor should be set to metric.name. If you're not sure about the metric names you can check the contents of the `history.history` dictionary returned by history = model.fit().

keras.callbacks.EarlyStopping is passed to the model.fit and that is called by the model at various points during training. It has access to all the available data about the state of the model and its performance, and it can interrupt training, save a model, load a different weight set, or otherwise alter the state of the model. 

Which parameters should be used for early stopping?

Up to this point, we used validation loss as our target metric to identify the best model during training. Why validation loss you might ask? 

Validation Loss

This metric is calculated on the entire validation set after the weight updates, in order to ascertain the model’s performance on the unseen data or the model’s ability to generalize on data that it has not directly.

Training Loss

It seems that the training loss is always less than the validation loss. But, both of these cases happen when training a model.

The validation set can be easier than the training set. For example, data augmentations often distort or occlude parts of the image.

Training loss is measured after each batch, while the validation loss is measured after each epoch, so on average the training loss is measured ½ an epoch earlier. This means that the validation loss has the benefit of extra gradient updates.

Regularization is typically only applied during training, not validation and testing. For example, if you’re using dropout, the model has fewer features available to it during training.

However, validation loss or training loss might not be that relevant to your particular use case or domain, and any other metric can be used instead. 

Related Post

How to detect Overfitting and Underfitting using Training and Validation Loss?

Save the best model using ModelCheckpoint and EarlyStopping in Keras

Calculate and Plot AUC ROC Curve for Multi-Class Classification

How to set steps_per_epoch,validation_steps, and validation_split in Keras’s fit()

How to save Keras training History object to File using Callback?