Whenever you train a model it takes a long time to train. Real-life models can take days or even weeks to train. If you cancel your training after it’s been running for a day or your model weights and values will be lost, then you would have to restart training from the beginning.

But if you saved your model then you can always resume training from that point. Another benefit is that you can take your model and transfer it to another computer, where you can continue training. In this tutorial, we’ll talk about saving and loading models.

Create a model

We use the CIFAR10 dataset to demonstrate model loading and saving. We normalize all pixel values to be between 0 and 1.

(x_train, y_train), (x_val, y_val) = tf.keras.datasets.cifar10.load_data()

x_train = x_train.astype('float32')
x_val = x_val.astype('float32')

x_train /= 255
x_val /= 255

IMG_SIZE=32

BATCH_SIZE=32

Next is the model definition, which is defined in the create_model function. This is a very basic model. We’re interested in learning how to load and save models, not creating the best model for the CIFAR10 dataset.

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(kernel_size=3, filters=32, padding='same', activation='relu', input_shape=[IMG_SIZE,IMG_SIZE, 3]),
      
      tf.keras.layers.Conv2D(kernel_size=3, filters=64, padding='same', activation='relu'),
      tf.keras.layers.MaxPooling2D(pool_size=2),
    
      tf.keras.layers.Conv2D(kernel_size=3, filters=128, padding='same', activation='relu'),
      tf.keras.layers.Conv2D(kernel_size=1, filters=256, padding='same', activation='relu'),
      tf.keras.layers.GlobalAveragePooling2D(),
    
      tf.keras.layers.Dense(10,'softmax')])

  model.compile(optimizer=tf.keras.optimizers.RMSprop(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  return model

Save the modal at the end of Epoch

We finally get to see how a model can be saved and perform model training by calling the fit method and providing a keras callback.

Create ModelCheckpoint

 The ModelCheckpoint callback is used to save the model after each training epoch. It can save multiple files or a single file. Setting save_weights_only to False in the Keras callback ‘ModelCheckpoint’ will save the full model.

model_checkpoint=tf.keras.callbacks.ModelCheckpoint('CIFAR10{epoch:02d}.h5',period=2,save_weights_only=False)

Make sure to include the epoch variable in your file path. Otherwise, your saved model will be replaced after every epoch. One option is to provide the period parameter when creating the model checkpoint object. In this case, 2, which as you can see saves a new model every 2 epochs.

history = model.fit(x_train[:2000], y_train[:2000],
              batch_size=BATCH_SIZE,
              epochs=6,
              callbacks=[model_checkpoint],
              validation_data=(x_val[:100], y_val[:100]),
              shuffle=True)

This will save a model every 2 epochs. If we look at the model directory, we can now see .h5 files. It contains the model’s configuration the model’s weights and the model’s optimizer’s state.

Save the Final Model as an HDF5 file

Another way of saving models is to call the save() method on the model. This will create an HDF5 formatted file. The save method saves additional data, like the model’s configuration and even the state of the optimizer. A model that was saved using the save() method can be loaded with the function keras.models.load_model.

model.save('my_model.h5') 

This allows you to save the entirety of the state of a model in a single file.

Load Model and Continue training

The saved model can be re-instantiated in the exact same state, without any of the code used for model definition or training.

new_model = tf.keras.models.load_model('my_model.h5')

new_model.evaluate(x_val,y_val)

The model returned by load_model() is a compiled model ready to be used unless the saved model was not compiled. Re-compiling the model will reset the state of the model.

It is possible to save a partly trained model and continue training after re-loading the model again. It’s useful when we have more training data in the future and we do not want to retrain the whole model again.

history = new_model.fit(x_train[:2000], y_train[:2000],
              batch_size=BATCH_SIZE,
              epochs=6,
              validation_data=(x_val[:100], y_val[:100]),
              shuffle=True)

Related Post