If you are working with images, Especially for datasets that are too large to be stored in memory then binary data takes up less space on disk, and takes less time to copy and read.

TFRecord file format TensorFlow binary file format. It stores your data as a sequence of binary strings. The data that is required at the time is loaded from the disk and then processed. In TFRecord everything is in a single file and we can use that file to dynamically shuffle at random places and batch it.

In this tutorial, we first create TFRecord from images and consume TFRecord using tf.data.

Download Image Dataset

You’ll need a set of images to train CNN network about the new classes you want to recognize. Google created an archive of creative-commons licensed flower photos to use initially.

from __future__ import absolute_import, division, print_function
from tqdm import tqdm
from numpy.random import randn

import pathlib
import random
import matplotlib.pyplot as plt

import tensorflow as tf
import numpy as np


tf.enable_eager_execution()

AUTOTUNE = tf.data.experimental.AUTOTUNE

data_dir=tf.keras.utils.get_file('flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', untar=True)

data_dir = pathlib.Path(data_dir)

In order to store these features in a TFRecord, we first need to create the lists that constitute the features.

all_images = list(data_dir.glob('*/*'))
all_images = [str(path) for path in all_images]
random.shuffle(all_images)

image_count = len(all_images)

label_names={'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}

Convert Image to tf.example

In order to convert a standard TensorFlow type to a tf.Example you can use tf.train.Feature.tf.Example is{"string": tf.train.Feature} mapping.

tf.train.Features is a collection of named features. It has a single attribute feature where the key is the name of the features and the value a tf.train.Feature.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

tf.train.BytesList, tf.train.FloatList, and tf.train.Int64List are at the core of a tf.train.Feature. All three have a single attribute value, which expects a list of respective bytes, float, and int.

Creating tf.Example message

Each value needs to be converted to a tf.train.Feature containing one of the 3 compatible types, using one of the functions above. We create a map from the feature name string to the encoded feature value produced. The map produced is converted to a Features message.

def _convert_to_example(image_buffer, label, text):
  
  example = tf.train.Example(features=tf.train.Features(feature={
      'label': _int64_feature(label),
      'text':  _bytes_feature(tf.compat.as_bytes(text)),
    'encoded':_bytes_feature(tf.compat.as_bytes(image_buffer))}))
  return example

Write to TFRecord file

We have all of the features are now stores in the tf.Example message. Now, we functionalize the code above and write the example messages to a file, flower.tfrecords.

with tf.python_io.TFRecordWriter('flower.tfrecords') as writer:
  for filename in tqdm(all_images):
    image_buffer,text,label = _process_image(filename)
    example = _convert_to_example(image_buffer, label,text)
    writer.write(example.SerializeToString())

Create dataset using TFRecord

The tf.data API supports .tfrecord file formats so that you can process TFRecord file. The tf.data.TFRecordDataset class enables you to stream over the contents of one or more TFRecord files as part of an input pipeline.

image_dataset = tf.data.TFRecordDataset('flower.tfrecords')

Read the TFRecord

We can extract feature using tf.parse_single_example(). We also use labels to determine which flower.

# Create a dictionary describing the features.  
image_feature_description = {
    'label': tf.FixedLenFeature([], tf.int64),
    'text': tf.FixedLenFeature([], tf.string),  
    'encoded': tf.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  feature=tf.parse_single_example(example_proto, image_feature_description)
  
  image=feature['encoded']
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.image.resize_images(image, [192, 192])
  image /= 255.0  # normalize to [0,1] range
  return image,feature['label'],feature['text']

dataset = image_dataset.map(_parse_image_function)

BATCH_SIZE = 32

ds = dataset.apply(
  tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
# ds = dataset.shuffle(buffer_size=image_count)
ds = ds.repeat()
# ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=AUTOTUNE)


for image,label,text in ds.take(1):
  plt.title(text.numpy())
  plt.imshow(image)

Run in Google Colab