Whenever you train a model the training can take a long time. Whereas real-life models can take a day or even weeks to train. If you do not save your trained model all your model weights and values will be lost, and you would have to restart training from the beginning but if you saved your model you can always resume training.

Another benefit is that you can take your model and transfer it to another computer, where you can continue training. In this tutorial, we’ll talk about saving, loading, and predicting values from trained TensorFlow models.

After completing this tutorial, you will learn:

  • How to develop a convolutional neural network to classify the image of dogs and cats(binary classification).
  • How to save the TensorFlow model as a .pb file
  • How to load the .pb file and predict the image of dogs and cats for classification.

Download Data

TensorFlow Datasets  package is the easiest way to load pre-defined data.

import os

import numpy as np

import matplotlib.pyplot as plt

import tensorflow as tf

import tensorflow_datasets as tfds
tfds.disable_progress_bar()

SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs', split=list(splits),
    with_info=True, as_supervised=True)

The tfds.load method downloads and caches the data, and returns a tf.data.Dataset object. We use the subsplit feature to divide it into (train, validation, test) with 80%, 10%, 10% of the data respectively.

Format Data

Use the tf.image module to format the images for the task. Resize the images to a fixed input size, and rescale the input channels to a range of [-1,1]

IMG_SIZE = 160 # All images will be resized to 160x160

def format_example(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label

Apply this function to each item in the dataset using the map method. Now shuffle and batch the data.

train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)

Create the base CNN model

First, we develop a baseline convolutional neural network model for the dogs vs. cats dataset. The model architecture that we can use as the basis of study and improvement.

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(kernel_size=3, filters=16, padding='same', activation='relu', input_shape=[IMG_SIZE,IMG_SIZE, 3]),
    tf.keras.layers.Conv2D(kernel_size=3, filters=30, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Conv2D(kernel_size=3, filters=60, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Conv2D(kernel_size=3, filters=90, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Conv2D(kernel_size=3, filters=110, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Conv2D(kernel_size=3, filters=130, padding='same', activation='relu'),
    tf.keras.layers.Conv2D(kernel_size=1, filters=40, padding='same', activation='relu'),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1,'sigmoid')
])

The model is a stack of convolutional layers with small 3×3 filters followed by a max pooling layer. These blocks can be repeated where the number of filters in each block is increased with the depth of the network such as 16, 30, 60, 90. Padding is used on the convolutional layers to ensure the height and width shapes of the output feature maps match the inputs. Each layer will use the ReLU activation function.

Compile Model

You must compile the model before training it. Since there are two classes, use a binary cross-entropy loss.

model.compile(optimizer='Adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

Binary classification requires the prediction of one value of either 0 or 1. An output layer with 1 node and a sigmoid activation will be used.

Train the model

After training for 5 epochs, you should see ~84% accuracy.

history = model.fit(train_batches,
                    epochs=1,
                    validation_data=validation_batches)

Evaluate Model

The tf.keras.Model.evaluate methods can use a tf.data.Dataset to evaluate the inference-mode loss and metrics for the data provided:

model.evaluate(test_batches)

Save Keras Model as .pb

Saving a fully-functional model is very useful—you can load them in TensorFlow.js (HDF5, Saved Model) and then train and run them in web browsers, or convert them to run on mobile devices using TensorFlow Lite (HDF5, Saved Model)

tf.saved_model.save(model, "/tmp/cnn/1/")

Saving a Keras Model using tf.saved_model.save now saves the list of variables, trainable variables, regularization losses, and the call function.

SavedModels have named functions called signatures. Keras models export their forward pass under the serving_default signature key. The SavedModel command line interface is useful for inspecting SavedModels on disk:

!saved_model_cli show --dir /tmp/cnn/1 --tag_set serve --signature_def serving_default

Load .pb Model and Predict

We can load the SavedModel back into Python with tf.saved_model.load and see how Dog’s image is classified.

loaded = tf.saved_model.load("/tmp/cnn/1/")
print(list(loaded.signatures.keys()))  # ["serving_default"]

Running inference from the SavedModel gives the same result as the original model.

infer = loaded.signatures["serving_default"]
print(infer.structured_outputs)

get_label_name = metadata.features['label'].int2str

for image, label in raw_test.take(1):  
  img,label=format_example(image,label)
  img=tf.expand_dims(img,axis=0)
  result = infer(tf.constant(img))['dense']
  if(result[0][0]<.50):
    result="cat"
  else:
    result="dog"
  plt.figure()
  plt.imshow(image)
  plt.title((result,label.numpy()))