In this post, we’re going to retrain an Image Classifier TensorFlow Model. You’ll need to install TensorFlow and you’ll need to understand how to use the command line.
1.Collect training data
We’re going to write a function to classify a piece of fruit Image. For starters, it will take an image of the fruit as input and predict whether it’s an apple or oranges as output.The more training data you have, the better a classifier you can create (at least 50 images of each, more is better). The example folder fruits images should have a structure like this: We will create a
~/tf_files/fruits folder and place each set of jpeg images in subdirectories (such as
~/tf_files/fruits/orange etc) The subfolder names are important.They define what label is applied to each image, but the filenames themselves don’t matter. A quick way to download multiple images at once is Chrome extension for batch download. To retrain your classifier you need to run a couple of scripts. You only need to provide one thing–training data.
retrain.py script is part of the TensorFlow repo.You need to download it manually, to the current directory(
curl -O https://raw.githubusercontent.com/tensorflow/hub/master/examples/image_retraining/retrain.py
Now, we have a trainer, we have data(Image), so let’s train! We will re-train the Inception v3 network.
Step2: Before starting the training active TensorFlow.
Step3: Start your image retraining with one big command.
python retrain.py \ --bottleneck_dir=bottlenecks \ --how_many_training_steps=4000 \ --model_dir=inception \ --summaries_dir=training_summaries/basic \ --output_graph=retrained_graph.pb \ --output_labels=retrained_labels.txt \ --image_dir=fruits
These commands download the inception model and retrain it to classify images for
~/tf_files/fruits. This operation can take several minutes depending on how many images you have and how many training steps you specified. The script will generate two files: the model in a protobuf file (retrained_graph.pb) and a label list of all the objects it can recognize (retrained_labels.txt).
Clone the Git repository for testing model
The following command will clone the Git repository containing the files required for the test model.
git clone https://github.com/googlecodelabs/tensorflow-for-poets-2
The repo contains two directories: android/, and scripts/ 1.android/: Directory contains nearly all the files necessary to build a simple Android app that classifies images. 2.scripts/: Directory contains the python scripts. These include scripts to prepare, test and evaluate the model. Now copy the tf_files directory from the first part, into /tensorflow-for-poets-2 working directory.
Test the Model
The scripts/ directory contains a simple command line script,
label_image.py, to test the network.
python -m scripts.label_image \ tf_files/fruits/grapes/grapes_2.jpg \ tf_files/retrained_graph.pb
Optimize model for Android
TensorFlow installation includes a tool,
optimize_for_inference, that removes all nodes that aren’t needed for a given set of input and output nodes.
python -m tensorflow.python.tools.optimize_for_inference \ --input=tf_files/retrained_graph.pb \ --output=tf_files/optimized_graph.pb \ --input_names="Cast" \ --output_names="final_result"
It creates a new file at
Make the model compressible
The retrained model is still 84MB in size at this point. That large download size may be a limiting factor for any app that includes it. Neural network operation requires a bunch of matrix characterizations, which means tons of multiply and add operations. Current mobile devices are capable of doing some of them with specialized hardware.
Quantization is one of the techniques to reduce both memory footprint and computer load. Now use the
quantize_graph script to apply changes:
python -m scripts.quantize_graph \ --input=tf_files/optimized_graph.pb \ --output=tf_files/rounded_graph.pb \ --output_node_names=final_result \ --mode=weights_rounded
It does this without any changes to the structure of the network, it simply quantizes the constants in place.It creates a new file at
tf_files/rounded_graph.pb. Every mobile app distribution system compresses the package before distribution. So test how much the graph can be compressed: You should see a significant improvement. I get 73% optimize model.