If you want to lower-level your training & evaluation code than what fit() and evaluate() provide, you should write your own training code. When you writing your own model training & evaluation code it works strictly in the same way across every kind of Keras model — Sequential models, models built with the Functional API, and models written from scratch via model subclassing.

In this tutorial, we write custom training from scratch using the GradientTape object. 


In the next code snippet, we’ll use the CIFAR10 dataset as tf.data.Dataset, in order to demonstrate how to use optimizers, losses, and metrics in custom training function.

train_ds = tfds.load(name="cifar10",split=tfds.Split.TRAIN)
test_ds = tfds.load(name="cifar10",split=tfds.Split.TEST)

def scale_image(features):
    image = tf.cast(features['image'], tf.float32)
    image /= 255
    return image, features['label']

train_batches = train_ds.map(scale_image).shuffle(1024).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)

test_batches = test_ds.map(scale_image).shuffle(1024).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)

Create a Model

Let’s consider the following model. Here, we build with the Sequential API, but it could be a Functional or a subclassed model as well.

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(kernel_size=3, filters=32, padding='same', activation='relu', input_shape=[IMG_SIZE,IMG_SIZE, 3]),
      tf.keras.layers.Conv2D(kernel_size=3, filters=64, padding='same', activation='relu'),
      tf.keras.layers.Conv2D(kernel_size=3, filters=128, padding='same', activation='relu'),
      tf.keras.layers.Conv2D(kernel_size=1, filters=256, padding='same', activation='relu'),

  return model

Compile Model

You could just skip passing a loss function and metrics in compile(), and instead, do everything manually in custom training. Here’s an example, that only uses compile() to configure the optimizer.


Specifying Loss and Metrics

To train a model with fit(), you need to specify a loss function, an optimizer, and optionally, some metrics to monitor.

loss_avg = tf.keras.metrics.Mean(name='loss')
accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="acc")
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

Override train_step

When you need to customize what fit() does, you should override the train_step() of the Model class. This is the function that is called by fit() for every batch of data. You will then be able to call fit() as usual — and it will be running your own learning algorithm. This method should contain logic for one step of training.

class CustomModel(tf.keras.Model):
  def train_step(self,data):
   return {'loss': loss_avg.result(), 'accuracy': accuracy.result()}

You can now use custom training logic without worrying about all of the features, model.fit() handles for you like distribution strategies, callbacks, data formats, looping logic, etc. Same applies for validation and inference via model.test_step() and model.predict_step().

It returns a ‘dict’, the values of the model’s metrics are returned. Example: {'loss': 0.2, 'accuracy': 0.7}.

This gets rid of the need for users to manually call model._set_inputs() when using Custom Training Loops. 

Gradient Tap

Calling a model inside a GradientTape scope enables you to retrieve the gradients of the trainable weights of the layer with respect to a loss value. Using an optimizer instance, you can use these gradients to update these variables.

def train_step(self,data):

    with tf.GradientTape() as tape:

      loss= loss_fn(y_train_batch,y_pred)



Importantly, we compute the loss via self.compiled_loss, which wraps the loss function that were passed to compile().

loss = self.compiled_loss(y, y_pred,regularization_losses=self.losses)

Update Loss And Metrics

State update and results computation are kept separate in update_state() and result(), respectively because in some cases, results computation might be very expensive, and would only be done periodically.

def train_step(self,data):

    return {'loss': loss_avg.result(), 'accuracy': accuracy.result()}

update_state(y_true, y_pred) uses the targets y_true and the model predictions y_pred to update the state variables.

result(), which uses the state variables to compute the final results.

We can also call self.compiled_metrics.update_state(y, y_pred) to update the state of the metrics that were passed in compile(), and we query results from self.metrics() at the end to retrieve their current value.


Note that this custom training does not prevent you from building models with the Functional API. You can do this whether you’re building Sequential models, Functional API models, or subclassed models. Using this custom training algorithm, you still get the benefit from the convenient features of fit(), such as callbacks, built-in distribution support, or step fusing, etc.

Run this code in Google Colab