Now a day we use GPU or TUP to train our Neural Networks.  The challenge is to feed data fast enough to keep them busy. If your data is stored as thousands of individual files is not ideal.

The rule of thumb is to split your data across several large files. If you have too many files, the time to access each file might start getting higher. If you have too few files, like one or two, then you are not getting the benefits of streaming from multiple files in parallel.

In this tutorial, we are going to batch them in a smaller TFRecord file and use the power of to read from multiple files in parallel.

The TFRecord file format is a simple record-oriented binary format. If your input data are on disk or working with large data then TensorFlow recommended using TFRecord format. You get a significant impact on the performance of your input pipeline. Binary data takes less space on disk, takes less time to copy, and can be read more efficiently from disk.


In this post, we load, resize and convert to TFRecord of the well-known Dogs vs. Cats data set and then load. you need to download the train part of the Dogs vs. Cats data set.

Convert Dataset into TFRecord

First, we need to load the image and resize it to the target size in which we want to save the data into a TFRecords file.

def read_image_and_label(img_path):
  bits =
  image = tf.image.decode_jpeg(bits)
  image = tf.image.resize_images(image, [TARGET_SIZE, TARGET_SIZE])
  label = tf.strings.split(img_path, sep='/')
  label = tf.strings.split(label[-1], sep='.')
  return image,label[0]

dataset ='data/train/*.jpg', seed=10000) # This also shuffles the data
dataset =

Next, you should stuff data in a protocol buffer called Example. Example protocol buffer contains Features. The feature is a protocol to describe the data and could have three types: bytes, float, and int64.

def _bytestring_feature(list_of_bytestrings):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings))

def _int_feature(list_of_ints): # int64
  return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints))

def _float_feature(list_of_floats): # float32
  return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))

def to_tfrecord(img_bytes, label):  
  class_num = np.argmax(np.array(CLASSES)==label) 
  feature = {
      "image": _bytestring_feature([img_bytes]), # one image in the list
      "class": _int_feature([class_num]),        # one class in the list      
  return tf.train.Example(features=tf.train.Features(feature=feature))

Then, we serialize the protocol buffer to a string and write it to a TFRecords file.

def recompress_image(image, label):
  image = tf.cast(image, tf.uint8)
  image = tf.image.encode_jpeg(image, optimize_size=True, chroma_downsampling=False)
  return image, label

dataset =, num_parallel_calls=AUTO)
dataset = dataset.batch(shared_size) 

for shard, (image, label) in enumerate(dataset):
  shard_size = image.numpy().shape[0]
  filename = "cat_dog" + "{:02d}-{}.tfrec".format(shard, shard_size)
  with as out_file:
    for i in range(shard_size):
      example = to_tfrecord(image.numpy()[i],label.numpy()[i])
    print("Wrote file {} containing {} records".format(filename, shard_size))

The code that loads image files, resizes them to a common size and then stores them across 16 TFRecord files It will read from 32 files in parallel and disregard data order in favor of reading speed.

Read TFRecord Dataset

Input pipelines extract tf.train.Example protocol buffer messages from a TFRecord-format file. Each tf.train.Example record contains one or more “features”, and the input pipeline typically converts these features into tensors.

def read_tfrecord(example):
    features = {
        "image":[], tf.string),  # tf.string = bytestring (not text string)
        "class":[], tf.int64),   # shape [] means scalar
    # decode the TFRecord
    example =, features)
    image = tf.image.decode_jpeg(example['image'], channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [TARGET_SIZE,TARGET_SIZE, 3])
    class_label = tf.cast(example['class'], tf.int32)
    return image, class_label

def get_batched_dataset(filenames):
  option_no_order =
  option_no_order.experimental_deterministic = False

  dataset =
  dataset = dataset.with_options(option_no_order)
  dataset = dataset.interleave(, cycle_length=16, num_parallel_calls=AUTO)
  dataset =, num_parallel_calls=AUTO)

  dataset = dataset.cache() # This dataset fits in RAM
  dataset = dataset.repeat()
  dataset = dataset.shuffle(2048)
  dataset = dataset.batch(BATCH_SIZE, drop_remainder=True) 
  dataset = dataset.prefetch(AUTO) #
  return dataset

Build a Simple Sequential Model

The code below defines the convolutional base using a stack of Conv2D and MaxPooling2Dlayers. As input, a CNN takes tensors of shape (image_height, image_width, color_channels)

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(kernel_size=3, filters=16, padding='same', activation='relu', input_shape=[TARGET_SIZE,TARGET_SIZE, 3]),
    tf.keras.layers.Conv2D(kernel_size=3, filters=32, padding='same', activation='relu'),
    tf.keras.layers.Conv2D(kernel_size=3, filters=64, padding='same', activation='relu'),
    tf.keras.layers.Conv2D(kernel_size=3, filters=128, padding='same', activation='relu'),
    tf.keras.layers.Conv2D(kernel_size=3, filters=256, padding='same', activation='relu'),


This is how a simple convolutional neural network looks in Keras. Now that the model is defined, you can train the model using a

Feed data using

You can pass object directly into fit() .

def get_training_dataset():
  return get_batched_dataset(training_filenames)

def get_validation_dataset():
  return get_batched_dataset(validation_filenames)

history =, steps_per_epoch=steps_per_epoch, epochs=10,
                      validation_data=get_validation_dataset(), validation_steps=validation_steps)

Here, the fit method uses the steps_per_epoch argument. This is the number of training steps the model runs before it moves to the next epoch.

Related Post

Convert PASCAL dataset to TFRecord for object detection in TensorFlow

Run this code in google colab