How to Create a Cartoonizer with TensorFlow Lite?

Written by ML GDEs Margaret Maynard-Reid and Sayak Paul

Image for post
Image for post
  • How to download the .tflite models directly from TensorFlow Hub if you are only interested in using the models for deployment.
  • Understand how to use the TFLite tools such as the Android Benchmark Tool, Model Metadata, and Codegen.
  • Guide developers on how to create a mobile application with TFLite models easily, with ML Model Binding feature from Android Studio.

Create the TensorFlow Lite Model

The authors of White-box CartoonGAN provide pre-trained weights that can be used for making inference on images. We convert these pre-trained weights to TFLite which are more suitable to run on a mobile app. Refer to the details on model conversion on GitHub here.

  • Convert SavedModel with post-training quantization using the latest TFLiteConverter.
  • Run inference in Python with the converted model.
  • Add metadata to enable easy integration with a mobile app.
  • Run model benchmark to make sure the model runs well on mobile.

Generate a SavedModel from the pre-trained model weights

The pre-trained weights of White-box CartoonGAN come in the following format (also referred to as checkpoints)

├── checkpoint├── model-33999.data-00000-of-00001└── model-33999.index
  • Instantiate the model instance and run the input placeholder through the model to get a placeholder for the model output.
  • Load the pre-trained checkpoints into the current session of the model.
  • Finally, export to SavedModel

Convert SavedModel to TFLite

TFLite provides support for three different post-training quantization strategies -

  1. Float16
  2. Integer
# Create a concrete function from the SavedModel
model = tf.saved_model.load(saved_model_dir)
concrete_func = model.signatures[
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
# Specify the input shape
concrete_func.inputs[0].set_shape([1, IMG_SHAPE, IMG_SHAPE, 3])
# Convert the model and export
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16] # Only for float16
tflite_model = converter.convert()
open(tflite_model_path, ‘wb’).write(tflite_model)
  • In order to convert the model using dynamic-range quantization, one just needs to comment this line converter.target_spec.supported_types = [tf.float16].
Image for post
Image for post
  • Inference time (the lower the better).
  • Hardware accelerator compatibility.
  • Memory usage.

Running inference in Python

After you have generated the TFLite models, it is important to make sure that models perform as expected after the conversion and before integrating them in mobile apps. So let’s run inference with the models in Python.

interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
input_details = interpreter.get_input_details()
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0][‘index’], preprocessed_source_image)
interpreter.invoke()
raw_prediction = interpreter.tensor(
interpreter.get_output_details()[0][‘index’])()
Image for post
Image for post

Add metadata for easy integration with a mobile app

Model metadata in TFLite makes the life of mobile application developers much easier. If your TFLite model is populated with the right metadata then it becomes a matter of only a few keystrokes to integrate that model into a mobile application. Discussing the code to populate a TFLite model with metadata is out of the scope for this tutorial, and please refer to the metadata guide. But in this section, we are going to provide you with some of the important pointers about metadata population for the TFLite models we generated. You can follow this notebook to refer to all the code.

input_image_normalization.options.mean = [127.5]
input_image_normalization.options.std = [127.5]
output_image_normalization.options.mean = [-1]
output_image_normalization.options.std = [0.00784313] # 1/127.5
  1. A.json file so that you can print it out which displays metadata. When you import the model into Android Studio, metadata can be displayed visually as well.

Benchmark models on Android (Optional)

As an optional step, we used the TFLite Android Model Benchmark tool to get an idea of the runtime performance on Android before deploying it.

bazel build -c opt \
— config=android_arm64 \
tensorflow/lite/tools/benchmark:benchmark_model
adb push benchmark_model /data/local tmp
adb shell chmod +x /data/local/tmp/benchmark_model
adb push whitebox_cartoon_gan_dr.tflite /data/local/tmp
adb shell /data/local/tmp/android_aarch64_benchmark_model \
— graph=/data/local/tmp/whitebox_cartoon_gan_dr.tflite \
— num_threads=4
Image for post
Image for post
Image for post
Image for post

Model deployment to Android

Now that we have the quantized TensorFlow Lite models with metadata by either following the previous steps (or by downloading the models directly from TensorFlow Hub here), we are ready to deploy them to Android. Follow along with the Android code on GitHub here.

  • Create a new Android project and set up the UI navigation.
  • Set up the CameraX API for image capture.
  • Import the .tflite models with ML Model Binding.
  • Putting everything together.

Download Android Studio 4.1 Preview

We need to first install Android Studio Preview (4.1 Beta 1) in order to use the new ML Model Binding feature to import a .tflite model and auto code generation. You can then explore the tfllite models visually and most importantly use the generated classes directly in your Android projects.

Create a new Android Project

First let’s create a new Android project with an empty Activity called MainActivity.kt which contains a companion object that defines the output directory where the captured image will be stored.

  • `CameraFragment.kt` handles camera setup, image capture and saving.
  • `CartoonFragment.kt` handles the display of input and cartoon image in the UI.

Set up CameraX for image capture

CameraX is a Jetpack support library which makes camera app development much easier.

  • Use CameraFragment.kt to hold the CameraX code
  • Request camera permission
  • Update AndroidManifest.xml
  • Check permission in MainActivity.kt
  • Implement a viewfinder with the CameraX Preview class
  • Implement image capture
  • Capture an image and convert it to a Bitmap

Import the TensorFlow Lite models

Now that the UI code has been completed. It’s time to import the TensorFlow Lite model for inference. ML Model Binding takes care of this with ease. In Android Studio, go to File > New > Other > TensorFlow Lite Model:

Image for post
Image for post
  • “Auto add build feature and required dependencies to gradle” is checked by default.
  • Make sure to also check “Auto add TensorFlow Lite gpu dependencies to gradle” since the GAN models are complex and slow, and so we need to enable GPU delegate.
Image for post
Image for post
Image for post
Image for post

Putting everything together

Now that we have set up the UI navigation, configured CameraX for image capture, and the tflite models are imported, let’s put all the pieces together!

  • Run inference on the input image and create a cartoonized version
  • Display both the original photo and the cartoonized photo in the UI
  • Use Kotlin coroutine to prevent the model inference from blocking UI main thread
Image for post
Image for post
Cartoonizer Android app

Acknowledgments

This Cartoonizer with TensorFlow Lite project with end-to-end tutorial was created with the great collaboration by ML GDEs and the TensorFlow Lite team. This is the one of a series of end-to-end TensorFlow Lite tutorials. We would like to thank Khanh LeViet and Lu Wang (TensorFlow Lite), Hoi Lam (Android ML), Trevor McGuire (CameraX) and Soonson Kwon (ML GDEs Google Developers Experts Program), for their collaboration and continuous support.

Google Developer Expert for ML | AI for Art & Design | margaretmz.art

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store