Introducing TensorFlow Graph Neural Networks

Posted by Sibon Li, Jan Pfeifer and Bryan Perozzi and Douglas Yarrington

Today, we are excited to release TensorFlow Graph Neural Networks (GNNs), a library designed to make it easy to work with graph structured data using TensorFlow. We have used an earlier version of this library in production at Google in a variety of contexts (for example, spam and anomaly detection, traffic estimation, YouTube content labeling) and as a component in our scalable graph mining pipelines. In particular, given the myriad types of data at Google, our library was designed with heterogeneous graphs in mind. We are releasing this library with the intention to encourage collaborations with researchers in industry.

Why use GNNs?

Graphs are all around us, in the real world and in our engineered systems. A set of objects, places, or people and the connections between them is generally describable as a graph. More often than not, the data we see in machine learning problems is structured or relational, and thus can also be described with a graph. And while fundamental research on GNNs is perhaps decades old, recent advances in the capabilities of modern GNNs have led to advances in domains as varied as traffic prediction, rumor and fake news detection, modeling disease spread, physics simulations, and understanding why molecules smell.

Graphs can model the relationships between many different types of data, including web pages (left), social connections (center), or molecules (right).
Graphs can model the relationships between many different types of data, including web pages (left), social connections (center), or molecules (right).

A graph represents the relations (edges) between a collection of entities (nodes or vertices). We can characterize each node, edge, or the entire graph, and thereby store information in each of these pieces of the graph. Additionally, we can ascribe directionality to edges to describe information or traffic flow, for example.

GNNs can be used to answer questions about multiple characteristics of these graphs. By working at the graph level, we try to predict characteristics of the entire graph. We can identify the presence of certain “shapes,” like circles in a graph that might represent sub-molecules or perhaps close social relationships. GNNs can be used on node-level tasks, to classify the nodes of a graph, and predict partitions and affinity in a graph similar to image classification or segmentation. Finally, we can use GNNs at the edge level to discover connections between entities, perhaps using GNNs to “prune” edges to identify the state of objects in a scene.

Structure

TF-GNN provides building blocks for implementing GNN models in TensorFlow. Beyond the modeling APIs, our library also provides extensive tooling around the difficult task of working with graph data: a Tensor-based graph data structure, a data handling pipeline, and some example models for users to quickly onboard.

The various components of TF-GNN that make up the workflow.
The various components of TF-GNN that make up the workflow.

The initial release of the TF-GNN library contains a number of utilities and features for use by beginners and experienced users alike, including:

  • A high-level Keras-style API to create GNN models that can easily be composed with other types of models. GNNs are often used in combination with ranking, deep-retrieval (dual-encoders) or mixed with other types of models (image, text, etc.)
    • GNN API for heterogeneous graphs. Many of the graph problems we approach at Google and in the real world contain different types of nodes and edges. Hence we chose to provide an easy way to model this.
  • A well-defined schema to declare the topology of a graph, and tools to validate it. This schema describes the shape of its training data and serves to guide other tools.
  • A GraphTensor composite tensor type which holds graph data, can be batched, and has graph manipulation routines available.
  • A library of operations on the GraphTensor structure:
    • Various efficient broadcast and pooling operations on nodes and edges, and related tools.
    • A library of standard baked convolutions, that can be easily extended by ML engineers/researchers.
    • A high-level API for product engineers to quickly build GNN models without necessarily worrying about its details.
  • An encoding of graph-shaped training data on disk, as well as a library used to parse this data into a data structure from which your model can extract the various features.

Example usage

In the example below, we build a model using the TF-GNN Keras API to recommend movies to a user based on what they watched and genres that they liked.

We use the ConvGNNBuilder method to specify the type of edge and node configuration, namely to use WeightedSumConvolution (defined below) for edges. And for each pass through the GNN, we will update the node values through a Dense interconnected layer:

    import tensorflow as tf
import tensorflow_gnn as tfgnn

# Model hyper-parameters:
h_dims = {'user': 256, 'movie': 64, 'genre': 128}

# Model builder initialization:
gnn = tfgnn.keras.ConvGNNBuilder(
lambda edge_set_name: WeightedSumConvolution(),
lambda node_set_name: tfgnn.keras.layers.NextStateFromConcat(
tf.keras.layers.Dense(h_dims[node_set_name]))
)

# Two rounds of message passing to target node sets:
model = tf.keras.models.Sequential([
gnn.Convolve({'genre'}), # sends messages from movie to genre
gnn.Convolve({'user'}), # sends messages from movie and genre to users
tfgnn.keras.layers.Readout(node_set_name="user"),
tf.keras.layers.Dense(1)
])

The code above works great, but sometimes we may want to use a more powerful custom model architecture for our GNNs. For example, in our previous use case, we might want to specify that certain movies or genres hold more weight when we give our recommendation. In the following snippet, we define a more advanced GNN with custom graph convolutions, in this case with weighted edges. We define the WeightedSumConvolution class to pool edge values as a sum of weights across all edges:

class WeightedSumConvolution(tf.keras.layers.Layer):
"""Weighted sum of source nodes states."""

def call(self, graph: tfgnn.GraphTensor,
edge_set_name: tfgnn.EdgeSetName) -> tfgnn.Field:
messages = tfgnn.broadcast_node_to_edges(
graph,
edge_set_name,
tfgnn.SOURCE,
feature_name=tfgnn.DEFAULT_STATE_NAME)
weights = graph.edge_sets[edge_set_name]['weight']
weighted_messages = tf.expand_dims(weights, -1) * messages
pooled_messages = tfgnn.pool_edges_to_node(
graph,
edge_set_name,
tfgnn.TARGET,
reduce_type='sum',
feature_value=weighted_messages)
return pooled_messages

Note that even though the convolution was written with only the source and target nodes in mind, TF-GNN makes sure it’s applicable and works on heterogeneous graphs (with various types of nodes and edges) seamlessly.

Next steps

You can check out the TF-GNN GitHub repo for more information. To stay up to date, you can read the TensorFlow blog, join the TensorFlow Forum at discuss.tensorflow.org, follow twitter.com/tensorflow, or subscribe to youtube.com/tensorflow. If you’ve built something you’d like to share, please submit it for our Community Spotlight at goo.gle/TFCS. For feedback, please file an issue on GitHub. Thank you!

Acknowledgments

The work described here was a research collaboration between Oleksandr Ferludin‎, Martin Blais, Jan Pfeifer‎, Arno Eigenwillig, Dustin Zelle, Bryan Perozzi and Da-Cheng Juan of Google, and Sibon Li, Alvaro Sanchez-Gonzalez, Peter Battaglia, Kevin Villela, Jennifer She and David Wong of DeepMind.

Read More

ML Community Day 2021 Recap

Posted by the TensorFlow Team

Thanks to everyone who joined our inaugural virtual ML Community Day! It was so great to get the community together and hear incredible talks like how JAX and TPUs make AlphaFold possible from the DeepMind team, and how Edge Impulse makes it easy for developers to work with TinyML using TensorFlow.

We also celebrated TensorFlow’s 6th birthday! The TensorFlow ecosystem has come a long way in 6 years, and we love seeing what you all achieve with our tools. From using machine learning to help advance access to human rights information, to creating a custom, TensorFlow-powered drumming arm.

In this article are a few of the updates and topics we shared during the event. You can watch the keynote below, and you can find recordings of every talk on the TensorFlow YouTube channel.


Model building

TensorFlow 2.7 is here! This release offers performance and usability improvements, including TFLite use of XNNPack for mobile inference performance boosts, training improvements on GPUs, and a dramatic improvement in debugging efficiency in Keras and TF.

Keras has been modularized as a separate pip package on top of TensorFlow (installed by default) and now lives in a separate GitHub repository. This will make it much easier for the community to contribute to the development of Keras. We welcome your PRs!

Responsible AI

The Responsible AI team also announced v0.4 of our Language Interpretability Tool (LIT). LIT is an open-source platform for visualization and understanding of NLP models. This new release includes new interpretability techniques like TCAV, Targeted Concept activation Vector. TCAV is an interpretability method for ML models that shows the importance of high level conceptsfor a predicted class.

Mobile

We recently launched on-device training in TensorFlow Lite. When deploying TensorFlow Lite machine learning model to a mobile app, you may want to enable the model to be improved or personalized based on input from the device or end user. Using on-device training techniques allows you to update a model without data leaving your users’ devices, improving user privacy, and without requiring users to update the device software. It’s currently available on Android.

And we continue to work on making performance better on TensorFlow Lite. As mentioned above, XNNPACK, a library for faster floating point ops, is now turned on by default in TensorFlow Lite. This allows your models to run on an average 2.3x faster on the CPU.

Find all the talks here

You can find all of the content in this playlist, and for your convenience here are direct links to each of the sessions also:

Read More

3D Hand Pose with MediaPipe and TensorFlow.js

Posted by Valentin Bazarevsky, Ivan Grishchenko, Eduard Gabriel Bazavan, Andrei Zanfir, Mihai Zanfir, Jiuqiang Tang, Jason Mayes, Ahmed Sabie, Google

Today, we’re excited to share a new version of our model for hand pose detection, with improved accuracy for 2D, novel support for 3D, and the new ability to predict keypoints on both hands simultaneously. Support for multi-hand tracking was one of the most common requests from the developer community, and we’re pleased to support it in this release.

You can try a live demo of the new model here. This work improves on our previous model which predicted 21 keypoints, but could only detect a single hand at a time. In this article, we’ll describe the new model, and how you can get started.

The new hand pose detection model in action.
The new hand pose detection model in action.

Try out the live demo!

How to use it

1. The first step is to import the library. You can either use the <script> tag in your html file or use NPM:

Through script tag:

<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/hand-pose-detection">>/script>
<!-- Optional: Include below scripts if you want to use MediaPipe runtime. -->
<script src="https://cdn.jsdelivr.net/npm/@mediapipe/hands"> </script >

Through NPM:

yarn add @tensorflow-models/hand-pose-detection

# Run below commands if you want to use TF.js runtime.
yarn add @tensorflow/tfjs-core @tensorflow/tfjs-converter
yarn add @tensorflow/tfjs-backend-webgl

# Run below commands if you want to use MediaPipe runtime.
yarn add @mediapipe/hands

If installed through NPM, you need to import the libraries first:

import * as handPoseDetection from '@tensorflow-models/hand-pose-detection';

Next create an instance of the detector:

const model = handPoseDetection.SupportedModels.MediaPipeHands;
const detectorConfig = {
runtime: 'mediapipe', // or 'tfjs'
modelType: 'full'
};
detector = await handPoseDetection.createDetector(model, detectorConfig);

Choose a modelType that fits your application needs, there are two options for you to choose from: lite, and full. From lite to full, the accuracy increases while the inference speed decreases.

2. Once you have a detector, you can pass in a video stream or static image to detect poses:

const video = document.getElementById('video');
const hands = await detector.estimateHands(video);

The output format is as follows: hands represent an array of detected hand predictions in the image frame. For each hand, the structure contains a prediction of the handedness (left or right) as well as a confidence score of this prediction. An array of 2D keypoints is also returned, where each keypoint contains x, y, and name. The x, y denotes the horizontal and vertical position of the hand keypoint in the image pixel space, and name denotes the joint label. In addition to 2D keypoints, we also return 3D keypoints (x, y, z values) in a metric scale, with the origin in auxiliary keypoint formed as an average between the first knuckles of index, middle, ring and pinky fingers.

[
{
score: 0.8,
Handedness: 'Right',
keypoints: [
{x: 105, y: 107, name: "wrist"},
{x: 108, y: 160, name: "pinky_finger_tip"},
...
]
keypoints3D: [
{x: 0.00388, y: -0.0205, z: 0.0217, name: "wrist"},
{x: -0.025138, y: -0.0255, z: -0.0051, name: "pinky_finger_tip"},
...
]
}
]

You can refer to our README for more details about the API.

Model deep dive

The updated version of our hand pose detection API improves the quality for 2D keypoint prediction, handedness (classification output whether it is left or right hand), and minimizes the number of false positive detections. More details about the updated model can be found in our recent paper: On-device Real-time Hand Gesture Recognition.

Following our recently released BlazePose GHUM 3D in TensorFlow.js, we also added metric-scale 3D keypoint prediction to hand pose detection in this release, with the origin being represented by an auxiliary keypoint, formed as a mean of first knuckles for index, middle, ring and pinky fingers. Our 3D ground truth is based on a statistical 3D human body model called GHUM, which is built using a large corpus of human shapes and motions.

To obtain hand pose ground truth, we fitted the GHUM hand model to our existing 2D hand dataset and recovered real world 3D keypoint coordinates. The shape and the hand pose variables of the GHUM hand model were optimized such that the reconstructed model aligns with the image evidence. This includes 2D keypoint alignment, shape, and pose regularization terms as well as anthropometric joint angle limits and model self contact penalties.

Sample GHUM hand fittings for hand images with 2D keypoint annotations overlaid. The data was used to train and test a variety of poses leading to better results for more extreme poses.
Sample GHUM hand fittings for hand images with 2D keypoint annotations overlaid. The data was used to train and test a variety of poses leading to better results for more extreme poses.

Model quality

In this new release, we substantially improved the quality of models, and evaluated them on a dataset of American Sign Language (ASL) gestures. As evaluation metric for 2D screen coordinates, we used Mean Average Precision (mAP) suggested by the COCO keypoint challenge methodology.

Hand model evaluation on American Sign Language dataset
Hand model evaluation on American Sign Language dataset

For 3D evaluation we used Mean Absolute Error in Euclidean 3D metric space, with the average error measured in centimeters.

Model Name

2D, mAP, %

3D, mean 3D error, cm

HandPose GHUM Lite

79.2

1.4

HandPose GHUM Full

83.8

1.3

Previous TensorFlow.js HandPose

66.5

N/A

Quality metrics for newly released HandPose GHUM models vs. previously released TensorFlow.js HandPose model in for 2D and 3D predictions

Browser performance

We benchmark the model across multiple devices. All the benchmarks are tested with two hands presented.

MacBook Pro 15” 2019. 

Intel core i9. 

AMD Radeon Pro Vega 20 Graphics.

(FPS)

iPhone 11

(FPS)

Pixel 5

(FPS)

Desktop 

Intel i9-10900K. Nvidia GTX 1070 GPU.

(FPS)

MediaPipe Runtime

With WASM & GPU Accel.

62 | 48

8 | 5

19 | 15 

  136 | 120

TensorFlow.js Runtime
With WebGL backend

36 | 31

15 | 12

11 | 8 

 42 | 35 

Inference speed of HandPose across different devices and runtimes. The first number in each cell is for the lite model, and the second number is for the full model.

To see the model’s FPS on your device, try our demo. You can switch the model type and runtime live in the demo UI to see what works best for your device.

Cross platform availability

In addition to the JavaScript hand pose detection API, these updated hand models are also available in MediaPipe Hands as a ready-to-use Android Solution API and Python Solution API, with prebuilt packages in Android Maven Repository and Python PyPI respectively.

For instance, for Android developers the Maven package can be easily integrated into an Android Studio project by adding the following into the project’s Gradle dependencies:

dependencies {
implementation 'com.google.mediapipe:solution-core:latest.release'
implementation 'com.google.mediapipe:hands:latest.release'
}

The MediaPipe Android Solution is designed to handle different use scenarios such as processing live camera feeds, video files, as well as static images. It also comes with utilities to facilitate overlaying the output landmarks onto either CPU images (with Canvas) or GPU (using OpenGL). For instance, the following code snippet demonstrates how it can be used to process a live camera feed and render the output on screen in real-time:

// Creates MediaPipe Hands.
HandsOptions handsOptions =
HandsOptions.builder()
.setModelComplexity(1)
.setMaxNumHands(2)
.setRunOnGpu(true)
.build();
Hands hands = new Hands(activity, handsOptions);

// Connects MediaPipe Hands to camera.
CameraInput cameraInput = new CameraInput(activity);
cameraInput.setNewFrameListener(textureFrame -> hands.send(textureFrame));

// Registers a result listener.
hands.setResultListener(
handsResult -> {
handsView.setRenderData(handsResult);
handsView.requestRender();
})

// Starts the camera to feed data to MediaPipe Hands.
handsView.post(this::startCamera);

To learn more about MediaPipe Android Solutions, please refer to our documentation and try them out with the example Android Studio project. Also visit MediaPipe Solutions for more cross-platform solutions.

Acknowledgements

We would like to acknowledge our colleagues who participated in or sponsored creating HandPose GHUM 3D and building the APIs: Cristian Sminchisescu, Michael Hays, Na Li, Ping Yu, George Sung, Jonathan Baccash‎, Esha Uboweja, David Tian, Kanstantsin Sokal‎, Gregory Karpiak, Tyler Mullen, Chuo-Ling Chang, Matthias Grundmann.

Read More

Women in Machine Learning Symposium – Event Recap

Posted by Joana Carrasqueira, Program Manager, TensorFlow.

Thank you to everyone who joined us at the first Women in Machine Learning Symposium!

Hundreds of practitioners joined from all over the world to share tips and insights for careers in ML, how to be involved in the community, contribute to open source, and much more. It was very inspiring to learn from each other’s experiences. Following is a quick recap, and an overview of the resources we discussed at the event. Thanks again.

Online education

Get involved in the community

Build your portfolio

Connect with (or become) a GDE

Read More

What’s new in TensorFlow 2.7?

Posted by Goldie Gadde and Josh Gordon for the TensorFlow team

TensorFlow 2.7 is here! This release improves usability with clearer error messages, simplified stack traces, and adds new tools and documentation for users migrating to TF2.

Improved Debugging Experience

The process of debugging your code is a fundamental part of the user experience of a machine learning framework. In this release, we’ve considerably improved the TensorFlow debugging experience to make it more productive and more enjoyable, via three major changes: simplified stack traces, displaying additional context information in errors that originate from custom Keras layers, and a wide-ranging audit of all error messages in Keras and TensorFlow.

Simplified stack traces

TensorFlow is now filtering by default the stack traces displayed upon error to hide any frame that originates from TensorFlow-internal code, and keep the information focused on what matters to you: your own code. This makes stack traces simpler and shorter, and it makes it easier to understand and fix the problems in your code.

If you’re actually debugging the TensorFlow codebase itself (for instance, because you’re preparing a PR for TensorFlow), you can turn off the filtering mechanism by calling tf.debugging.disable_traceback_filtering().

Automatic context injection for Keras layer exceptions

One of the most common use cases for writing low-level code is creating custom Keras layers, so we wanted to make debugging your layers as easy and productive as possible. The first thing you do when you’re debugging a layer is to print the shapes and dtypes of its inputs, as well the value of its training and mask arguments. We now add this information automatically to all stack traces that originate from custom Keras layers.

See the effect of stack trace filtering and call context information display in practice in the image below:

Simplified stack traces in TensorFlow 2.7
Simplified stack traces in TensorFlow 2.7

Audit and improve all error messages in the TensorFlow and Keras codebases

Lastly, we’ve audited every error message in the Keras and TensorFlow codebases (thousands of error locations!) and improved them to make sure they follow UX best practices. A good error message should tell you what the framework expected, what you did that didn’t match the framework’s expectations, and should provide tips to fix the problem.

Improve tf.function error messages

We have improved two common types of tf.function error messages: runtime error messages and “Graph” tensor error messages, by including tracebacks pointing to the error source in the user code. For other vague and inaccurate tf.function error messages, we also updated them to be more clear and accurate.

For the runtime error message caused by the user code

@tf.function
def f():
l = tf.range(tf.random.uniform((), minval=1, maxval=10, dtype=tf.int32))
return l[20]

A summary of the old error message looks like

# … Python stack trace of the function call …

InvalidArgumentError: slice index 20 of dimension 0 out of bounds.
[[node strided_slice (defined at <'ipython-input-8-250c76a76c0e'>:5) ]] [Op:__inference_f_75]

Errors may have originated from an input operation.
Input Source operations connected to node strided_slice:
range (defined at <ipython-input-8-250c76a76c0e >':4)

Function call stack:
f

A summary of the new error message looks like

# … Python stack trace of the function call …

InvalidArgumentError: slice index 20 of dimension 0 out of bounds.
[[node strided_slice
(defined at <ipython-input-3-250c76a76c0e>:5)
]] [Op:__inference_f_15]

Errors may have originated from an input operation.
Input Source operations connected to node strided_slice:
In[0] range (defined at <ipython-input-3-250c76a76c0e>:4)
In[1] strided_slice/stack:
In[2] strided_slice/stack_1:
In[3] strided_slice/stack_2:

Operation defined at: (most recent call last)
# … Stack trace of the error within the function …
>>> File "<ipython-input-3-250c76a76c0e>", line 7, in <module>
>>> f()
>>>
>>> File "<ipython-input-3-250c76a76c0e>", line 5, in f
>>> return l[20]
>>>

The main difference is runtime errors raised while executing a tf.function now include a stack trace which shows the source of the error, in the user’s code.

# … Original error message and information …
# … More stack frames …
>>> File "<ipython-input-3-250c76a76c0e>", line 7, in <module>
>>> f()
>>>
>>> File "<ipython-input-3-250c76a76c0e>", line 5, in f
>>> return l[20]
>>>

For the “Graph” tensor error messages caused by the following user code

x = None

@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2

@tf.function
def captures_leaked_tensor(b):
b += x
return b

leaky_function(tf.constant(1))
captures_leaked_tensor(tf.constant(2))

A summary of the old error message looks like

# … Python stack trace of the function call …

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: add:0

A summary of the new error message looks like

# … Python stack trace of the function call …

TypeError: Originated from a graph execution error.

The graph execution error is detected at a node built at (most recent call last):
# … Stack trace of the error within the function …
>>> File <ipython-input-5-95ca3a98778f>, line 6, in leaky_function
# … More stack trace of the error within the function …

Error detected in node 'add' defined at: File "<ipython-input-5-95ca3a98778f>", line 6, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

The main difference is errors for attempting to capture a tensor that was leaked from an unreachable graph now include a stack trace which shows where the tensor was created in the user’s code:

# … Original error message and information …
# … More stack frames …
>>> File <ipython-input-5-95ca3a98778f>, line 6, in leaky_function

Error detected in node 'add' defined at: File "<ipython-input-5-95ca3a98778f>", line 6, in leaky_function

TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

Introducing tf.experimental.ExtensionType

User-defined types can make your projects more readable, modular, maintainable. TensorFlow 2.7.0 introduces the ExtensionType API, which can be used to create user-defined object-oriented types that work seamlessly with TensorFlow’s APIs. Extension types are a great way to track and organize the tensors used by complex models. Extension types can also be used to define new tensor-like types, which specialize or extend the basic concept of “Tensor.” To create an extension type, simply define a Python class with tf.experimental.ExtensionType as its base, and use type annotations to specify the type for each field:

class TensorGraph(tf.experimental.ExtensionType):
"""A collection of labeled nodes connected by weighted edges."""
edge_weights: tf.Tensor # shape=[num_nodes, num_nodes]
node_labels: typing.Mapping[str, tf.Tensor] # shape=[num_nodes]; dtype=any

class MaskedTensor(tf.experimental.ExtensionType):
"""A tensor paired with a boolean mask, indicating which values are valid."""
values: tf.Tensor
mask: tf.Tensor # shape=values.shape; false for missing/invalid values.

class CSRSparseMatrix(tf.experimental.ExtensionType):
"""Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix)."""
values: tf.Tensor # shape=[num_nonzero]; dtype=any
col_index: tf.Tensor # shape=[num_nonzero]; dtype=int64
row_index: tf.Tensor # shape=[num_rows+1]; dtype=int64

The ExtensionType base class adds a constructor and special methods based on the field type annotations (similar to typing.NamedTuple and @dataclasses.dataclass from the standard Python library). You can optionally customize the type by overriding these defaults, or adding new methods, properties, or subclasses.

Extension types are supported by the following TensorFlow APIs:

  • Keras: Extension types can be used as inputs and outputs for Keras Models and Layers.
  • Dataset: Extension types can be included in Datasets, and returned by dataset Iterators.
  • TensorFlow hub: Extension types can be used as inputs and outputs for tf.hub modules.
  • SavedModel: Extension types can be used as inputs and outputs for SavedModel functions.
  • tf.function: Extension types can be used as arguments and return values for functions wrapped with the @tf.function decorator.
  • control flow: Extension types can be used by control flow operations, such as tf.cond and tf.while_loop. This includes control flow operations added by autograph.
  • tf.py_function: Extension types can be used as arguments and return values for the func argument to tf.py_function.
  • Tensor ops: Extension types can be extended to support most TensorFlow ops that accept Tensor inputs (e.g., tf.matmul, tf.gather, and tf.reduce_sum), using dispatch decorators.
  • distribution strategy: Extension types can be used as per-replica values.

For more information about extension types, see the Extension Type guide.

Note: The tf.experimental prefix indicates that this is a new API, and we would like to collect feedback from real-world usage; barring any unforeseen design issues, we plan to migrate ExtensionType out of the experimental package in accordance with the TF experimental policy.

TF2 Migration made easier!

To support users interested in migrating their workloads from TF1 to TF2, we have created a new Migrate to TF2 tab on the TensorFlow website, which includes updated guides and completely new documentation with concrete, runnable examples in Colab.

A new shim tool has been added which dramatically eases migration of variable_scope-based models to TF2. It is expected to enable most TF1 users to run existing model architectures as-is (or with only minor adjustments) in TF2 pipelines without having to rewrite your modeling code. You can learn more about it in the model mapping guide.

New community contributed models on TensorFlow Hub

Since the last TensorFlow release, the community really came together to make many new models available on TensorFlow Hub. Now you can find models like MLP-Mixer, Vision Transformers, Wav2Vec2, RoBERTa, ConvMixer, DistillBERT, YoloV5 and many more. All of these models are ready to use via TensorFlow Hub. You can learn more about publishing your models here.

Next steps

Check out the release notes for more information. To stay up to date, you can read the TensorFlow blog, follow twitter.com/tensorflow, or subscribe to youtube.com/tensorflow. If you’ve built something you’d like to share, please submit it for our Community Spotlight at goo.gle/TFCS. For feedback, please file an issue on GitHub or post to the TensorFlow Forum. Thank you!

Read More

On-device training in TensorFlow Lite

Posted by the TensorFlow Lite team

TensorFlow Lite is Google’s machine learning framework to deploy machine learning models on multiple devices and surfaces such as mobile (iOS and Android), desktops and other edge devices. Recently, we added support to run TensorFlow Lite models in a browser as well. In order to build apps using TensorFlow Lite, you can either use an off-the shelf model from TensorFlow Hub, or convert an existing TensorFlow Model to a TensorFlow Lite model using the converter. Once the model is deployed in an app, you can run inference on the model based on input data.

TensorFlow Lite now supports training your models on-device, in addition to running inference. On-device training enables interesting personalization use cases where models can be fine-tuned based on user needs. For instance, you could deploy an image classification model and allow a user to fine-tune the model to recognize bird species using transfer learning, while allowing another user to retrain the same model to recognize fruits. This new feature is available in TensorFlow 2.7 and later and is currently available for Android apps. (iOS support will be added in the future.)

On-device training is also a necessary foundation for Federated Learning use cases to train global models on decentralized data. This blog post does not cover Federated Learning and instead focuses on helping you integrate on-device training in your Android apps.

Later in this article we will reference a Colab and Android sample app as we walk you through the end-to-end implementation path for on-device learning to fine-tune an image classification model.

Improvements over the earlier approach

In our 2019 blog post, we introduced on-device training concepts and an example of on-device training in TensorFlow Lite. However, there were several limitations. For example, it was not easy to customize the model structure and optimizers. You also had to deal with multiple physical TensorFlow Lite (.tflite) models instead of a single TensorFlow Lite model. Similarly, there was no easy way to store and update the training weights. Our latest TensorFlow Lite version streamlines this process by providing more convenient options for on-device training, as explained below.

How does it work?

In order to deploy a TensorFlow Lite model with on-device training built-in, here are the high level steps:

  • Build a TensorFlow model for training and inference
  • Convert the TensorFlow model to TensorFlow Lite format
  • Integrate the model in your Android app
  • Invoke model training in the app, similar to how you would invoke model inference

These steps are explained below.

Build a TensorFlow model for training and inference

The TensorFlow Lite model should not only support model inference, but also model training, which typically involves saving the model’s weights to the file system and restoring the weights from the file system. This is done to save the training weights after each training epoch, so that the next training epoch can use the weights from the previous one, instead of starting training from scratch.

Our suggested approach is to implement these tf.functions to represent training, inference, saving weights, and loading weights:

  • A train function that trains the model using training data. The train function below makes a prediction, calculates the loss (or error), and uses tf.GradientTape() to record operations for automatic differentiation and update the model’s parameters.
    # The `train` function takes a batch of input images and labels.
    @tf.function(input_signature=[
    tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
    tf.TensorSpec([None, 10], tf.float32),
    ])
    def train(self, x, y):
    with tf.GradientTape() as tape:
    prediction = self.model(x)
    loss = self._LOSS_FN(prediction, y)
    gradients = tape.gradient(loss, self.model.trainable_variables)
    self._OPTIM.apply_gradients(
    zip(gradients, self.model.trainable_variables))
    result = {"loss": loss}
    for grad in gradients:
    result[grad.name] = grad
    return result
  • An infer or a predict function that invokes model inference. This is similar to how you currently use TensorFlow Lite for inference.
    @tf.function(input_signature=[tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32)])
    def predict(self, x):
    return {
    "output": self.model(x)
    }
  • A save/restore function that saves training weights (i.e., parameters used by the model) in Checkpoints format to the file system. The save function’s code is shown below.
    @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)])
    def save(self, checkpoint_path):
    tensor_names = [weight.name for weight in self.model.weights]
    tensors_to_save = [weight.read_value() for weight in self.model.weights]
    tf.raw_ops.Save(
    filename=checkpoint_path, tensor_names=tensor_names,
    data=tensors_to_save, name='save')
    return {
    "checkpoint_path": checkpoint_path
    }

Convert to TensorFlow Lite format

You may already be familiar with the workflow to convert your TensorFlow model to the TensorFlow Lite format. Some of the low level features for on-device training (e.g., variables to store the model parameters) are still experimental, and others (e.g., weight serialization) currently rely on TF Select operators, so you will need to set these flags during conversion. You can find an example of all the flags you need to set in the Colab.

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.experimental_enable_resource_variables = True
tflite_model = converter.convert()

Integrate the model in your Android app

Once you have converted your model to the TensorFlow Lite format, you’re ready to integrate the model into your app! Refer to the Android app samples for more details.

Invoke model training and inference in app

On Android, TensorFlow Lite on-device training can be performed using either Java or C++ APIs. You can create an instance of the TensorFlow Lite Interpreter to load a model and drive model training tasks. We had previously defined multiple tf.functions: these functions can be invoked using TensorFlow Lite’s support for Signatures, which allow a single TensorFlow Lite model to support multiple ‘entry’ points. For example, we had defined a train function for on-device training, which is one of the model’s signatures. The train function can be invoked using TensorFlow Lite’s runSignature method by specifying the name of the signature (‘train’):

 // Run training for a few steps.
float[] losses = new float[NUM_EPOCHS];
for (int epoch = 0; epoch < NUM_EPOCHS; ++epoch) {
for (int batchIdx = 0; batchIdx < NUM_BATCHES; ++batchIdx) {
Map<String, Object> inputs = new HashMap<>>();
inputs.put("x", trainImageBatches.get(batchIdx));
inputs.put("y", trainLabelBatches.get(batchIdx));

Map<String, Object> outputs = new HashMap<>();
FloatBuffer loss = FloatBuffer.allocate(1);
outputs.put("loss", loss);

interpreter.runSignature(inputs, outputs, "train");

// Record the last loss.
if (batchIdx == NUM_BATCHES - 1) losses[epoch] = loss.get(0);
}
}


Similarly, the following example shows how to invoke inference using the model’s ‘infer’ signature:

try (Interpreter anotherInterpreter = new Interpreter(modelBuffer)) {
// Restore the weights from the checkpoint file.

int NUM_TESTS = 10;
FloatBuffer testImages = FloatBuffer.allocateDirect(NUM_TESTS * 28 * 28).order(ByteOrder.nativeOrder());
FloatBuffer output = FloatBuffer.allocateDirect(NUM_TESTS * 10).order(ByteOrder.nativeOrder());

// Fill the test data.

// Run the inference.
Map<String, Object> inputs = new HashMap<>>();
inputs.put("x", testImages.rewind());
Map<String, Object> outputs = new HashMap<>();
outputs.put("output", output);
anotherInterpreter.runSignature(inputs, outputs, "infer");
output.rewind();

// Process the result to get the final category values.
int[] testLabels = new int[NUM_TESTS];
for (int i = 0; i < NUM_TESTS; ++i) {
int index = 0;
for (int j = 1; j < 10; ++j) {
if (output.get(i * 10 + index) < output.get(i * 10 + j))
index = testLabels[j];
}
testLabels[i] = index;
}
}

And, that’s it! You now have a TensorFlow Lite model that is able to use on-device training. We hope that this code walkthrough gives you a good idea on how to run on-device training in TensorFlow Lite, and we’re excited to see where you take it.

Practical considerations

In theory, you should be able to apply on-device training in TensorFlow Lite to any use case that TensorFlow supports. However, in reality there are a few practical considerations that you need to keep in mind before you deploy on-device training in your apps:

  • Use cases: The Colab example shows an example of on-device training for a vision use case. If you run into issues for specific models or use cases, please let us know on GitHub.
  • Performance: Depending on the use case, on-device training could take anywhere from a few seconds to much longer. If you run on-device training as part of a user-facing feature (e.g., your end user is interacting with the feature), you should measure the time taken for a wide range of possible training inputs in your app to limit the training time. If your use-case requires very long on-device training times, consider training a model using a desktop or the cloud first, then fine-tuning it on-device.
  • Battery usage: Just like model inference, invoking model training on device may result in a battery drain. If model training is part of a feature that is not user facing, we recommend following Android’s guidelines to implement background tasks.
  • Training from scratch vs. retraining: In theory, it should be possible to train a model from scratch on device using the above features. However, in reality, training from scratch involves an enormous amount of training data and may take several days even on servers with powerful processors. Consequently, for on-device applications, we recommend retraining on an already trained model (i.e., transfer learning) as shown in the Colab example.

Roadmap

Future work includes (but is not limited to) on-device training support on iOS, performance improvements to leverage on-device accelerators (e.g. GPUs) for on-device training, reducing the binary size by implementing more training ops natively in TensorFlow Lite, higher level API support (e.g. via the TensorFlow Lite Task Library) to abstract away the implementation details and examples covering other on-device training use cases (e.g. NLP). Our long term roadmap involves potentially providing on-device end-to-end Federated Learning solutions.

Next steps

Thank you for reading! We’re excited to see what you build using on-device learning. Once again, here are links to the sample app and Colab. If you have any feedback, please let us know on the TensorFlow Forum, or on GitHub.

Acknowledgements

This post reflects the significant contributions of many people in Google’s TensorFlow Lite team including Michelle Carney, Lawrence Chan, Jaesung Chung, Jared Duke, Terry Heo, Jared Lim, Yu-Cheng Ling, Thai Nguyen, Karim Nosseir, Arun Venkatesan, Haoliang Zhang, other TensorFlow Lite team members, and our collaborators in Google Research.

Read More

Building a board game app with TensorFlow: a new TensorFlow Lite reference app

Posted by Wei Wei, Developer Advocate

Games are often used as test grounds for various reinforcement learning (RL) algorithms. While it is very exciting that machine learning researchers invent new RL algorithms to master challenging games, we are also curious to see that game developers are using RL to build gaming bots in TensorFlow for various purposes, such as quality testing, game balance tuning and game difficulty assessment.

We already have a detailed tutorial that demonstrates how to implement the actor-critic RL method for the classical CartPole gym environment with TensorFlow. In this end-to-end tutorial, we are going to show you how to use TensorFlow core, TensorFlow Agents and TensorFlow Lite to build a game agent to play against a human user in a small board game app. The end result is an Android reference app that looks like below, and we have open sourced all the code in tensorflow/examples repository for your reference.

Demo game play in ‘Plane Strike’
Demo game play in ‘Plane Strike’

The game is called ‘Plane Strike’, a small board game that resembles the board game ‘Battleship’. The rules are very simple:

  • At the beginning of the game, the user and the agent each have a ‘plane’ object (8 blue cells that form a ‘plane’ as you can see in the animation above) on their own boards; these planes are only visible to the owners of the board and hidden to their opponents.
  • The user and the agent take turns to strike at one cell of each other’s board. The user can tap any cell in the agent’s board, while the agent will automatically make the choice based on the prediction of a machine learning model. The attempted cell turns red if it is a ‘plane’ cell (‘hit’); otherwise it turns yellow (‘miss’).
  • Whoever achieves 8 red cells first wins the game; then the game is restarted with fresh boards.

Even though it may be possible to create handcrafted rules for such a small game, we turn to reinforcement learning to create a smart agent that a human player can’t easily beat. For a general introduction of reinforcement learning, please refer to this RL course from DeepMind and UCL.

We provide 2 paths of training and deployment for this game app

TensorFlow Lite with a Python model written from scratch

In this path, to train the agent, we first create a custom OpenAI gym environment ‘PlaneStrike-v0’, which allows us to easily roll out game plays and gather game logs. Then we use the reward-to-go policy gradient algorithm to train the agent. REINFORCE is a policy gradient algorithm in RL. Its basic idea is to adjust the policy network parameters based on the reward signals collected during the gameplay, so that the policy network can maximize the return in future plays.

Mathematically, the policy gradient is defined as:

where:

  • T: the number of timesteps per episode, which can vary per episode
  • st: the state at timestep t
  • at: chosen action at timestep t given state s
  • πθ: is the policy parameterized by θ
  • R(*): is the reward gathered, given the policy

Please refer to this DeepMind lecture on policy gradient for a more detailed discussion. To implement it with TensorFlow, we define a simple 3-layer MLP as our policy network, which predicts the agent’s next strike position, given the human player’s board state. Note that the log expression of the above policy gradient without the reward part is the equivalent of negative cross entropy loss. In this case, since we want to maximize the rewards, we can just minimize the categorical cross entropy loss to achieve that.

model.compile(loss='sparse_categorical_crossentropy', optimizer=sgd)

We create a play_game() function to roll out the game and help us gather game logs. After each episode, we train the agent via Keras fit() function:

model.fit(x=board_log, y=action_log, sample_weight=rewards)

Note that we pass the discounted rewards-to-go as ‘sample_weight’ into the Keras fit() function as a shortcut, to implement the policy gradient algorithm without writing a custom training loop. An intuitive way to think about this is we need a tuple of (x, y, reward) instead of just (x, y) as in supervised learning. Rewards, which can be negative, help the predictor output move toward/away from y, based on x. This is different from supervised learning (in which case your ‘sample_weight’ can never be negative).

Since what we are doing isn’t supervised learning, we can’t really use training loss to monitor the training progress. Instead, we are going to use a proxy metric ‘game_length’, which indicates how many steps the agent takes to finish each episode. Intuitively you can understand that if the agent is smarter and makes better predictions, the game length becomes shorter.

Training progress in TensorBoard

Since this is a game that needs instantaneous responses from the agent, we want to deploy the model on mobile devices instead of servers. After training the model, we use the TFLite converter to convert the Keras model into a TFLite model, and integrate it into our Android app.

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

The exported model is very fast and takes <1 ms to execute on Pixel phones. During the game play, at each step the agent looks at the user’s board position and predicts its next strike position to achieve 8 red cells as fast as possible.

convertBoardStateToByteBuffer(board);
tflite.run(boardData, outputProbArrays);
float[] probArray = outputProbArrays[0];
int agentStrikePosition = -1;
float maxProb = 0;
for (int i = 0; i int x = i / Constants.BOARD_SIZE;
int y = i % Constants.BOARD_SIZE;
if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) {
agentStrikePosition = i;
maxProb = probArray[i];
}
}

TensorFlow Lite with a model trained with TensorFlow Agents

While it’s a good exercise to write our agent from scratch using TensorFlow API, it’s better to leverage existing implementations of RL algorithms. TensorFlow Agents is a library for reinforcement learning in TensorFlow, and makes it easier to design, implement and test new RL algorithms by providing well tested modular components that can be modified and extended. TF Agents has implemented several state-of-the-art RL algorithms, including DQN, DDPG, REINFORCE, PPO, SAC and TD3. Trained policies by TF Agents can be converted to TFLite directly and deployed into mobile apps (note that this feature is only recently enabled so you will need the nightly builds of TensorFlow and TensorFlow Agents).

We use the TF Agents REINFORCE agent to train our agent. First, we need to define a TF Agents training environment as we did with the gym environment in the previous section. Then we can define an actor net as our policy network



actor_net = tfa.networks.Sequential([
tfa.keras_layers.InnerReshape([BOARD_SIZE, BOARD_SIZE], [BOARD_SIZE**2]),
tf.keras.layers.Dense(FC_LAYER_PARAMS, activation='relu'),
tf.keras.layers.Dense(BOARD_SIZE**2),
tf.keras.layers.Lambda(lambda t: tfp.distributions.Categorical(logits=t)),
], input_spec=train_py_env.observation_spec(
))

We are going to use the built-in REINFORCE agent that TF Agents has already implemented. The agent is built on top of the ‘actor_net’ defined above:

  tf_agent = reinforce_agent.ReinforceAgent(
train_env.time_step_spec(),
train_env.action_spec(),
actor_network=actor_net,
optimizer=optimizer,
normalize_returns=True,
train_step_counter=train_step_counter)

To train the agent, we need to collect some trajectories as experience. We define a function just for that using DeepMind Reverb and TF Agent PyDriver:

def collect_episode(environment, policy, num_episodes, replay_buffer_observer):
"""Collect game episode trajectories."""
initial_time_step = environment.reset()

driver = py_driver.PyDriver(
environment,
py_tf_eager_policy.PyTFEagerPolicy(policy, use_tf_function=True),
[replay_buffer_observer],
max_episodes=num_episodes)
initial_time_step = environment.reset()
driver.run(initial_time_step)

Now we are ready to train the model:

for i in range(iterations):
# Collect a few episodes using collect_policy and save to the replay buffer.
collect_episode(train_py_env, collect_policy,
COLLECT_EPISODES_PER_ITERATION, replay_buffer_observer)

# Use data from the buffer and update the agent's network.
iterator = iter(replay_buffer.as_dataset(sample_batch_size=1))
trajectories, _ = next(iterator)
tf_agent.train(experience=trajectories)
replay_buffer.clear()

You can monitor the training progress using TensorBoard. In this case, we visualize both the average episode length and average return.

TF Agents training progress in TensorBoard

Once the policy has been trained and exported as SavedModel, you can converted it into a TFLite model:

converter = tf.lite.TFLiteConverter.from_saved_model(
policy_dir, signature_keys=['action'])
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_policy = converter.convert()
with open(os.path.join(model_dir, 'planestrike_tf_agents.tflite'), 'wb') as f:
f.write(tflite_policy)

Currently there are a few TensorFlow ops that are required during the conversion. The converted model is slightly different from the model we trained using TensorFlow directly, because it takes 4 tensors as the input. What really matters here is the ‘observation’ tensor. Our agent will look at this ‘observation’ tensor and predict its next move. The other 3 can be safely ignored at inference time.

Visualizing TFLite model converted from TF Agents using Netron
Visualizing TFLite model converted from TF Agents using Netron

Also, the model directly outputs the strike position instead of the probability distribution, so we no longer need to do argmax manually.

@Override
protected void runInference() {
Map output = new HashMap();
// TF Agent directly returns the predicted action
int[][] prediction = new int[1][1];
output.put(0, prediction);
tflite.runForMultipleInputsOutputs(inputs, output);
agentStrikePosition = prediction[0][0];

So to summarize, in this post we showed you 2 paths of how to train a game agent, convert the trained model to TFLite and deploy it into an Android app. Hopefully this end-to-end tutorial helps you better understand how to leverage the TensorFlow ecosystem to build cool games.

And lastly, if this little game looks interesting to you, we challenge you to install the app on your phone and see if you can beat the agent we have trained 😃.

Read More

TF Challenge Winners! Graphic

Announcing the Winners of the TensorFlow Lite for Microcontrollers Challenge!

Posted by Pete Warden for the TensorFlow team

TF Challenge Winners! Graphic

In May 2021, we published the TensorFlow Microcontroller Challenge, inviting developers to push the boundaries of TensorFlow Lite for Microcontrollers. Our sincere thanks go out to all those who participated in our competition and contributed to its success! Submissions came from 20 countries across 6 continents.

We’re excited to announce the five winning entries.

  • Mapping Dance by Eduardo Padrón: Take control of lighting and video projections with your dance moves.
  • Move! by Yongjae Kim, Jonghyun Baek, Eunji Lee, Yeonhee Kim, and Jueun Choi: Stay active, using movement to control a variety of games.
  • Snoring Guardian by Naveen Kumar: A snore-no-more device embedded in your pillow.
  • Squat Counter by Manas Pange: Focus on your form, while this tracker counts your squats.
  • Voice Turn by Alvaro Gonzalez-Vila: A safer way for cyclists to signal using their voice.

These projects push boundaries, spark joy, and show off the helpfulness of TensorFlow Lite for Microcontrollers. The teams who created them will each receive a prize and meet with the TensorFlow team.

To view the winning entries, check out the TensorFlow Lite for Microcontrollers collection on Experiments with Google. Read More

ML Community Day: Save the date

Posted by the TensorFlow team

Please join us for ML Community Day, a virtual developer event on November 9th. You’ll hear from the TensorFlow, JAX, and Deepmind teams and the community, covering new products, updates, and more.

This event takes place on TensorFlow’s sixth birthday (time flies!), and celebrates many years of open-source work by the developer community. We’ll have talks on on-device Machine Learning, JAX, Responsible AI, Cloud TPUs, Machine Learning in production, and sessions from the community.

You’ll also learn how you can get involved with the machine learning community by connecting with Google Developer Experts, joining Special Interest Groups, TensorFlow User Groups, and more.

Later in the day we’ll host the TensorFlow Contributor Summit, in which Google Developer Experts, Special Interest Groups, TensorFlow User Group organizers, and other community leaders gather to connect in a small group to discuss topics such as documentation and how to contribute to the TensorFlow ecosystem. During this event, we will also host the first TensorFlow Awards Ceremony to recognize outstanding community contributions.

You can find more details at the event website. We’ll see you soon!

Read More

Get Started with TensorFlow Lite Micro on Sony’s Spresense

A guest post by Daniel Sandblom, Sony

Editor’s note: an earlier version of this article appeared on the Sony Developers.

 Photo of a Sony Spresense microcontroller board

Now you can develop solutions with TensorFlow Lite Micro (TFLM) for the Spresense microcontroller board from Sony. TFLM is designed to run on microcontroller systems where the hardware resources are more limited compared to larger computerized systems. The footprint of TFLM is typically in the order of only 10’s of kBs.

What you get is a combination of a leading machine learning ecosystem with a high performance microcontroller running at super low power consumption. The Spresense board was designed with camera and hi-res audio inputs as core features which open up a substantial set of use cases. Pete Warden, a research engineer on the TensorFlow team, shares his view on that TFLM is now available for use with the Spresense board: “It’s great to see this kind of compute capability tightly integrated into a low power sensor, the combination will help make machine learning accessible to developers in medical, agriculture, industrial monitoring and many other areas where a small form factor and energy are strong constraints.”

The development of TFLM has been a tight collaboration between Google and Arm to optimize functionality while keeping the footprint to a minimum. Fredrik Knutsson, Team Lead at Arm, explains how TFLM has been optimized for the ARM processor architecture: “Arm’s open source CMSIS-NN library provides high performance implementations of common neural network functions for Arm Cortex-M processors. Arm’s engineers have worked closely with the TensorFlow team to develop optimized versions of the TensorFlow Lite kernels in the CMSIS-NN library, delivering extremely fast performance on Arm Cortex-M cores like Spresense.

How to get started with TensorFlow on Spresense

The easiest and quickest way to get started with TensorFlow on Spresense is to run one of the examples. There is one hello_world example that shows the basic steps and functionality. There is also a micro_speech example using Spresense’s audio abilities, and there’s a person_detection example utilizing the Spresense camera. The latter two examples demonstrate how to link visual and audio sensors to the inputs of TensorFlow models.

Below are the general steps to run the examples:

  1. Set up the Spresense SDK: Getting started with TensorFlow for Spresense
  2. Download the Spresense repository including the examples
  3. Build and Flash the binary into Spresense main board
  4. Run the example

Heads-up: we will run an upcoming webinar for “TensorFlow on Spresense” on October 14 – register here!

Check out these links for more info:

Read More