3D Hand Pose with MediaPipe and TensorFlow.js

Posted by Valentin Bazarevsky, Ivan Grishchenko, Eduard Gabriel Bazavan, Andrei Zanfir, Mihai Zanfir, Jiuqiang Tang, Jason Mayes, Ahmed Sabie, Google

Today, we’re excited to share a new version of our model for hand pose detection, with improved accuracy for 2D, novel support for 3D, and the new ability to predict keypoints on both hands simultaneously. Support for multi-hand tracking was one of the most common requests from the developer community, and we’re pleased to support it in this release.

You can try a live demo of the new model here. This work improves on our previous model which predicted 21 keypoints, but could only detect a single hand at a time. In this article, we’ll describe the new model, and how you can get started.

The new hand pose detection model in action.
The new hand pose detection model in action.

Try out the live demo!

How to use it

1. The first step is to import the library. You can either use the <script> tag in your html file or use NPM:

Through script tag:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/hand-pose-detection">>/script>
<!-- Optional: Include below scripts if you want to use MediaPipe runtime. -->
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/hands"> </script >

Through NPM:

yarn add @tensorflow-models/hand-pose-detection

# Run below commands if you want to use TF.js runtime.
yarn add @tensorflow/tfjs-core @tensorflow/tfjs-converter
yarn add @tensorflow/tfjs-backend-webgl

# Run below commands if you want to use MediaPipe runtime.
yarn add @mediapipe/hands

If installed through NPM, you need to import the libraries first:

import * as handPoseDetection from '@tensorflow-models/hand-pose-detection';

Next create an instance of the detector:

const model = handPoseDetection.SupportedModels.MediaPipeHands;
const detectorConfig = {
runtime: 'mediapipe', // or 'tfjs'
modelType: 'full'
detector = await handPoseDetection.createDetector(model, detectorConfig);

Choose a modelType that fits your application needs, there are two options for you to choose from: lite, and full. From lite to full, the accuracy increases while the inference speed decreases.

2. Once you have a detector, you can pass in a video stream or static image to detect poses:

const video = document.getElementById('video');
const hands = await detector.estimateHands(video);

The output format is as follows: hands represent an array of detected hand predictions in the image frame. For each hand, the structure contains a prediction of the handedness (left or right) as well as a confidence score of this prediction. An array of 2D keypoints is also returned, where each keypoint contains x, y, and name. The x, y denotes the horizontal and vertical position of the hand keypoint in the image pixel space, and name denotes the joint label. In addition to 2D keypoints, we also return 3D keypoints (x, y, z values) in a metric scale, with the origin in auxiliary keypoint formed as an average between the first knuckles of index, middle, ring and pinky fingers.

score: 0.8,
Handedness: 'Right',
keypoints: [
{x: 105, y: 107, name: "wrist"},
{x: 108, y: 160, name: "pinky_finger_tip"},
keypoints3D: [
{x: 0.00388, y: -0.0205, z: 0.0217, name: "wrist"},
{x: -0.025138, y: -0.0255, z: -0.0051, name: "pinky_finger_tip"},

You can refer to our README for more details about the API.

Model deep dive

The updated version of our hand pose detection API improves the quality for 2D keypoint prediction, handedness (classification output whether it is left or right hand), and minimizes the number of false positive detections. More details about the updated model can be found in our recent paper: On-device Real-time Hand Gesture Recognition.

Following our recently released BlazePose GHUM 3D in TensorFlow.js, we also added metric-scale 3D keypoint prediction to hand pose detection in this release, with the origin being represented by an auxiliary keypoint, formed as a mean of first knuckles for index, middle, ring and pinky fingers. Our 3D ground truth is based on a statistical 3D human body model called GHUM, which is built using a large corpus of human shapes and motions.

To obtain hand pose ground truth, we fitted the GHUM hand model to our existing 2D hand dataset and recovered real world 3D keypoint coordinates. The shape and the hand pose variables of the GHUM hand model were optimized such that the reconstructed model aligns with the image evidence. This includes 2D keypoint alignment, shape, and pose regularization terms as well as anthropometric joint angle limits and model self contact penalties.

Sample GHUM hand fittings for hand images with 2D keypoint annotations overlaid. The data was used to train and test a variety of poses leading to better results for more extreme poses.
Sample GHUM hand fittings for hand images with 2D keypoint annotations overlaid. The data was used to train and test a variety of poses leading to better results for more extreme poses.

Model quality

In this new release, we substantially improved the quality of models, and evaluated them on a dataset of American Sign Language (ASL) gestures. As evaluation metric for 2D screen coordinates, we used Mean Average Precision (mAP) suggested by the COCO keypoint challenge methodology.

Hand model evaluation on American Sign Language dataset
Hand model evaluation on American Sign Language dataset

For 3D evaluation we used Mean Absolute Error in Euclidean 3D metric space, with the average error measured in centimeters.

Model Name

2D, mAP, %

3D, mean 3D error, cm

HandPose GHUM Lite



HandPose GHUM Full



Previous TensorFlow.js HandPose



Quality metrics for newly released HandPose GHUM models vs. previously released TensorFlow.js HandPose model in for 2D and 3D predictions

Browser performance

We benchmark the model across multiple devices. All the benchmarks are tested with two hands presented.

MacBook Pro 15” 2019. 

Intel core i9. 

AMD Radeon Pro Vega 20 Graphics.


iPhone 11


Pixel 5



Intel i9-10900K. Nvidia GTX 1070 GPU.


MediaPipe Runtime

With WASM & GPU Accel.

62 | 48

8 | 5

19 | 15 

  136 | 120

TensorFlow.js Runtime
With WebGL backend

36 | 31

15 | 12

11 | 8 

 42 | 35 

Inference speed of HandPose across different devices and runtimes. The first number in each cell is for the lite model, and the second number is for the full model.

To see the model’s FPS on your device, try our demo. You can switch the model type and runtime live in the demo UI to see what works best for your device.

Cross platform availability

In addition to the JavaScript hand pose detection API, these updated hand models are also available in MediaPipe Hands as a ready-to-use Android Solution API and Python Solution API, with prebuilt packages in Android Maven Repository and Python PyPI respectively.

For instance, for Android developers the Maven package can be easily integrated into an Android Studio project by adding the following into the project’s Gradle dependencies:

dependencies {
implementation 'com.google.mediapipe:solution-core:latest.release'
implementation 'com.google.mediapipe:hands:latest.release'

The MediaPipe Android Solution is designed to handle different use scenarios such as processing live camera feeds, video files, as well as static images. It also comes with utilities to facilitate overlaying the output landmarks onto either CPU images (with Canvas) or GPU (using OpenGL). For instance, the following code snippet demonstrates how it can be used to process a live camera feed and render the output on screen in real-time:

// Creates MediaPipe Hands.
HandsOptions handsOptions =
Hands hands = new Hands(activity, handsOptions);

// Connects MediaPipe Hands to camera.
CameraInput cameraInput = new CameraInput(activity);
cameraInput.setNewFrameListener(textureFrame -> hands.send(textureFrame));

// Registers a result listener.
handsResult -> {

// Starts the camera to feed data to MediaPipe Hands.

To learn more about MediaPipe Android Solutions, please refer to our documentation and try them out with the example Android Studio project. Also visit MediaPipe Solutions for more cross-platform solutions.


We would like to acknowledge our colleagues who participated in or sponsored creating HandPose GHUM 3D and building the APIs: Cristian Sminchisescu, Michael Hays, Na Li, Ping Yu, George Sung, Jonathan Baccash‎, Esha Uboweja, David Tian, Kanstantsin Sokal‎, Gregory Karpiak, Tyler Mullen, Chuo-Ling Chang, Matthias Grundmann.

Read More

Women in Machine Learning Symposium – Event Recap

Posted by Joana Carrasqueira, Program Manager, TensorFlow.

Thank you to everyone who joined us at the first Women in Machine Learning Symposium!

Hundreds of practitioners joined from all over the world to share tips and insights for careers in ML, how to be involved in the community, contribute to open source, and much more. It was very inspiring to learn from each other’s experiences. Following is a quick recap, and an overview of the resources we discussed at the event. Thanks again.

Online education

Get involved in the community

Build your portfolio

Connect with (or become) a GDE

Read More

What’s new in TensorFlow 2.7?

Posted by Goldie Gadde and Josh Gordon for the TensorFlow team

TensorFlow 2.7 is here! This release improves usability with clearer error messages, simplified stack traces, and adds new tools and documentation for users migrating to TF2.

Improved Debugging Experience

The process of debugging your code is a fundamental part of the user experience of a machine learning framework. In this release, we’ve considerably improved the TensorFlow debugging experience to make it more productive and more enjoyable, via three major changes: simplified stack traces, displaying additional context information in errors that originate from custom Keras layers, and a wide-ranging audit of all error messages in Keras and TensorFlow.

Simplified stack traces

TensorFlow is now filtering by default the stack traces displayed upon error to hide any frame that originates from TensorFlow-internal code, and keep the information focused on what matters to you: your own code. This makes stack traces simpler and shorter, and it makes it easier to understand and fix the problems in your code.

If you’re actually debugging the TensorFlow codebase itself (for instance, because you’re preparing a PR for TensorFlow), you can turn off the filtering mechanism by calling tf.debugging.disable_traceback_filtering().

Automatic context injection for Keras layer exceptions

One of the most common use cases for writing low-level code is creating custom Keras layers, so we wanted to make debugging your layers as easy and productive as possible. The first thing you do when you’re debugging a layer is to print the shapes and dtypes of its inputs, as well the value of its training and mask arguments. We now add this information automatically to all stack traces that originate from custom Keras layers.

See the effect of stack trace filtering and call context information display in practice in the image below:

Simplified stack traces in TensorFlow 2.7
Simplified stack traces in TensorFlow 2.7

Audit and improve all error messages in the TensorFlow and Keras codebases

Lastly, we’ve audited every error message in the Keras and TensorFlow codebases (thousands of error locations!) and improved them to make sure they follow UX best practices. A good error message should tell you what the framework expected, what you did that didn’t match the framework’s expectations, and should provide tips to fix the problem.

Improve tf.function error messages

We have improved two common types of tf.function error messages: runtime error messages and “Graph” tensor error messages, by including tracebacks pointing to the error source in the user code. For other vague and inaccurate tf.function error messages, we also updated them to be more clear and accurate.

For the runtime error message caused by the user code

def f():
l = tf.range(tf.random.uniform((), minval=1, maxval=10, dtype=tf.int32))
return l[20]

A summary of the old error message looks like

# … Python stack trace of the function call …

InvalidArgumentError: slice index 20 of dimension 0 out of bounds.
[[node strided_slice (defined at <'ipython-input-8-250c76a76c0e'>:5) ]] [Op:__inference_f_75]

Errors may have originated from an input operation.
Input Source operations connected to node strided_slice:
range (defined at <ipython-input-8-250c76a76c0e >':4)

Function call stack:

A summary of the new error message looks like

# … Python stack trace of the function call …

InvalidArgumentError: slice index 20 of dimension 0 out of bounds.
[[node strided_slice
(defined at <ipython-input-3-250c76a76c0e>:5)
]] [Op:__inference_f_15]

Errors may have originated from an input operation.
Input Source operations connected to node strided_slice:
In[0] range (defined at <ipython-input-3-250c76a76c0e>:4)
In[1] strided_slice/stack:
In[2] strided_slice/stack_1:
In[3] strided_slice/stack_2:

Operation defined at: (most recent call last)
# … Stack trace of the error within the function …
>>> File "<ipython-input-3-250c76a76c0e>", line 7, in <module>
>>> f()
>>> File "<ipython-input-3-250c76a76c0e>", line 5, in f
>>> return l[20]

The main difference is runtime errors raised while executing a tf.function now include a stack trace which shows the source of the error, in the user’s code.

# … Original error message and information …
# … More stack frames …
>>> File "<ipython-input-3-250c76a76c0e>", line 7, in <module>
>>> f()
>>> File "<ipython-input-3-250c76a76c0e>", line 5, in f
>>> return l[20]

For the “Graph” tensor error messages caused by the following user code

x = None

def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2

def captures_leaked_tensor(b):
b += x
return b


A summary of the old error message looks like

# … Python stack trace of the function call …

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: add:0

A summary of the new error message looks like

# … Python stack trace of the function call …

TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
# … Stack trace of the error within the function …
>>> File <ipython-input-5-95ca3a98778f>, line 6, in leaky_function
# … More stack trace of the error within the function …

Error detected in node 'add' defined at: File "<ipython-input-5-95ca3a98778f>", line 6, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

The main difference is errors for attempting to capture a tensor that was leaked from an unreachable graph now include a stack trace which shows where the tensor was created in the user’s code:

# … Original error message and information …
# … More stack frames …
>>> File <ipython-input-5-95ca3a98778f>, line 6, in leaky_function

Error detected in node 'add' defined at: File "<ipython-input-5-95ca3a98778f>", line 6, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

Introducing tf.experimental.ExtensionType

User-defined types can make your projects more readable, modular, maintainable. TensorFlow 2.7.0 introduces the ExtensionType API, which can be used to create user-defined object-oriented types that work seamlessly with TensorFlow’s APIs. Extension types are a great way to track and organize the tensors used by complex models. Extension types can also be used to define new tensor-like types, which specialize or extend the basic concept of “Tensor.” To create an extension type, simply define a Python class with tf.experimental.ExtensionType as its base, and use type annotations to specify the type for each field:

class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: typing.Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64

The ExtensionType base class adds a constructor and special methods based on the field type annotations (similar to typing.NamedTuple and @dataclasses.dataclass from the standard Python library). You can optionally customize the type by overriding these defaults, or adding new methods, properties, or subclasses.

Extension types are supported by the following TensorFlow APIs:

  • Keras: Extension types can be used as inputs and outputs for Keras Models and Layers.
  • Dataset: Extension types can be included in Datasets, and returned by dataset Iterators.
  • TensorFlow hub: Extension types can be used as inputs and outputs for tf.hub modules.
  • SavedModel: Extension types can be used as inputs and outputs for SavedModel functions.
  • tf.function: Extension types can be used as arguments and return values for functions wrapped with the @tf.function decorator.
  • control flow: Extension types can be used by control flow operations, such as tf.cond and tf.while_loop. This includes control flow operations added by autograph.
  • tf.py_function: Extension types can be used as arguments and return values for the func argument to tf.py_function.
  • Tensor ops: Extension types can be extended to support most TensorFlow ops that accept Tensor inputs (e.g., tf.matmul, tf.gather, and tf.reduce_sum), using dispatch decorators.
  • distribution strategy: Extension types can be used as per-replica values.

For more information about extension types, see the Extension Type guide.

Note: The tf.experimental prefix indicates that this is a new API, and we would like to collect feedback from real-world usage; barring any unforeseen design issues, we plan to migrate ExtensionType out of the experimental package in accordance with the TF experimental policy.

TF2 Migration made easier!

To support users interested in migrating their workloads from TF1 to TF2, we have created a new Migrate to TF2 tab on the TensorFlow website, which includes updated guides and completely new documentation with concrete, runnable examples in Colab.

A new shim tool has been added which dramatically eases migration of variable_scope-based models to TF2. It is expected to enable most TF1 users to run existing model architectures as-is (or with only minor adjustments) in TF2 pipelines without having to rewrite your modeling code. You can learn more about it in the model mapping guide.

New community contributed models on TensorFlow Hub

Since the last TensorFlow release, the community really came together to make many new models available on TensorFlow Hub. Now you can find models like MLP-Mixer, Vision Transformers, Wav2Vec2, RoBERTa, ConvMixer, DistillBERT, YoloV5 and many more. All of these models are ready to use via TensorFlow Hub. You can learn more about publishing your models here.

Next steps

Check out the release notes for more information. To stay up to date, you can read the TensorFlow blog, follow twitter.com/tensorflow, or subscribe to youtube.com/tensorflow. If you’ve built something you’d like to share, please submit it for our Community Spotlight at goo.gle/TFCS. For feedback, please file an issue on GitHub or post to the TensorFlow Forum. Thank you!

Read More

On-device training in TensorFlow Lite

Posted by the TensorFlow Lite team

TensorFlow Lite is Google’s machine learning framework to deploy machine learning models on multiple devices and surfaces such as mobile (iOS and Android), desktops and other edge devices. Recently, we added support to run TensorFlow Lite models in a browser as well. In order to build apps using TensorFlow Lite, you can either use an off-the shelf model from TensorFlow Hub, or convert an existing TensorFlow Model to a TensorFlow Lite model using the converter. Once the model is deployed in an app, you can run inference on the model based on input data.

TensorFlow Lite now supports training your models on-device, in addition to running inference. On-device training enables interesting personalization use cases where models can be fine-tuned based on user needs. For instance, you could deploy an image classification model and allow a user to fine-tune the model to recognize bird species using transfer learning, while allowing another user to retrain the same model to recognize fruits. This new feature is available in TensorFlow 2.7 and later and is currently available for Android apps. (iOS support will be added in the future.)

On-device training is also a necessary foundation for Federated Learning use cases to train global models on decentralized data. This blog post does not cover Federated Learning and instead focuses on helping you integrate on-device training in your Android apps.

Later in this article we will reference a Colab and Android sample app as we walk you through the end-to-end implementation path for on-device learning to fine-tune an image classification model.

Improvements over the earlier approach

In our 2019 blog post, we introduced on-device training concepts and an example of on-device training in TensorFlow Lite. However, there were several limitations. For example, it was not easy to customize the model structure and optimizers. You also had to deal with multiple physical TensorFlow Lite (.tflite) models instead of a single TensorFlow Lite model. Similarly, there was no easy way to store and update the training weights. Our latest TensorFlow Lite version streamlines this process by providing more convenient options for on-device training, as explained below.

How does it work?

In order to deploy a TensorFlow Lite model with on-device training built-in, here are the high level steps:

  • Build a TensorFlow model for training and inference
  • Convert the TensorFlow model to TensorFlow Lite format
  • Integrate the model in your Android app
  • Invoke model training in the app, similar to how you would invoke model inference

These steps are explained below.

Build a TensorFlow model for training and inference

The TensorFlow Lite model should not only support model inference, but also model training, which typically involves saving the model’s weights to the file system and restoring the weights from the file system. This is done to save the training weights after each training epoch, so that the next training epoch can use the weights from the previous one, instead of starting training from scratch.

Our suggested approach is to implement these tf.functions to represent training, inference, saving weights, and loading weights:

  • A train function that trains the model using training data. The train function below makes a prediction, calculates the loss (or error), and uses tf.GradientTape() to record operations for automatic differentiation and update the model’s parameters.
    # The `train` function takes a batch of input images and labels.
    tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
    tf.TensorSpec([None, 10], tf.float32),
    def train(self, x, y):
    with tf.GradientTape() as tape:
    prediction = self.model(x)
    loss = self._LOSS_FN(prediction, y)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    zip(gradients, self.model.trainable_variables))
    result = {"loss": loss}
    for grad in gradients:
    result[grad.name] = grad
    return result
  • An infer or a predict function that invokes model inference. This is similar to how you currently use TensorFlow Lite for inference.
    @tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
    def predict(self, x):
    return {
    "output": self.model(x)
  • A save/restore function that saves training weights (i.e., parameters used by the model) in Checkpoints format to the file system. The save function’s code is shown below.
    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
    def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
    filename=checkpoint_path, tensor_names=tensor_names,
    data=tensors_to_save, name='save')
    return {
    "checkpoint_path": checkpoint_path

Convert to TensorFlow Lite format

You may already be familiar with the workflow to convert your TensorFlow model to the TensorFlow Lite format. Some of the low level features for on-device training (e.g., variables to store the model parameters) are still experimental, and others (e.g., weight serialization) currently rely on TF Select operators, so you will need to set these flags during conversion. You can find an example of all the flags you need to set in the Colab.

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

Integrate the model in your Android app

Once you have converted your model to the TensorFlow Lite format, you’re ready to integrate the model into your app! Refer to the Android app samples for more details.

Invoke model training and inference in app

On Android, TensorFlow Lite on-device training can be performed using either Java or C++ APIs. You can create an instance of the TensorFlow Lite Interpreter to load a model and drive model training tasks. We had previously defined multiple tf.functions: these functions can be invoked using TensorFlow Lite’s support for Signatures, which allow a single TensorFlow Lite model to support multiple ‘entry’ points. For example, we had defined a train function for on-device training, which is one of the model’s signatures. The train function can be invoked using TensorFlow Lite’s runSignature method by specifying the name of the signature (‘train’):

 // Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
Map<String, Object> inputs = new HashMap<>>();
inputs.put("x", trainImageBatches.get(batchIdx));
inputs.put("y", trainLabelBatches.get(batchIdx));

Map<String, Object> outputs = new HashMap<>();
FloatBuffer loss = FloatBuffer.allocate(1);
outputs.put("loss", loss);

interpreter.runSignature(inputs, outputs, "train");

// Record the last loss.
if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);

Similarly, the following example shows how to invoke inference using the model’s ‘infer’ signature:

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
// Restore the weights from the checkpoint file.

int NUM_TESTS = 10;
FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

// Fill the test data.

// Run the inference.
Map<String, Object> inputs = new HashMap<>>();
inputs.put("x", testImages.rewind());
Map<String, Object> outputs = new HashMap<>();
outputs.put("output", output);
anotherInterpreter.runSignature(inputs, outputs, "infer");

// Process the result to get the final category values.
int[] testLabels = new int[NUM_TESTS];
for (int i = 0; i < NUM_TESTS; ++i) {
int index = 0;
for (int j = 1; j < 10; ++j) {
if (output.get(i * 10 + index) < output.get(i * 10 + j))
index = testLabels[j];
testLabels[i] = index;

And, that’s it! You now have a TensorFlow Lite model that is able to use on-device training. We hope that this code walkthrough gives you a good idea on how to run on-device training in TensorFlow Lite, and we’re excited to see where you take it.

Practical considerations

In theory, you should be able to apply on-device training in TensorFlow Lite to any use case that TensorFlow supports. However, in reality there are a few practical considerations that you need to keep in mind before you deploy on-device training in your apps:

  • Use cases: The Colab example shows an example of on-device training for a vision use case. If you run into issues for specific models or use cases, please let us know on GitHub.
  • Performance: Depending on the use case, on-device training could take anywhere from a few seconds to much longer. If you run on-device training as part of a user-facing feature (e.g., your end user is interacting with the feature), you should measure the time taken for a wide range of possible training inputs in your app to limit the training time. If your use-case requires very long on-device training times, consider training a model using a desktop or the cloud first, then fine-tuning it on-device.
  • Battery usage: Just like model inference, invoking model training on device may result in a battery drain. If model training is part of a feature that is not user facing, we recommend following Android’s guidelines to implement background tasks.
  • Training from scratch vs. retraining: In theory, it should be possible to train a model from scratch on device using the above features. However, in reality, training from scratch involves an enormous amount of training data and may take several days even on servers with powerful processors. Consequently, for on-device applications, we recommend retraining on an already trained model (i.e., transfer learning) as shown in the Colab example.


Future work includes (but is not limited to) on-device training support on iOS, performance improvements to leverage on-device accelerators (e.g. GPUs) for on-device training, reducing the binary size by implementing more training ops natively in TensorFlow Lite, higher level API support (e.g. via the TensorFlow Lite Task Library) to abstract away the implementation details and examples covering other on-device training use cases (e.g. NLP). Our long term roadmap involves potentially providing on-device end-to-end Federated Learning solutions.

Next steps

Thank you for reading! We’re excited to see what you build using on-device learning. Once again, here are links to the sample app and Colab. If you have any feedback, please let us know on the TensorFlow Forum, or on GitHub.


This post reflects the significant contributions of many people in Google’s TensorFlow Lite team including Michelle Carney, Lawrence Chan, Jaesung Chung, Jared Duke, Terry Heo, Jared Lim, Yu-Cheng Ling, Thai Nguyen, Karim Nosseir, Arun Venkatesan, Haoliang Zhang, other TensorFlow Lite team members, and our collaborators in Google Research.

Read More

Building a board game app with TensorFlow: a new TensorFlow Lite reference app

Posted by Wei Wei, Developer Advocate

Games are often used as test grounds for various reinforcement learning (RL) algorithms. While it is very exciting that machine learning researchers invent new RL algorithms to master challenging games, we are also curious to see that game developers are using RL to build gaming bots in TensorFlow for various purposes, such as quality testing, game balance tuning and game difficulty assessment.

We already have a detailed tutorial that demonstrates how to implement the actor-critic RL method for the classical CartPole gym environment with TensorFlow. In this end-to-end tutorial, we are going to show you how to use TensorFlow core, TensorFlow Agents and TensorFlow Lite to build a game agent to play against a human user in a small board game app. The end result is an Android reference app that looks like below, and we have open sourced all the code in tensorflow/examples repository for your reference.

Demo game play in ‘Plane Strike’
Demo game play in ‘Plane Strike’

The game is called ‘Plane Strike’, a small board game that resembles the board game ‘Battleship’. The rules are very simple:

  • At the beginning of the game, the user and the agent each have a ‘plane’ object (8 blue cells that form a ‘plane’ as you can see in the animation above) on their own boards; these planes are only visible to the owners of the board and hidden to their opponents.
  • The user and the agent take turns to strike at one cell of each other’s board. The user can tap any cell in the agent’s board, while the agent will automatically make the choice based on the prediction of a machine learning model. The attempted cell turns red if it is a ‘plane’ cell (‘hit’); otherwise it turns yellow (‘miss’).
  • Whoever achieves 8 red cells first wins the game; then the game is restarted with fresh boards.

Even though it may be possible to create handcrafted rules for such a small game, we turn to reinforcement learning to create a smart agent that a human player can’t easily beat. For a general introduction of reinforcement learning, please refer to this RL course from DeepMind and UCL.

We provide 2 paths of training and deployment for this game app

TensorFlow Lite with a Python model written from scratch

In this path, to train the agent, we first create a custom OpenAI gym environment ‘PlaneStrike-v0’, which allows us to easily roll out game plays and gather game logs. Then we use the reward-to-go policy gradient algorithm to train the agent. REINFORCE is a policy gradient algorithm in RL. Its basic idea is to adjust the policy network parameters based on the reward signals collected during the gameplay, so that the policy network can maximize the return in future plays.

Mathematically, the policy gradient is defined as:


  • T: the number of timesteps per episode, which can vary per episode
  • st: the state at timestep t
  • at: chosen action at timestep t given state s
  • πθ: is the policy parameterized by θ
  • R(*): is the reward gathered, given the policy

Please refer to this DeepMind lecture on policy gradient for a more detailed discussion. To implement it with TensorFlow, we define a simple 3-layer MLP as our policy network, which predicts the agent’s next strike position, given the human player’s board state. Note that the log expression of the above policy gradient without the reward part is the equivalent of negative cross entropy loss. In this case, since we want to maximize the rewards, we can just minimize the categorical cross entropy loss to achieve that.

model.compile(loss='sparse_categorical_crossentropy', optimizer=sgd)

We create a play_game() function to roll out the game and help us gather game logs. After each episode, we train the agent via Keras fit() function:

model.fit(x=board_log, y=action_log, sample_weight=rewards)

Note that we pass the discounted rewards-to-go as ‘sample_weight’ into the Keras fit() function as a shortcut, to implement the policy gradient algorithm without writing a custom training loop. An intuitive way to think about this is we need a tuple of (x, y, reward) instead of just (x, y) as in supervised learning. Rewards, which can be negative, help the predictor output move toward/away from y, based on x. This is different from supervised learning (in which case your ‘sample_weight’ can never be negative).

Since what we are doing isn’t supervised learning, we can’t really use training loss to monitor the training progress. Instead, we are going to use a proxy metric ‘game_length’, which indicates how many steps the agent takes to finish each episode. Intuitively you can understand that if the agent is smarter and makes better predictions, the game length becomes shorter.

Training progress in TensorBoard

Since this is a game that needs instantaneous responses from the agent, we want to deploy the model on mobile devices instead of servers. After training the model, we use the TFLite converter to convert the Keras model into a TFLite model, and integrate it into our Android app.

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

The exported model is very fast and takes <1 ms to execute on Pixel phones. During the game play, at each step the agent looks at the user’s board position and predicts its next strike position to achieve 8 red cells as fast as possible.

tflite.run(boardData, outputProbArrays);
float[] probArray = outputProbArrays[0];
int agentStrikePosition = -1;
float maxProb = 0;
for (int i = 0; i int x = i / Constants.BOARD_SIZE;
int y = i % Constants.BOARD_SIZE;
if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) {
agentStrikePosition = i;
maxProb = probArray[i];

TensorFlow Lite with a model trained with TensorFlow Agents

While it’s a good exercise to write our agent from scratch using TensorFlow API, it’s better to leverage existing implementations of RL algorithms. TensorFlow Agents is a library for reinforcement learning in TensorFlow, and makes it easier to design, implement and test new RL algorithms by providing well tested modular components that can be modified and extended. TF Agents has implemented several state-of-the-art RL algorithms, including DQN, DDPG, REINFORCE, PPO, SAC and TD3. Trained policies by TF Agents can be converted to TFLite directly and deployed into mobile apps (note that this feature is only recently enabled so you will need the nightly builds of TensorFlow and TensorFlow Agents).

We use the TF Agents REINFORCE agent to train our agent. First, we need to define a TF Agents training environment as we did with the gym environment in the previous section. Then we can define an actor net as our policy network

actor_net = tfa.networks.Sequential([
tfa.keras_layers.InnerReshape([BOARD_SIZE, BOARD_SIZE], [BOARD_SIZE**2]),
tf.keras.layers.Dense(FC_LAYER_PARAMS, activation='relu'),
tf.keras.layers.Lambda(lambda t: tfp.distributions.Categorical(logits=t)),
], input_spec=train_py_env.observation_spec(

We are going to use the built-in REINFORCE agent that TF Agents has already implemented. The agent is built on top of the ‘actor_net’ defined above:

  tf_agent = reinforce_agent.ReinforceAgent(

To train the agent, we need to collect some trajectories as experience. We define a function just for that using DeepMind Reverb and TF Agent PyDriver:

def collect_episode(environment, policy, num_episodes, replay_buffer_observer):
"""Collect game episode trajectories."""
initial_time_step = environment.reset()

driver = py_driver.PyDriver(
py_tf_eager_policy.PyTFEagerPolicy(policy, use_tf_function=True),
initial_time_step = environment.reset()

Now we are ready to train the model:

for i in range(iterations):
# Collect a few episodes using collect_policy and save to the replay buffer.
collect_episode(train_py_env, collect_policy,
COLLECT_EPISODES_PER_ITERATION, replay_buffer_observer)

# Use data from the buffer and update the agent's network.
iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
trajectories, _ = next(iterator)

You can monitor the training progress using TensorBoard. In this case, we visualize both the average episode length and average return.

TF Agents training progress in TensorBoard

Once the policy has been trained and exported as SavedModel, you can converted it into a TFLite model:

converter = tf.lite.TFLiteConverter.from_saved_model(
policy_dir, signature_keys=['action'])
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
tflite_policy = converter.convert()
with open(os.path.join(model_dir, 'planestrike_tf_agents.tflite'), 'wb') as f:

Currently there are a few TensorFlow ops that are required during the conversion. The converted model is slightly different from the model we trained using TensorFlow directly, because it takes 4 tensors as the input. What really matters here is the ‘observation’ tensor. Our agent will look at this ‘observation’ tensor and predict its next move. The other 3 can be safely ignored at inference time.

Visualizing TFLite model converted from TF Agents using Netron
Visualizing TFLite model converted from TF Agents using Netron

Also, the model directly outputs the strike position instead of the probability distribution, so we no longer need to do argmax manually.

protected void runInference() {
Map output = new HashMap();
// TF Agent directly returns the predicted action
int[][] prediction = new int[1][1];
output.put(0, prediction);
tflite.runForMultipleInputsOutputs(inputs, output);
agentStrikePosition = prediction[0][0];

So to summarize, in this post we showed you 2 paths of how to train a game agent, convert the trained model to TFLite and deploy it into an Android app. Hopefully this end-to-end tutorial helps you better understand how to leverage the TensorFlow ecosystem to build cool games.

And lastly, if this little game looks interesting to you, we challenge you to install the app on your phone and see if you can beat the agent we have trained 😃.

Read More

TF Challenge Winners! Graphic

Announcing the Winners of the TensorFlow Lite for Microcontrollers Challenge!

Posted by Pete Warden for the TensorFlow team

TF Challenge Winners! Graphic

In May 2021, we published the TensorFlow Microcontroller Challenge, inviting developers to push the boundaries of TensorFlow Lite for Microcontrollers. Our sincere thanks go out to all those who participated in our competition and contributed to its success! Submissions came from 20 countries across 6 continents.

We’re excited to announce the five winning entries.

  • Mapping Dance by Eduardo Padrón: Take control of lighting and video projections with your dance moves.
  • Move! by Yongjae Kim, Jonghyun Baek, Eunji Lee, Yeonhee Kim, and Jueun Choi: Stay active, using movement to control a variety of games.
  • Snoring Guardian by Naveen Kumar: A snore-no-more device embedded in your pillow.
  • Squat Counter by Manas Pange: Focus on your form, while this tracker counts your squats.
  • Voice Turn by Alvaro Gonzalez-Vila: A safer way for cyclists to signal using their voice.

These projects push boundaries, spark joy, and show off the helpfulness of TensorFlow Lite for Microcontrollers. The teams who created them will each receive a prize and meet with the TensorFlow team.

To view the winning entries, check out the TensorFlow Lite for Microcontrollers collection on Experiments with Google. Read More

ML Community Day: Save the date

Posted by the TensorFlow team

Please join us for ML Community Day, a virtual developer event on November 9th. You’ll hear from the TensorFlow, JAX, and Deepmind teams and the community, covering new products, updates, and more.

This event takes place on TensorFlow’s sixth birthday (time flies!), and celebrates many years of open-source work by the developer community. We’ll have talks on on-device Machine Learning, JAX, Responsible AI, Cloud TPUs, Machine Learning in production, and sessions from the community.

You’ll also learn how you can get involved with the machine learning community by connecting with Google Developer Experts, joining Special Interest Groups, TensorFlow User Groups, and more.

Later in the day we’ll host the TensorFlow Contributor Summit, in which Google Developer Experts, Special Interest Groups, TensorFlow User Group organizers, and other community leaders gather to connect in a small group to discuss topics such as documentation and how to contribute to the TensorFlow ecosystem. During this event, we will also host the first TensorFlow Awards Ceremony to recognize outstanding community contributions.

You can find more details at the event website. We’ll see you soon!

Read More

Get Started with TensorFlow Lite Micro on Sony’s Spresense

A guest post by Daniel Sandblom, Sony

Editor’s note: an earlier version of this article appeared on the Sony Developers.

 Photo of a Sony Spresense microcontroller board

Now you can develop solutions with TensorFlow Lite Micro (TFLM) for the Spresense microcontroller board from Sony. TFLM is designed to run on microcontroller systems where the hardware resources are more limited compared to larger computerized systems. The footprint of TFLM is typically in the order of only 10’s of kBs.

What you get is a combination of a leading machine learning ecosystem with a high performance microcontroller running at super low power consumption. The Spresense board was designed with camera and hi-res audio inputs as core features which open up a substantial set of use cases. Pete Warden, a research engineer on the TensorFlow team, shares his view on that TFLM is now available for use with the Spresense board: “It’s great to see this kind of compute capability tightly integrated into a low power sensor, the combination will help make machine learning accessible to developers in medical, agriculture, industrial monitoring and many other areas where a small form factor and energy are strong constraints.”

The development of TFLM has been a tight collaboration between Google and Arm to optimize functionality while keeping the footprint to a minimum. Fredrik Knutsson, Team Lead at Arm, explains how TFLM has been optimized for the ARM processor architecture: “Arm’s open source CMSIS-NN library provides high performance implementations of common neural network functions for Arm Cortex-M processors. Arm’s engineers have worked closely with the TensorFlow team to develop optimized versions of the TensorFlow Lite kernels in the CMSIS-NN library, delivering extremely fast performance on Arm Cortex-M cores like Spresense.

How to get started with TensorFlow on Spresense

The easiest and quickest way to get started with TensorFlow on Spresense is to run one of the examples. There is one hello_world example that shows the basic steps and functionality. There is also a micro_speech example using Spresense’s audio abilities, and there’s a person_detection example utilizing the Spresense camera. The latter two examples demonstrate how to link visual and audio sensors to the inputs of TensorFlow models.

Below are the general steps to run the examples:

  1. Set up the Spresense SDK: Getting started with TensorFlow for Spresense
  2. Download the Spresense repository including the examples
  3. Build and Flash the binary into Spresense main board
  4. Run the example

Heads-up: we will run an upcoming webinar for “TensorFlow on Spresense” on October 14 – register here!

Check out these links for more info:

Read More

How DermAssist uses TensorFlow.js for on-device image quality checks

Posted by Miles Hutson and Aaron Loh, Google Health

At Google I/O in May, we previewed our DermAssist web application designed to help people understand issues related to their skin. The tool is designed to be easy to use. Upon opening it, users are expected to take three images of their skin, hair, or nail concern from multiple angles, and provide some additional information about themselves and their condition.

This product has been CE marked as a Class I medical device in the EU. It is not available in the United States.
This product has been CE marked as a Class I medical device in the EU. It is not available in the United States.

We recognize that when users take pictures on their phone, some images could be blurry or have poor lighting. To address this, we initially added a “quality check” after images had been uploaded, which would prompt people to retake an image when necessary. However, these prompts could be a frustrating experience for them, depending on their upload speed, how long they took to acquire the image, and the multiple retakes it might require to pass the quality check.

Letting the user know they uploaded an image with insufficient quality and advising them to retake it before they proceed.
Letting the user know they uploaded an image with insufficient quality and advising them to retake it before they proceed.

To improve their experience, we decided to give users image quality feedback both on-device as they line up a photo, and when they review the photo before uploading. The way this feature works can be seen below. As the user lines up their camera for a shot, they may get a notification that their environment has a lighting issue (right image). Or, they may be notified that they took a blurry photo as they moved their camera (left image); the model helpfully lets them know their image is blurred before they go to the trouble of uploading it. They can decide to go back and correct the issue without the need to upload the image.

Examples of poor lighting or blurry images that obtain real time feedback so the user knows to take a new photo

Developing the Model

When developing the model, it was important to ensure that the model could comfortably run on-device. One such architecture designed for that purpose is MobileNetV2, which we selected as the backbone for the model.

Our discussions with dermatologists highlighted recurrent issues with image quality, such as the image being too blurry, too badly lit, or inappropriate for interpreting skin diseases. We curated several datasets to tackle those issues, which also informed the outputs of the model. The datasets included a crowdsourced data collection, public datasets, data obtained from tele-dermatology services, and synthetically generated images, many of which were further labeled by trained human graders. Combined, we trained the model on more than 30k images.

We trained the model with multiple binary heads, one for each quality issue. In the diagram below, we see how the input image is fed into a MobileNet feature extractor. This feature embedding is then fed to multiple distinct fully connected layers, producing a binary output (yes/no), each corresponding to a certain quality issue.

The infrastructure we used to train the model was built using TensorFlow, and exported models in the standard SavedModel format.

Translating the model to TensorFlow.js

Our team’s infrastructure for training models makes use of TensorFlow examples, which meant that the exported SavedModel had nodes for loading and preprocessing TensorFlow Examples.

TensorFlow.js at present does not support such preprocessing nodes. Therefore, we modified the signature of the SavedModel to use the image input node after the preprocessing nodes as input to the model. We re-implemented the processing in our Angular integration below.

Having rebuilt the SavedModel in the correct format for conversion, we employed the TensorFlow.js converter to convert it to the TensorFlow.js model format, which consists of a JSON file identifying the model topology, as well as the weights in sharded bin files.

tensorflowjs_converter --input_format=keras /path/to/tfjs/signature/ /path/to/write/tfjs_model

Integrating TensorFlow.js with Observables and the Image Capture API

With the model trained, serialized, and made available for TensorFlow.js, it might feel like the job is pretty much done. However, we still had to integrate the TensorFlow.js model into our Angular 2 web application. While doing that, we had the goal that the model would ultimately be exposed as an API similar to other components. A good abstraction would allow frontend engineers to work with the TensorFlow.js model as they would work with any other part of the application, rather than as a unique component.

To begin, we created a wrapper class around the model ImageQualityPredictor. This Typescript class exposed only two methods:

  1. A static method createImageQualityPredictor that, given a URL for the model, returns a promise for an ImageQualityPredictor.
  2. A makePrediction method that takes ImageData and returns an array of quality predictions above a given threshold.

We found that the implementation of makePrediction was key for abstracting the inner workings of our model. The result of calling execute on our model was an array of Tensors representing yes/no probabilities for each binary head. But we didn’t want the downstream application code to be responsible for the delicate task of thresholding these tensors and connecting them back to the heads’ descriptions. Instead, we moved these details inside of our wrapper class. The final return value to the caller was instead an interface ImageQualityPrediction.

export interface ImageQualityPrediction {
score: number;
qualityIssue: QualityIssue;

In order to make sure that a single ImageQualityPredictor was shared across the application, we in turn wrapped ImageQualityPredictor in a singleton ImageQualityModelService. This service handled the initialization of the predictor and tracked if the predictor already had a request in progress. It also contained helper methods for extracting frames from the ImageCapture API that our camera feature is built on and translating QualityIssue to plain English strings.

Finally, we combined the CameraService and our ImageQualityModelService in an ImageQualityService. The final product exposed for use in any given front end component is a simple observable that provides text describing any quality issues.

export class ImageQualityService {
readonly realTimeImageQualityText$: Observable;

private readonly cameraService: CameraService,
private readonly imageQualityModelService: ImageQualityModelService) {
const retrieveText = () =>
this.realTimeImageQualityText$ =
filter(() => !imageQualityModelService.requestInProgress),
// ...

This lends itself well to Angular’s normal templating system, accomplishing our goal of making a TensorFlow.js model in Angular as easy to work with as any other component for front end engineers. For example, a suggestion chip can be as easy to include in a component as

 <suggestive-chip *ngIf="(imageQualityText$ | async) as text"

Looking Ahead

With helping users to capture better pictures in mind, we developed an on-device image quality check for the DermAssist app to provide real-time guidance on image intake. Part of making users’ lives easier is making sure that this model works fast enough such that we can show a notification as quickly as possible while they’re taking a picture. For us, this means finding ways to reduce the model size in order to reduce the time it takes to load on a user’s device. Possible techniques to further advance this goal may be model quantization, or attempts at model distillation into smaller architectures.

To learn more about the DermAssist application, check out our blog post from Google I/O.

To learn more about TensorFlow.js, you can visit the main site here, also be sure to check out the tutorials and guide.

Read More

TensorFlow Model Optimization Toolkit — Collaborative Optimization API

A guest post by Mohamed Nour Abouelseoud and Elena Zhelezina at Arm.

This blog post introduces collaborative techniques for machine learning model optimization for edge devices, proposed and contributed by Arm to the TensorFlow Model Optimization Toolkit, available starting from release v0.7.0.

The main idea of the collaborative optimization pipeline is to apply the different optimization techniques in the TensorFlow Model Optimization Toolkit one after another while maintaining the balance between compression and accuracy required for deployment. This leads to significant reduction in the model size and could improve inference speed given framework and hardware-specific support such as that offered by the Arm Ethos-N and Ethos-U NPUs.

This work is a part of the toolkit’s roadmap to support the development of smaller and faster ML models. You can see previous posts on post-training quantization, quantization-aware training, sparsity, and clustering for more background on the toolkit and what it can do.

What is Collaborative Optimization? And why?

The motivation behind collaborative optimization remains the same as that behind the Model Optimization Toolkit (TFMOT) in general, which is to enable model conditioning and compression for improving deployment to edge devices. The push towards edge computing and endpoint-oriented AI creates high demand for such tools and techniques. The Collaborative Optimization API stacks all of the available techniques for model optimization to take advantage of their cumulative effect and achieve the best model compression while maintaining required accuracy.

Given the following optimization techniques, various combinations for deployment are possible:

In other words, it is possible to apply one or both of pruning and clustering, followed by post-training quantization or QAT before deployment.

The challenge in combining these techniques is that the APIs don’t consider previous ones, with each optimization and fine-tuning process not preserving the results of the preceding technique. This spoils the overall benefit of simultaneously applying them; i.e., clustering doesn’t preserve the sparsity introduced by the pruning process and the fine-tuning process of QAT loses both the pruning and clustering benefits. To overcome these problems, we introduce the following collaborative optimization techniques:

Considered together, along with the option of post-training quantization instead of QAT, these provide several paths for deployment, visualized in the following deployment tree, where the leaf nodes are deployment-ready models, meaning they are fully quantized and in TFLite format. The green fill indicates steps where retraining/fine-tuning is required and a dashed red border highlights the collaborative optimization steps. The technique used to obtain a model at a given node is indicated in the corresponding label.

Keras Model Flow Chart

The direct, quantization-only (post-training or QAT) deployment path is omitted in the figure above.

The idea is to reach the fully optimized model at the third level of the above deployment tree; however, any of the other levels of optimization could prove satisfactory and achieve the required inference latency, compression, and accuracy target, in which case no further optimization is needed. The recommended training process would be to iteratively go through the levels of the deployment tree applicable to the target deployment scenario and see if the model fulfils the optimization requirements and, if not, use the corresponding collaborative optimization technique to compress the model further and repeat until the model is fully optimized (pruned, clustered, and quantized), if needed.

To further unlock the improvements in memory usage and speed at inference time associated with collaborative optimization, specialized run-time or compiler software and dedicated machine learning hardware is required. Examples include the Arm ML Ethos-N driver stack for the Ethos-N processor and the Ethos-U Vela compiler for the Ethos-U processor. Both examples currently require quantizing and converting optimized Keras models to TensorFlow Lite first.

The figure below shows the density plots of a sample weight kernel going through the full collaborative optimization pipeline.

Quantized deployment model with reduced number of unique values

The result is a quantized deployment model with a reduced number of unique values as well as a significant number of sparse weights, depending on the target sparsity specified at training time. This results in substantial model compression advantages and significantly reduced inference latency on specialized hardware.

Compression and accuracy results

The tables below show the results of running several experiments on popular models, demonstrating the compression benefits vs. accuracy loss incurred by applying these techniques. More aggressive optimizations can be applied, but at the cost of accuracy. Though the table below includes measurements for TensorFlow Lite models, similar benefits are observed for other serialization formats.

Sparsity-preserving Quantization aware training (PQAT)




Pruned Model (50% sparsity)

QAT Model

PQAT Model


FP32 Top-1 Accuracy



(Fake INT8) 94.721%

(Fake INT8) 94.128%

INT8 Top-1 Accuracy






528,128 → 434,879 (17.66%)

528,128 → 334,154 (36.73%)

512,224 → 403,261 (21.27%)

512,032 → 303,997 (40.63%)

MobileNet_v1 on ImageNet

FP32 Top-1 Accuracy



(Fake INT8) 70.67%

(Fake INT8) 70.29%

INT8 Top-1 Accuracy






4,665,520 → 3,880,331 (16.83%)

4,665,520 → 2,939,734 (37.00%)

4,569,416 → 3,808,781 (16.65%)

4,569,416 → 2,869,600 (37.20%)

Note: DS-CNN-L is a keyword spotting model designed for edge devices. More can be found in Arm’s ML Examples repository.

Cluster-preserving Quantization aware training (CQAT)




Clustered Model

QAT Model

CQAT Model

MobileNet_v1 on CIFAR-10

FP32 Top-1 Accuracy


94.48% (16 clusters)    

(Fake INT8) 94.80%

(Fake INT8) 94.60%

INT8 Top-1 Accuracy


94.41% (16 clusters)    








MobileNet_v1 on ImageNet

FP32 Top-1 Accuracy


65.30% (32 clusters)

(Fake INT8) 70.39%

(Fake INT8) 65.35%

INT8 Top-1 Accuracy


60.60% (32 clusters)




4,665,568 → 3,886,277 (16.7%)

4,665,568 → 3,035,752 (34.9%)

4,569,416 → 3,804,871 (16.7%)

4,569,472 → 2,912,655 (36.25%)

Sparsity and cluster preserving quantization aware training (PCQAT)




Pruned Model (50% sparsity)

QAT Model

Pruned Clustered Model



FP32 Top-1 Accuracy



(Fake INT8) 94.85%

93.76% (8 clusters)    

(Fake INT8) 94.28%

INT8 Top-1 Accuracy




93.21% (8 clusters)



506,400 → 425,006 (16.07%)

506,400 → 317,937 (37.22%)

507,296 → 424,368 (16.35%)

506,400 → 205,333


507,296 → 201,744 (60.23%)

MobileNet_v1 on ImageNet

FP32 Top-1 Accuracy



(Fake INT8) 70.88%

67.64% (16 clusters)

(Fake INT8) 67.80%

INT8 Top-1 Accuracy




66.89% (16 clusters)    



4,665,552 →  3,886,236 (16.70%)

4,665,552 →  2,909,148 (37.65%)

4,569,416 → 3,808,781 (16.65%)

4,665,552 →  2,013,010 (56.85%)    

4,569472 →  1,943,957 (57.46%)

Applying PCQAT

To apply PCQAT, you will need to first use the pruning API to prune the model, then chain it with clustering using the sparsity-preserving clustering API. After that, the QAT API is used along with the custom collaborative optimization quantization scheme. An example is shown below.

import tensorflow_model_optimization as tfmot
model = build_your_model()

# prune model
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model, ...)


# pruning wrappers must be stripped before clustering
stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

After pruning the model, cluster and fit as below.

# Sparsity preserving clustering
from tensorflow_model_optimization.python.core.clustering.keras.experimental import (cluster)

# Specify your clustering parameters along
# with the `preserve_sparsity` flag
clustering_params = {
'preserve_sparsity': True

# Cluster and fine-tune as usual
cluster_weights = cluster.cluster_weights
sparsity_clustered_model = cluster_weights(stripped_pruned_model_copy, **clustering_params)

# Strip clustering wrappers before the PCQAT step
stripped_sparsity_clustered_model = tfmot.clustering.keras.strip_clustering(sparsity_clustered_model)

Then apply PCQAT.

pcqat_annotate_model = quantize.quantize_annotate_model(stripped_sparsity_clustered_model )
pcqat_model = quantize.quantize_apply(quant_aware_annotate_model,scheme=default_8bit_cluster_preserve_quantize_scheme.Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))

The example above shows the training process to achieve a fully optimized PCQAT model, for the other techniques, please refer to the CQAT, PQAT, and sparsity-preserving clustering example notebooks. Note that the API used for PCQAT is the same as that of CQAT, the only difference being the use of the preserve_sparsity flag to ensure that the zero cluster is preserved during training. The PQAT API usage is similar but uses a different, sparsity preserving, quantization scheme.


The features and results presented in this post are the work of many people including the Arm ML Tooling team and our collaborators in Google’s TensorFlow Model Optimization Toolkit team.

From Arm – Anton Kachatkou, Aron Virginas-Tar, Ruomei Yan, Saoirse Stewart, Peng Sun, Elena Zhelezina, Gergely Nagy, Les Bell, Matteo Martincigh, Benjamin Klimczak, Thibaut Goetghebuer-Planchon, Tamás Nyíri, Johan Gras.

From Google – David Rim, Frederic Rechtenstein, Alan Chiao, Pulkit Bhuwalka

Read More