Writing your own input pipeline in Python to read data and transform it can be pretty inefficient. TensorFlow provides the tf.data API to allow you to easily build performance and scalable input pipelines.
We are going to talk about TensorFlow’s Dataset APIs that you can use to make your training more performant. This is the API for writing high-performance pipelines to avoid various sorts of stalls and make sure that your training always has data as it’s ready to consume it.
In order to demonstrate all of these APIs, we’re going to be walking through a case study starting with the most naive implementation and then progressively adding more performant APIs and looking at how that helps our training. So we’re going to be training an image classifier and we’re going to be training it on VGG16.
You can think of the
tf.data input pipeline as an ETL(Extract, Transform, Load) process. So the first stage is the extract stage where we read the data from, let’s say, network storage or from your local disk, and then you potentially are parsing the file format.
In this tutorial we will use the TensorFlow flowers dataset:
flowers = keras.utils.get_file(
tf.data doesn’t provide any tools for split datasets. You could use
sklearn.model_selection.train_test_split to generate train/eval/test datasets, then create
from sklearn.model_selection import train_test_split
It’s worth noting that different parts of the data pipeline will stress different parts of the system. So loading from the disk is an I/o- bound task and we’ll generally want to consume this I/O as fast as possible so that we’re not constantly waiting for images to arrive from disk one at a time.
We’ll need to load all of the images from the disk into memory so that they can be consumed by the training process. Resize the images from their native format down–back to the format that the model expects.
parts = tf.strings.split(filename, '/')
label =parts[-2] == CLASS_NAMES
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
return image, label
We’ll also want to do some augmentation. Then we’ll do some task-specific optimization. So, in this case, we’ll randomly flip the images and we’ll add some pixel-level noise to make our training a little bit more robust.
Augmentations tend to be CPU intensive because we’re doing various sorts of math in our augmentation.
def augment_image(image, label):
im_shape = image.shape
return image, tf.cast(label,tf.float32)
We do some decoding of the image data, and then we apply some image-processing transformations such as resizing, flipping normalization, and so on. The key things to note here are these transformations are actually very similar and correspond one to one.
The first one is we create a simple data set consisting of all the filenames in our input.
Second, we’ll want to shuffle the data so that we see a different ordering for each epoch.
Next, we apply a transformation called the map transformation. We provide this
parse_image() custom function. What this function does is that it’s going to read the file one by one using the
tf.io.read_file API and it uses the filename path to compute the label and returns both of these.
The next thing we do is another map transformation to now take this raw image data but convert it and do some processing to make it amenable for our training task. So we have this
augment_image function here which we provide to this map transformation.
By default, the map transformation will apply the custom function that you provide to each element of your input data set in sequence. But if there is no dependency between these elements, there’s no reason to do this in sequence, right? So you can parallelize this by passing the
num_parallel_calls argument to the map transformation.
Finally, batching tends to be a somewhat memory-intensive task because we’re copying examples from their original location into the memory buffer of our mini-batch.
This is a very common practice for training efficiency in ML tasks.
We want to use the CPU resource to process and prepare the next batch of data. What this will do is that when the next training step starts, we don’t have to wait for the next batch of data to be prepared. It will automatically be there, and this can reduce the overall training time significantly.
ds = ds.prefetch(buffer_size=AUTOTUNE)
If you’ve been paying close attention, you’ll notice that we have these magic numbers
num_parallel_calls. You might be wondering, how do you determine the optimal values of these? In reality, it’s actually not very straightforward to compute the optimal values of these parameters because if you set them too low, you might not be using enough parallelism in your system and if you set them too high, it might lead to contention and actually have the opposite effect of what you want.
Fortunately, tf.data makes it really easy for you to specify these. Instead of specifying specific values for these arguments, you can simply use this constant tf.data experimental AUTOTUNE. What this does is it indicates to the tf.data runtime that it should do the autotuning for you and determine the optimal values for these arguments based on your workload, your environment, your setup.
ds = ds.prefetch(buffer_size=AUTOTUNE)
Create a Model
We’re training on VGG16. We’ve chosen to use the Keras Applications VGG16 which is just a canned VGG16 implementation. We’re training a classifier. So we’re going to be using
categorical_crossentropy. Then, finally, we’re going to be using the
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(len(CLASS_NAMES),activation='softmax')
model = tf.keras.Sequential([
Starting from Tensorflow 1.9, you can pass
tf.data.Dataset objects directly into
history = model.fit(train_ds,
Note that this is just one way in which you can read data using
tf.data, and there are a number of different APIs that you can use for other situations.
The most notable one is the TFRecordDataset API, which you would use if your data is in the
TFRecord file format and this is potentially the most performant format to use with