If you are working with images, Especially for datasets that are too large to be stored in memory then binary data takes up less space on disk, and takes less time to copy and read.
TFRecord file format TensorFlow binary file format. It stores your data as a sequence of binary strings. The data that is required at the time is loaded from the disk and then processed. In
In this tutorial, we first create
Download Image Dataset
You’ll need a set of images to train CNN network about the new classes you want to recognize. Google created an archive of creative-commons licensed flower photos to use initially.
from __future__ import absolute_import, division, print_function
from tqdm import tqdm
from numpy.random import randn
import pathlib
import random
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
AUTOTUNE = tf.data.experimental.AUTOTUNE
data_dir=tf.keras.utils.get_file('flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', untar=True)
data_dir = pathlib.Path(data_dir)
In order to store these features in a TFRecord, we first need to create the lists that constitute the features.
all_images = list(data_dir.glob('*/*'))
all_images = [str(path) for path in all_images]
random.shuffle(all_images)
image_count = len(all_images)
label_names={'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
Convert Image to tf.example
In order to convert a standard TensorFlow type to tf.Example
tf.train.Feature.
tf.Example
{"string": tf.train.Feature}
mapping.
tf.train.
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
tf.train.BytesList, tf.train.FloatList, and tf.train.Int64List are at the core of a tf.train.Feature. All three have a single attribute value, which expects a list of respective bytes, float, and int.
Creating tf .Example message
Each value needs to be converted to a tf.train.
def _convert_to_example(image_buffer, label, text):
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(label),
'text': _bytes_feature(tf.compat.as_bytes(text)),
'encoded':_bytes_feature(tf.compat.as_bytes(image_buffer))}))
return example
Write to TFRecord file
We have all of the features are now stores in tf.Example
flower.tfrecords
with tf.python_io.TFRecordWriter('flower.tfrecords') as writer:
for filename in tqdm(all_images):
image_buffer,text,label = _process_image(filename)
example = _convert_to_example(image_buffer, label,text)
writer.write(example.SerializeToString())
Create dataset using TFRecord
The tf.data
API supports .tfrecord file formats so that you can process tf.data.TFRecordDataset
image_dataset = tf.data.TFRecordDataset('flower.tfrecords')
Read the TFRecord
We can extract feature using tf.parse_single_example()
. We also use labels to determine which flower.
# Create a dictionary describing the features.
image_feature_description = {
'label': tf.FixedLenFeature([], tf.int64),
'text': tf.FixedLenFeature([], tf.string),
'encoded': tf.FixedLenFeature([], tf.string),
}
def _parse_image_function(example_proto):
# Parse the input tf.Example proto using the dictionary above.
feature=tf.parse_single_example(example_proto, image_feature_description)
image=feature['encoded']
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize_images(image, [192, 192])
image /= 255.0 # normalize to [0,1] range
return image,feature['label'],feature['text']
dataset = image_dataset.map(_parse_image_function)
BATCH_SIZE = 32
ds = dataset.apply(
tf.data.experimental.shuffle_and_repeat(buffer_size=image_count))
# ds = dataset.shuffle(buffer_size=image_count)
ds = ds.repeat()
# ds = ds.batch(BATCH_SIZE)
ds = ds.prefetch(buffer_size=AUTOTUNE)
for image,label,text in ds.take(1):
plt.title(text.numpy())
plt.imshow(image)