How to Create a Cartoonizer with TensorFlow Lite?
This is an end-to-end tutorial on how to convert a TensorFlow model to TensorFlow Lite (TFLite) and deploy it to an Android app to cartoonize an image captured by the camera.
We created this end-to-end tutorial to help developers with these objectives:
- Provide a reference for the developers looking to convert models written in TensorFlow 1.x to their TFLite variants using the new features of the latest (v2) converter — for example, the MLIR-based converter, more supported ops, and improved kernels, etc.
(In order to convert TensorFlow 2.x models in TFLite please follow this guide.)
- How to download the .tflite models directly from TensorFlow Hub if you are only interested in using the models for deployment.
- Understand how to use the TFLite tools such as the Android Benchmark Tool, Model Metadata, and Codegen.
- Guide developers on how to create a mobile application with TFLite models easily, with ML Model Binding feature from Android Studio.
The project repo contains notebooks for saving and converting to .tflite models and the Android code (learn more about the SavedModel format on the TensorFlow doc). The TFLite models are also available for download directly from TensorFlow Hub here.
White-box CartoonGAN (by Xinrui Wang and Jinze Yu) transforms an input image (preferably a natural image) to its cartoonized representation. This tutorial uses the generator of the White-box CartoonGAN for inference in the Android app.
Create the TensorFlow Lite Model
The authors of White-box CartoonGAN provide pre-trained weights that can be used for making inference on images. We convert these pre-trained weights to TFLite which are more suitable to run on a mobile app. Refer to the details on model conversion on GitHub here.
Step-by-step summary of this section:
- Generate a SavedModel out of the pre-trained model checkpoints.
- Convert SavedModel with post-training quantization using the latest TFLiteConverter.
- Run inference in Python with the converted model.
- Add metadata to enable easy integration with a mobile app.
- Run model benchmark to make sure the model runs well on mobile.
Generate a SavedModel from the pre-trained model weights
The pre-trained weights of White-box CartoonGAN come in the following format (also referred to as checkpoints)
├── checkpoint├── model-33999.data-00000-of-00001└── model-33999.index
As the original White-box CartoonGAN model is implemented in TensorFlow 1.x, we first need to generate a single self-contained model file in the
SavedModel format using TensorFlow 1.15. Then we will switch to TensorFlow 2 later to convert it to the lightweight TFLite format.
This is how to create a SavedModel in TensorFlow 1.x:
- Create a placeholder for the model input.
- Instantiate the model instance and run the input placeholder through the model to get a placeholder for the model output.
- Load the pre-trained checkpoints into the current session of the model.
- Finally, export to
Now that we have the original model in the SavedModel format, we can switch to TensorFlow 2 and proceed toward converting it to TFLite.
Convert SavedModel to TFLite
TFLite provides support for three different post-training quantization strategies -
- Dynamic range
A particular strategy is determined based on ones use-case; however, in this tutorial, we will cover all three different quantization strategies.
TFLite models with dynamic-range and float16 quantization
The steps to convert models to TFLite using these two quantization strategies are almost identical except during float16 quantization, you need to specify an extra option. The steps for model conversion are demonstrated in the code below -
# Create a concrete function from the SavedModel
model = tf.saved_model.load(saved_model_dir)
concrete_func = model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]# Specify the input shape
concrete_func.inputs.set_shape([1, IMG_SHAPE, IMG_SHAPE, 3])# Convert the model and export
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16] # Only for float16
tflite_model = converter.convert()
A couple of things to note from the code above -
- Here, we are specifying the input shape of the model that will be converted to TFLite. However, note that TFLite supports dynamic shaped models from TensorFlow 2.3. We used fixed-shaped inputs in order to restrict the memory usage of the models running on mobile devices.
- In order to convert the model using dynamic-range quantization, one just needs to comment this line converter.target_spec.supported_types = [tf.float16].
TFLite models with integer quantization
In order to convert the model using integer quantization, we need to pass a representative dataset to the converter so that the activation ranges can be calibrated accordingly. TFLite models generated using this strategy are known to sometimes work better than the other two that we just saw. Integer quantized models are generally smaller as well.
For the sake of brevity, we are going to skip the representative dataset generation part but you can refer to it in this notebook.
In order to let the TFLiteConverter take advantage of this strategy, we need to just pass converter.representative_dataset = representative_dataset_gen and remove converter.target_spec.supported_types = [tf.float16].
So after we generated these different models here’s how we stand in terms of model size -
You might feel tempted to just go with the model quantized with integer quantization but you should also consider the following things before finalizing this decision -
- Quality of the end results of the models.
- Inference time (the lower the better).
- Hardware accelerator compatibility.
- Memory usage.
We will get to these in a moment. If you want to dig deeper into these different quantization strategies refer to the official guide here.
Running inference in Python
After you have generated the TFLite models, it is important to make sure that models perform as expected after the conversion and before integrating them in mobile apps. So let’s run inference with the models in Python.
Before feeding an image to our White-box CartoonGAN TFLite models it’s important to make sure that the image is preprocessed well. Otherwise, the models might perform unexpectedly. The original model was trained using BGR images, so we need to account for this fact in the preprocessing steps as well. You can find all of the preprocessing steps in this notebook.
Here is how to make inference with the .tflite model on a preprocessed input image -
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
input_details = interpreter.get_input_details()interpreter.allocate_tensors()
interpreter.invoke()raw_prediction = interpreter.tensor(
The output is an image with BGR channel ordering so we need to convert to RGB in the postprocessing steps. Here is what the cartoonized image looks like alongside the original input image -
Again, you can find all of the postprocessing steps in this notebook.
Add metadata for easy integration with a mobile app
Model metadata in TFLite makes the life of mobile application developers much easier. If your TFLite model is populated with the right metadata then it becomes a matter of only a few keystrokes to integrate that model into a mobile application. Discussing the code to populate a TFLite model with metadata is out of the scope for this tutorial, and please refer to the metadata guide. But in this section, we are going to provide you with some of the important pointers about metadata population for the TFLite models we generated. You can follow this notebook to refer to all the code.
Two of the most important parameters we discovered during metadata population are mean and standard deviation with which the results should be processed. In our case, mean and standard deviation need to be used for both preprocessing postprocessing. For normalizing the input image the metadata configuration should be like the following -
input_image_normalization.options.mean = [127.5]
input_image_normalization.options.std = [127.5]
This would convert the pixel range in an input image to [-1, 1]. Now, during postprocessing, the pixels need to be scaled back to the range of [0, 255]
output_image_normalization.options.mean = [-1]
output_image_normalization.options.std = [0.00784313] # 1/127.5
There are two files created from the “add metadata process”:
- A .tflite file with the same name as the original model, with metadata added, including model name, description, version, input and output tensor, etc.
- A.json file so that you can print it out which displays metadata. When you import the model into Android Studio, metadata can be displayed visually as well.
The models that have been populated with metadata make it really easy to import in Android Studio which we will discuss later under the “Model deployment to an Android” section.
Benchmark models on Android (Optional)
As an optional step, we used the TFLite Android Model Benchmark tool to get an idea of the runtime performance on Android before deploying it.
Here is a high-level summary using the benchmark C++ binary:
1. Configure Android SDK/NDK prerequisites
2. Build the benchmark C++ binary with bazel
bazel build -c opt \
— config=android_arm64 \
3. Use adb (Android Debug Bridge) to push the benchmarking tool binary to device and make executable
adb push benchmark_model /data/local tmp
adb shell chmod +x /data/local/tmp/benchmark_model
4. Push the whitebox_cartoon_gan_dr.tflite model to device
adb push whitebox_cartoon_gan_dr.tflite /data/local/tmp
5. Run the benchmark tool
adb shell /data/local/tmp/android_aarch64_benchmark_model \
— graph=/data/local/tmp/whitebox_cartoon_gan_dr.tflite \
You will see a result in the terminal like this:
Repeat above steps for the other two tflite models: float16 and int8 variants.
In summary, here is the average inference time we got from the benchmark tool running on a Pixel 4:
Refer to the documentation of the benchmark tool (C++ binary | Android APK) for details and additional options such as how to reduce variance between runs and how to profile operators, etc. You can also see the performance values of some of the popular ML models on the TensorFlow official documentation here.
Model deployment to Android
Now that we have the quantized TensorFlow Lite models with metadata by either following the previous steps (or by downloading the models directly from TensorFlow Hub here), we are ready to deploy them to Android. Follow along with the Android code on GitHub here.
The Android app uses Jetpack Navigation Component for UI navigation and CameraX for image capture. We use the new ML Model Binding feature for importing the tflite model and then Kotlin Coroutine for async handling of the model inference so that the UI is not blocked while waiting for the results.
Let’s dive into the details step by step:
- Download Android Studio 4.1 Preview.
- Create a new Android project and set up the UI navigation.
- Set up the CameraX API for image capture.
- Import the .tflite models with ML Model Binding.
- Putting everything together.
Download Android Studio 4.1 Preview
We need to first install Android Studio Preview (4.1 Beta 1) in order to use the new ML Model Binding feature to import a .tflite model and auto code generation. You can then explore the tfllite models visually and most importantly use the generated classes directly in your Android projects.
Download the Android Studio Preview here. You should be able to run the Preview version side by side with a stable version of Android Studio. Make sure to update your Gradle plug-in to at least 4.1.0-alpha10; otherwise the ML Model Binding menu may be inaccessible.
Create a new Android Project
First let’s create a new Android project with an empty Activity called MainActivity.kt which contains a companion object that defines the output directory where the captured image will be stored.
Use Jetpack Navigation Component to navigate the UI of the app. Please refer to the tutorial here to learn more details about this support library.
There are 3 screens in this sample app:
- `PermissionsFragment.kt` handles checking the camera permission.
- `CameraFragment.kt` handles camera setup, image capture and saving.
- `CartoonFragment.kt` handles the display of input and cartoon image in the UI.
The navigation graph in nav_graph.xml defines the navigation of the three screens and data passing between CameraFragment and CartoonFragment.
Set up CameraX for image capture
CameraX is a Jetpack support library which makes camera app development much easier.
Camera1 API was simple to use but it lacked a lot of functionality. Camera2 API provides more fine control than Camera1 but it’s very complex — with almost 1000 lines of code in a very basic example.
CameraX on the other hand, is much easier to set up with 10 times less code. In addition, it’s lifecycle aware so you don’t need to write the extra code to handle the Android lifecycle.
Here are the steps to set up CameraX for this sample app:
CameraFragment.ktto hold the CameraX code
- Request camera permission
- Check permission in
- Implement a viewfinder with the CameraX Preview class
- Implement image capture
- Capture an image and convert it to a
CameraSelector is configured to be able to take use of the front facing and rear facing camera since the model can stylize any type of faces or objects, and not just a selfie.
Once we capture an image, we convert it to a Bitmap which is passed to the TFLite model for inference. Navigate to a new screen
CartoonFragment.kt where both the original image and the cartoonized image are displayed.
Import the TensorFlow Lite models
Now that the UI code has been completed. It’s time to import the TensorFlow Lite model for inference. ML Model Binding takes care of this with ease. In Android Studio, go to File > New > Other > TensorFlow Lite Model:
- Specify the .tflite file location.
- “Auto add build feature and required dependencies to gradle” is checked by default.
- Make sure to also check “Auto add TensorFlow Lite gpu dependencies to gradle” since the GAN models are complex and slow, and so we need to enable GPU delegate.
This import accomplishes two things:
- auto create a ml folder and place the model file .tflite file under there.
2. auto generate a Java class under the folder: app/build/generated/ml_source_out/debug/[package-name]/ml, which handles all the tasks such as model loading, image pre-preprocess and post-processing, and run model inference for stylizing the input image.
Once the import completes, we see the *.tflite display the model metadata info as well as code snippets in both Kotlin and Java that can be copy/pasted in order to use the model:
Repeat the steps above to import the other two .tflite model variants.
Putting everything together
Now that we have set up the UI navigation, configured CameraX for image capture, and the tflite models are imported, let’s put all the pieces together!
- Model input: capture a photo with CameraX and save it
- Run inference on the input image and create a cartoonized version
- Display both the original photo and the cartoonized photo in the UI
- Use Kotlin coroutine to prevent the model inference from blocking UI main thread
First we capture a photo with CameraX in
imageCapture?.takePicture(), then in
onImageSaved() convert the .jpg image to a Bitmap, rotate if necessary, and then save it to an output directory defined in
With the JetPack Nav Component, we can easily navigate to
CartoonFrament.kt and pass the image directory location as a string argument, and the type of tflite model as an integer. Then in
CartoonFrament.kt, retrieve the file directory string where the photo was stored, create an image file then convert it to be Bitmap which can be used as the input to the tflite model.
CartoonFrament.kt, also retrieve the type of tflite model that was chosen for inference. Run model inference on the input image and create a cartoon image. We display both the original image and the cartoonized image in the UI.
Note: the inference takes time so we use Kotlin coroutine to prevent the model inference from blocking the UI main thread. Show a ProgressBar till the model inference completes.
Here is what we have once all pieces are put together and here are the cartoon images created by the model:
This brings us to the end of the tutorial. We hope you have enjoyed reading it and will apply what you learned to your real-world applications with TensorFlow Lite. If you have created any cool samples with what you learned here, please remember to add it to awesome-tflite — a repo with TensorFlow Lite samples, tutorials, tools and learning resources.
This Cartoonizer with TensorFlow Lite project with end-to-end tutorial was created with the great collaboration by ML GDEs and the TensorFlow Lite team. This is the one of a series of end-to-end TensorFlow Lite tutorials. We would like to thank Khanh LeViet and Lu Wang (TensorFlow Lite), Hoi Lam (Android ML), Trevor McGuire (CameraX) and Soonson Kwon (ML GDEs Google Developers Experts Program), for their collaboration and continuous support.
Also thanks to the authors of the paper Learning to Cartoonize Using White-box Cartoon Representations: Xinrui Wang and Jinze Yu.
A full version of this tutorial was originally published here: https://blog.tensorflow.org/2020/09/how-to-create-cartoonizer-with-tf-lite.html