Join us at the third Women in ML Symposium!

Join us at the third Women in ML Symposium!

Posted by Sharbani Roy – Senior Director, Product Management, Google

We’re back with the third annual Women in Machine Learning Symposium on December 7, 2023! Join us virtually from 9:30 am to 1:00 pm PT for an immersive and insightful set of deep dives for every level of Machine Learning experience.

The Women in ML Symposium is an inclusive event for anyone passionate about the transformative fields of Machine Learning (ML) and Artificial Intelligence (AI). Dive into the latest advancements in generative AI, explore the intricacies of privacy-preserving AI, dig into the underlying accelerators and ML frameworks that power models, and uncover practical applications of ML across multiple industries.

Our event offers sessions for all expertise levels, from beginners to advanced practitioners. Hear about what’s new in ML and building with Google AI from our keynote speakers, gain insights from seasoned industry leaders across Google Health, Nvidia, Adobe, and more – and discover a wealth of knowledge on topics ranging from foundational AI concepts to open source tools, techniques, and beyond.

RSVP today to secure your spot and explore our exciting agenda. We can’t wait to see you there!

Read More

Simulated Spotify Listening Experiences for Reinforcement Learning with TensorFlow and TF-Agents

Simulated Spotify Listening Experiences for Reinforcement Learning with TensorFlow and TF-Agents

Posted by Surya Kanoria, Joseph Cauteruccio, Federico Tomasi, Kamil Ciosek, Matteo Rinaldi, and Zhenwen Dai – Spotify

Introduction

Many of our music recommendation problems involve providing users with ordered sets of items that satisfy users’ listening preferences and intent at that point in time. We base current recommendations on previous interactions with our application and, in the abstract, are faced with a sequential decision making process as we continually recommend content to users.

Reinforcement Learning (RL) is an established tool for sequential decision making that can be leveraged to solve sequential recommendation problems. We decided to explore how RL could be used to craft listening experiences for users. Before we could start training Agents, we needed to pick a RL library that allowed us to easily prototype, test, and potentially deploy our solutions.

At Spotify we leverage TensorFlow and the extended TensorFlow Ecosystem (TFX, TensorFlow Serving, and so on) as part of our production Machine Learning Stack. We made the decision early on to leverage TensorFlow Agents as our RL Library of choice, knowing that integrating our experiments with our production systems would be vastly more efficient down the line.

One missing bit of technology we required was an offline Spotify environment we could use to prototype, analyze, explore, and train Agents offline prior to online testing. The flexibility of the TF-Agents library, coupled with the broader advantages of TensorFlow and its ecosystem, allowed us to cleanly design a robust and extendable offline Spotify simulator.

We based our simulator design on TF-Agents Environment primitives and using this simulator we developed, trained and evaluated sequential models for item recommendations, vanilla RL Agents (PPG, DQN) and a modified deep Q-Network, which we call the Action-Head DQN (AH-DQN), that addressed the specific challenges imposed by the large state and action space of our RL formulation.

Through live experiments we were able to show that our offline performance estimates were strongly correlated with online results. This then opened the door for large scale experimentation and application of Reinforcement Learning across Spotify, enabled by the technological foundations unlocked by TensorFlow and TF-Agents.

In this post we’ll provide more details about our RL problem and how we used TF-Agents to enable this work end to end.

The RL Loop and Simulated Users

Reinforcement Learning loop
In RL, Agents interact with the environment continuously. At a given time step the Agent consumes an observation from the environment and, using this observation, produces an action given its policy at time t. The environment then processes the action and emits both a reward and the next observation (note that although typically used interchangeably, State is the complete information required to summarize the environment post action, Observation is the portion of this information actually exposed to the Agent).

In our case the reward emitted from the environment is the response of a user to music recommendations driven by the Agent’s action. In the absence of a simulator we would need to expose real users to Agents to observe rewards. We utilize a model-based RL approach to avoid letting an untrained Agent interact with real users (with the potential of hurting user satisfaction in the training process).

In this model-based RL formulation the Agent is not trained online against real users. Instead, it makes use of a user model that predicts responses to a list of tracks derived via the Agent’s action. Using this model we optimize actions in such a way as to maximize a (simulated) user satisfaction metric. During the training phase the environment makes use of this user model to return a predicted user response to the action recommended by the Agent.

We use Keras to design and train our user model. The serialized user model is then unpacked by the simulator and used to calculate rewards during Agent training and evaluation.

Simulator Design

In the abstract, what we needed to build was clear. We needed a way to simulate user listening sessions for the Agent. Given a simulated user and some content, instantiate a listening session and let the Agent drive recommendations in that session. Allow the simulated user to “react” to these recommendations and let the Agent adjust its strategy based on this result to drive some expected cumulative reward.

The TensorFlow Agents environment design guided us in developing the modular components of our system, each of which was responsible for different parts of the overall simulation.

In our codebase we define an environment abstraction that requires the following be defined for every concrete instantiation:

class AbstractEnvironment(ABC):
_user_model: AbstractUserModel = None
_track_sampler: AbstractTrackSampler = None
_episode_tracker: EpisodeTracker = None
_episode_sampler: AbstractEpisodeSampler = None

    @abstractmethod
    def reset(self) -> List[float]:
pass

    @abstractmethod
    def step(self, action: float) -> (List[float], float, bool):
pass

    def observation_space(self) -> Dict:
pass

    @abstractmethod
    def action_space(self) -> Dict:
pass

Set-Up

At the start of Agent training we need to instantiate a simulation environment that has representations of hypothetical users and the content we’re looking to recommend to them. We base these instantiations on both real and hypothetical Spotify listening experiences. The critical information that defines these instantiations is passed to the environment via _episode_sampler. As mentioned, we also need to provide the simulator with a trained user model, in this case via _user_model.
Flow chart of agent training set up

Actions and Observations

Just like any Agent environment, our simulator requires that we specify the action_spec and observation_spec. Actions in our case may be continuous or discrete depending both on our Agent selection and how we propose to translate an Agent’s action into actual recommendations. We typically recommend ordered lists of items drawn from a pool of potential items. Formulating this action space directly would lead to it being combinatorially complex. We also assume the user will interact with multiple items, and as such previous work in this area that relies on single choice assumptions doesn’t apply.

In the absence of a discrete action space consisting of item collections we need to provide the simulator with a method for turning the Agent’s action into actual recommendations. This logic is contained in the via _track_sampler. The “example play modes” proposed by the episode sampler contains information on items that can be presented to the simulated user. The track sampler consumes these and the agent’s action and returns actual item recommendations.
Flow chart of Agent actions_spec and observation_spec combining to create a recommendation

Termination and Reset

We also need to handle the episode termination dynamics. In our simulator, the reset rules are set by the model builder and based on empirical investigations of interaction data relevant to a specific music listening experience. As a hypothetical, we may determine that 92% of listening sessions terminate after 6 sequential track skips and we’d construct our simulation termination logic to match. It also requires that we design abstractions in our simulator that allow us to check if the episode should be terminated after each step.

When the episode is reset the simulator will sample a new hypothetical user listening session pair and begin the next episode.

Episode Steps

As with standard TF Agents Environments we need to define the step dynamics for our simulation. We have optional dynamics of the simulation that we need to make sure are enforced at each step. For example, we may desire that the same item cannot be recommended more than once. If the Agent’s action indicates a recommendation of an item that was previously recommended we need to build in the functionality to pick the next best item based on this action.

We also need to call the termination (and other supporting functions) mentioned above as needed at each step.

Episode Storage and Replay

The functionality mentioned up until this point collectively created a very complex simulation setup. While the TF Agents replay buffer provided us with the functionality required to store episodes for Agent training and evaluation, we quickly realized the need to be able to store more episode data for debugging purposes, and more detailed evaluations specific to our simulation distinct from standard Agent performance measures.

We thus allowed for the inclusion of an expanded _episode_tracker that would store additional information about the user model predictions, information noting the sampled users/content pairs, and more.

Creating TF-Agent Environments

Our environment abstraction gives us a template that matches that of a standard TF-Agents Environment class. Some inputs to our environment need to be resolved before we can actually create the concrete TF-Agents environment instance. This happens in three steps.

First we define a specific simulation environment that conforms to our abstraction. For example:

class PlaylistEnvironment(AbstractEnvironment):
def __init__(
self,
user_model: AbstractUserModel,
track_sampler: AbstractTrackSampler,
episode_tracker: EpisodeTracker,
episode_sampler: AbstractEpisodeSampler,
....
):

...

Next we use an Environment Builder Class that takes as input a user model, track sampler, etc. and an environment class like PlaylistEnvironment. The builder creates a concrete instance of this environment:

self.playlist_env: PlaylistEnvironment = environment_ctor(
user_model=user_model,
track_sampler=track_sampler,
episode_tracker=episode_tracker,
episode_sampler=self._eps_sampler,
)

Lastly, we utilize a conversion class that constructs a TF-Agents Environment from a concrete instance of ours:

class TFAgtPyEnvironment(py_environment.PyEnvironment):
    def __init__(self, environment: AbstractEnvironment):
  super().__init__()
  self.env = environment

This is then executed internally to our Environment Builder:

class EnvironmentBuilder(AbstractEnvironmentBuilder):

      def __init__(self, ...):
      ...

      def get_tf_env(self):
       ...
      tf_env: TFAgtPyEnvironment = TFAgtPyEnvironment(
    self.playlist_env
      )
      return tf_env

The resulting TensorFlow Agents environment can then be used for Agent training.
Flow chart showing simulator design
This simulator design allows us to easily create and manage multiple environments with a variety of different configurations as needed.

We next discuss how we used our simulator to train RL Agents to generate Playlists.

A Customized Agent for Playlist Generation

As mentioned, Reinforcement Learning provides us with a method set that naturally accommodates the sequential nature of music listening; allowing us to adapt to users’ ever evolving preferences as sessions progress.

One specific problem we can attempt to use RL to solve is that of automatic music playlist generation. Given a (large) set of tracks, we want to learn how to create one optimal playlist to recommend to the user in order to maximize satisfaction metrics. Our use case is different from standard slate recommendation tasks, where usually the target is to select at most one item in the sequence. In our case, we assume we have a user-generated response for multiple items in the slate, making slate recommendation systems not directly applicable. Another complication is that the set of tracks from which recommendations are drawn is ever changing.

We designed a DQN variant capable of handling these constraints that we called an Action Head DQN (AHDQN).
Moving image of AH-DQN network creating recommendations based on changing variables
The AH-DQN network takes as input the current state and an available action to produce a single Q value for the input action. This process is repeated for every possible item in the input. Finally, the item with the highest Q value is selected and added to the slate, and the process continues until the slate is full.

Experiments In Brief

We tested our approach both offline and online at scale to assess the ability of the Agent to power our real-world recommender systems. In addition to testing the Agent itself we were also keen to assess the extent to which our offline performance estimates for various policies returned by our simulator matched (or at least directionally aligned) with our online results.
Graph measuring simulated performance assessment by scaled online reward for different policies

We observed this directional alignment for numerous naive, heuristic, model driven, and RL policies.

Please refer to our KDD paper for more information on the specifics of our model-based RL approach and Agent design.

Federico Tomasi, Joseph Cauteruccio, Surya Kanoria, Kamil Ciosek, Matteo Rinaldi, and Zhenwen Dai
KDD 2023

Acknowledgements

We’d like to thank all our Spotify teammates past and present who contributed to this work. Particularly, we’d like to thank Mehdi Ben Ayed for his early work in helping to develop our RL codebase. We’d also like to thank the TensorFlow Agents team for their support and encouragement throughout this project (and for the library that made it possible).

Read More

Building a board game with the TFLite plugin for Flutter

Building a board game with the TFLite plugin for Flutter

Posted by Wei Wei, Developer Advocate

In our previous blog posts Building a board game app with TensorFlow: a new TensorFlow Lite reference app and Building a reinforcement learning agent with JAX, and deploying it on Android with TensorFlow Lite, we demonstrated how to train a reinforcement learning (RL) agent with TensorFlow, TensorFlow Agents and JAX respectively, and then deploy the converted TFLite model in an Android app using TensorFlow Lite, to play a simple board game ‘Plane Strike’.

While these end-to-end tutorials are helpful for Android developers, we have heard from the Flutter developer community that it would be interesting to make the app cross-platform. Inspired by the officially released TensorFlow Lite Plugin for Flutter recently, we are going to write one last tutorial and port the app to Flutter.
Flow Chart illustrating training a Reinforncement Learning (RL) Agent with TensorFlow, TensorFlow Agents and JAX, deploying the converted model in an Android app and Flutter using the TensorFlow Lite plugin

Since we already have the model trained with TensorFlow and converted to TFLite, we can just load the model with TFLite interpreter:

void _loadModel() async {
  // Create the interpreter
  _interpreter = await Interpreter.fromAsset(_modelFile);
}

Then we pass in the user board state and help the game agent identify the most promising position to strike next (please refer to our previous blog posts if you need a refresher on the game rules) by running TFLite inference:

int predict(List<List<double>> boardState) {
  var input = [boardState];
  var output = List.filled(_boardSize * _boardSize, 0)
      .reshape([1, _boardSize * _boardSize]);

  // Run inference
  _interpreter.run(input, output);

  // Argmax
  double max = output[0][0
];


  int
maxIdx = 0;
  for (int i = 1; i < _boardSize * _boardSize; i++) {
    if (max < output[0][i]) {
      maxIdx = i;
      max = output[0][i];
    }
  }

  return maxIdx;
}

That’s it! With some additional Flutter frontend code to render the game boards and track game progress, we can immediately run the game on both Android and iOS (currently the plugin only supports these two mobile platforms). You can find the complete code on GitHub.

ALT TEXT
If you want to dig digger, there are a couple of things you can try:
  1. Convert the TFAgents-trained model to TFLite and run it with the plugin
  2. Leverage the RL technique we have used and build a new agent for the tic tac toe game in the Flutter Casual Games Toolkit. You will need to create a new RL environment and train the model from scratch before deployment, but the core concept and technique are pretty much the same.

This concludes this mini-series of blogs on leveraging TensorFlow/JAX to build games for Android and Flutter. And we very much look forward to all the exciting things you build with our tooling, so be sure to share them with @googledevs, @TensorFlow, and your developer communities!

Read More

People of AI: Season 2

People of AI: Season 2

Posted by Ashley Oldacre

If you are joining us for the first time, you can binge listen to our amazing 8 episodes from Season 1 wherever you get your podcasts.

We are back for another season of People of AI with a new lineup of incredible guests! I am so excited to introduce my new co-host Luiz Gustavo Martins as we meet inspiring people with interesting stories in the field of Artificial Intelligence.

Last season we focused on the incredible journeys that our guests took to get into the field of AI. Through our stories, we highlighted that no matter who you are, what your interests are, or what you work on, there is a place for anyone to get into this field. We also explored how much more accessible the technology has become over the years, as well as the importance of building AI-related products responsibly and ethically. It is easier than ever to use tools, platforms and services powered by machine learning to leverage the benefits of AI, and break down the barrier of entry.

For season 2, we will feature amazing conversations, focusing on Generative AI! Specifically, we will be discussing the explosive growth of Generative AI tools and the major technology shift that has happened in recent months. We will dive into various topics to explore areas where Generative AI can contribute tremendous value, as well as boost both productivity and economic growth. We will also continue to explore the personal paths and career development of this season’s guests as they share how their interest in technology was sparked, how they worked hard to get to where they are today, and explore what it is that they are currently working on.

Starting today, we will release one new episode of season 2 per week. Listen to the first episode on the People of AI site or wherever you get your podcasts. And stay tuned for later in the season when we premiere our first video podcasts as well!

  • Episode 1: meet your hosts, Ashley and Gus and learn about Generative AI, Bard and the big shift that has dramatically changed the industry. 
  • Episode 2: meet Sunita Verma, a long-time Googler, as she shares her personal journey from Engineering to CS, and into Google. As an early pioneer of AI and Google Ads, we will talk about the evolution of AI and how Generative AI will transform the way we work. 
  • Episode 3: meet Sayak Paul, a Google Developer Expert (GDE) as we explore what it means to be a GDE and how to leverage the power of your community through community contributions. 
  • Episode 4: meet Crispin Velez, the lead for Cloud’s Vertex AI as we dig into his experience in Cloud working with customers and partners on how to integrate and deploy AI. We also learn how he grew his AI developer community in LATAM from scratch. 
  • Episode 5: meet Joyce Shen, venture capital/private equity investor. She shares her fascinating career in AI and how she has worked with businesses to spot AI talent, incorporate AI technology into workflows and implement responsible AI into their products. 
  • Episode 6: meet Anne Simonds and Brian Gary, founders of Muse https://www.museml.com. Join us as we talk about their recent journeys into AI and their new company which uses the power of Generative AI to spark creativity. 
  • Episode 7: meet Tulsee Doshi, product lead for Google’s Responsible AI efforts as we discuss the development of Google-wide resources and best practices for developing more inclusive, diverse, and ethical algorithm driven products. 
  • Episode 8: meet Jeanine Banks, Vice President and General Manager of Google Developer X and Head of Developer Relations. Join us as we debunk AI and get down to what Generative AI really is, how it has changed over the past few months and will continue to change the developer landscape. 
  • Episode 9: meet Simon Tokumine, Director of Product Management at Google. We will talk about how AI has brought us into the era of task-orientated products and is fueling a new community of makers.

Listen now to the first episode of Season 2. We can’t wait to share the stories of these exceptional People of AI with you!

This podcast is sponsored by Google. Any remarks made by the speakers are their own and are not endorsed by Google.

Read More

Pre-processing temporal data made easier with TensorFlow Decision Forests and Temporian

Pre-processing temporal data made easier with TensorFlow Decision Forests and Temporian

Posted by Google: Mathieu Guillame-Bert, Richard Stotz, Robert Crowe, Luiz GUStavo Martins (Gus), Ashley Oldacre, Kris Tonthat, Glenn Cameron, and Tryolabs: Ian Spektor, Braulio Rios, Guillermo Etchebarne, Diego Marvid, Lucas Micol, Gonzalo Marín, Alan Descoins, Agustina Pizarro, Lucía Aguilar, Martin Alcala Rubi

Temporal data is omnipresent in applied machine learning applications. Data often changes over time or is only available or valuable at a certain point in time. For example, market prices and weather conditions change constantly. Temporal data is also often highly discriminative in decision-making tasks. For example, the rate of change and interval between two consecutive heartbeats provides valuable insights into a person’s physical health, and temporal patterns of network logs are used to detect configuration issues and intrusions. Hence, it is essential to incorporate temporal data and temporal information in ML applications.

INFO:  Temporian is a new open-source Python library for preprocessing and feature engineering temporal data for machine learning applications. It is developed in collaboration between Google and Tryolabs. Check the sister blog post for more details.

This blog post demonstrates how to train a forecasting model on transactional data. Specifically, we will show how to forecast the total weekly sales from individual sales records. For the modeling part, we will use TensorFlow Decision Forests as they are well suited to handle temporal data. To feed the transaction data to our model, and to compute temporal specific features, we will use Temporian, a newly released library designed for ingesting and aggregating transactional data from multiple non-synchronized sources.

ALT TEXT

Time series are the most commonly used representation for temporal data. They consist of uniformly sampled values, which can be useful for representing aggregate signals. However, time series are sometimes not sufficient to represent the richness of available data. Instead, multivariate time series can represent multiple signals together, while time sequences or event sets can represent non-uniformly sampled measurements. Multi-index time sequences can be used to represent relations between different time sequences. In this blog post, we will use the multivariate multi-index time sequence, also known as event sets. Don’t worry, they’re not as complex as they sound.

Examples of temporal data include:

  • Weather and other environmental data for weather forecasting, soil profile forecasting and crop yield optimization, temperature tracking, and climate change characterization.

  • Sensory data for quality monitoring, and predictive maintenance.

  • Health data for early treatment, personalized medicine, and epidemic detection.

  • Retail customer data for sales forecasting, sales optimization, and targeted advertising.

  • Banking customer data for fraud detection and loan risk analysis.

  • Economic and financial data for risk analysis, budgetary analysis, stock market analysis, and yield projections.

A simple example

Let’s start with a simple example. We have collected sales records from a fictitious online shop. Each time a client makes a purchase, we record the following information: time of the purchase, client id, product purchased, and price of the product.

The dataset is stored in a single CSV file, with one transaction per line:

$ head -n 5 sales.csv
timestamp,client,product,price
2010-10-05 11:09:56,c64,p35,405.35
2010-09-27 15:00:49,c87,p29,605.35
2010-09-09 12:58:33,c97,p10,108.99
2010-09-06 12:43:45,c60,p85,443.35

Looking at data is crucial to understand the data and spot potential issues. Our first task is to load the sales data into an EventSet and plot it.

INFO: A Temporian EventSet is a general-purpose container for temporal data. It can represent multivariate time series, time sequences, and indexed data.

# Import Temporian
import temporian as tp

# Load the csv dataset
sales = tp.from_csv("/tmp/sales.csv")

# Print details about the EventSet
sales

This code snippet load and print the data:

We can also plot the data:
# Plot "price" feature of the EventSet
sales["price"].plot()

ALT TEXT

We have shown how to load and visualize temporal data in just a few lines of code. However, the resulting plot is very busy, as it shows all transactions for all clients in the same view.

A common operation on temporal data is to calculate the moving sum. Let’s calculate and plot the sum of sales for each transaction in the previous seven days. The moving sum can be computed using the moving_sum operator.

weekly_sales = sales["price"].moving_sum(tp.duration.days(7))

weekly_sales.plot()

ALT TEXT

BONUS: To make the plots interactive, you can add the interactive=True argument to the plot function. 

Sales per products

In the previous step, we computed the overall moving sum of sales for the entire shop. However, what if we wanted to calculate the rolling sum of sales for each product or client separately?

For this task, we can use an index.

# Index the data by "product"
sales_per_product = sales.add_index("product")

# Compute the moving sum for each product
weekly_sales_per_product = sales_per_product["price"].moving_sum(
        tp.duration.days(7)
)

# Plot the results
weekly_sales_per_product.plot()

ALT TEXT

NOTE: Many operators such as moving_sum applied independently on each index.

Aggregate transactions into time series

Our dataset contains individual client transactions. To use this data with a machine learning model, it is often useful to aggregate it into time series, where the data is sampled uniformly over time. For example, we could aggregate the sales weekly, or calculate the total sales in the last week for each day.

However, it is important to note that aggregating transaction data into time series can result in some data loss. For example, the individual transaction timestamps and values would be lost. This is because the aggregated time series would only represent the total sales for each time period.

Let’s compute the total sales in the last week for each day for each product individually.

# The data is sampled daily
daily_sampling = sales_per_product.tick(tp.duration.days(1))

weekly_sales_daily = sales_per_product["price"].moving_sum(
    tp.duration.days(7),
    sampling=daily_sampling, # The new bit
)

weekly_sales_daily.plot()

ALT TEXT

NOTE: The current plot is a continuous line, while the previous plots have markers. This is because Temporian uses continuous lines by default when the data is uniformly sampled, and markers otherwise.

After the data preparation stage is finished, the data can be exported to a Pandas DataFrame as a final step.

tp.to_pandas(weekly_sales_daily)

Train a forecasting model with TensorFlow model

A key application of Temporian is to clean data and perform feature engineering for machine learning models. It is well suited for forecasting, anomaly detection, fraud detection, and other tasks where data comes continuously.

In this example, we show how to train a TensorFlow model to predict the next day’s sales using past sales for each product individually. We will feed the model various levels of aggregations of sales as well as calendar information.

Let’s first augment our dataset and convert it to a dataset compatible with a tabular ML model.

sales_per_product = sales.add_index("product")

# Create one example per day
daily_sampling = sales_per_product.tick(tp.duration.days(1))

# Compute moving sums with various window length.
# Machine learning models are able to select the ones that matter.

features = [] for w in [3, 7, 14, 28]:
features.append(sales_per_product["price"] .moving_sum(
tp.duration.days(w),
sampling=daily_sampling)
.rename(f"moving_sum_{w}"))

# Calendar information such as the day of the week are
# very informative of human activities.
features.append(daily_sampling.calendar_day_of_week())

# The label is the daly sales shifted / leaked one days in the future.
label = (sales_per_product["price"] .leak(tp.duration.days(1))
.moving_sum(
tp.duration.days(1),
sampling=daily_sampling,
)
.rename("label"))

# Collect the features and labels together.
dataset = tp.glue(*features, label)

dataset

ALT TEXT

We can then convert the dataset from EventSet to TensorFlow Dataset format, and train a Random Forest.

import tensorflow_decision_forests as tfdf

def extract_label(example):
example.pop("timestamp") # Don't use use the timestamps as feature
label = example.pop("label")
return example, label

tf_dataset = tp.to_tensorflow_dataset(dataset).map(extract_label).batch(100)

model = tfdf.keras.RandomForestModel(task=tfdf.keras.Task.REGRESSION,verbose=2)
model.fit(tf_dataset)

And that’s it, we have a model trained to forecast sales. We now can look at the variable importance of the model to understand what features matter the most.

model.summary()

In the summary, we can find the INV_MEAN_MIN_DEPTH variable importance:

Type: "RANDOM_FOREST"
Task: REGRESSION
...
Variable Importance: INV_MEAN_MIN_DEPTH:
1. "moving_sum_28" 0.342231 ################
2. "product" 0.294546 ############
3. "calendar_day_of_week" 0.254641 ##########
4. "moving_sum_14" 0.197038 ######
5. "moving_sum_7" 0.124693 #
6. "moving_sum_3" 0.098542

We see that moving_sum_28 is the feature with the highest importance (0.342231). This indicates that the sum of sales in the last 28 days is very important to the model. To further improve our model, we should probably add more temporal aggregation features. The product feature also matters a lot.

And to get an idea of the model itself, we can plot one of the trees of the Random Forest.

tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=2)
ALT TEXT

More on temporal data preprocessing

We demonstrated some simple data preprocessing. If you want to see other examples of temporal data preprocessing on different data domains, check the Temporian tutorials. Notably:

  • Heart rate analysis ❤️ detects individual heartbeats and derives heart rate related features on raw ECG signals from Physionet.
  • M5 Competition 🛒 predicts retail sales in the M5 Makridakis Forecasting competition.
  • Loan outcomes prediction 🏦 prepares relational SQL data to predict outcomes for finished loans.
  • Detecting payment card fraud 💳 detects fraudulent payment card transactions in real time.
  • Supervised and unsupervised anomaly detection 🔎 perform data analysis and feature engineering to detect anomalies in a group of server’s resource usage metrics.

Next Steps

We demonstrated how to handle temporal data such as transactions in TensorFlow using the Temporian library. Now you can try it too!

To learn more about model training with TensorFlow Decision Forests:

Read More

Distributed Fast Fourier Transform in TensorFlow

Distributed Fast Fourier Transform in TensorFlow

Posted by Ruijiao Sun, Google Intern – DTensor team

Fast Fourier Transform is an important method of signal processing, which is commonly used in a number of ways, including speeding up convolutions, extracting features, and regularizing models. Distributed Fast Fourier Transform (Distributed FFT) offers a way to compute Fourier Transforms in models that work with image-like datasets that are too large to fit into the memory of a single accelerator device. In a previous Google Research Paper, “Large-Scale Discrete Fourier Transform on TPUs” by Tianjian Lu, a Distributed FFT algorithm was implemented for TensorFlow v1 as a library. This work presents the newly added native support in TensorFlow v2 for Distributed FFT, through the new TensorFlow distribution API, DTensor.

About DTensor

DTensor is an extension to TensorFlow for synchronous distributed computing. It distributes the program and tensors through a procedure called Single program, multiple data (SPMD) extension. DTensor offers an uniform API for traditional data and model parallelism patterns used widely in Machine Learning.

Example Usage

The API interface for distributed FFT is the same as the original FFT in TensorFlow. Users just need to pass a sharded tensor as an input to the existing FFT ops in TensorFlow, such as tf.signal.fft2d. The output of a distributed FFT becomes sharded too.

import TensorFlow as tf
from TensorFlow.experimental import dtensor

# Set up devices
device_type = dtensor.preferred_device_type()
if device_type == 'CPU':
cpu = tf.config.list_physical_devices(device_type)
tf.config.set_logical_device_configuration(cpu[0], [tf.config.LogicalDeviceConfiguration()] * 8)
if device_type == 'GPU':
gpu = tf.config.list_physical_devices(device_type)
tf.config.set_logical_device_configuration(gpu[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * 8)
dtensor.initialize_accelerator_system()

# Create a mesh
mesh = dtensor.create_distributed_mesh(mesh_dims=[('x', 1), ('y', 2), ('z', 4)], device_type=device_type)

# Set up a distributed input Tensor
input = tf.complex(
tf.random.stateless_normal(shape=(2, 2, 4), seed=(1, 2), dtype=tf.float32),
tf.random.stateless_normal(shape=(2, 2, 4), seed=(2, 4), dtype=tf.float32))
init_layout = dtensor.Layout(['x', 'y', 'z'], mesh)
d_input = dtensor.relayout(input, layout=init_layout)

# Run distributed fft2d. DTensor determines the most efficient
#
layout of of d_output.
d_output = tf.signal.fft2d(d_input)

Performance Analysis

The following experiment demonstrates that the distributed FFT can process more data than the non-distributed one by utilizing memory across multiple devices. The tradeoff is spending additional time on communication and data transposes that slow down the calculation speed.

Graph of performance on different machines, measuri8ng wall clock time in seconds by size per dimension across single GPU, Distributed FFT and Undistributed FFT

This phenomenon is shown in detail from the profiling result of the 10K*10K distributed FFT experiment. The current implementation of distributed FFT in TensorFlow follows the simple shuffle+local FFT method, which is also used by other popular distributed FFT libraries such as FFTW and PFFT. Notably, the two local FFT ops only take 3.6% of the total time (15ms). This is around 1/3 of the time for non-distributed fft2d. Most of the computing time is spent on data shuffling, represented by the ncclAllToAll Operation. Note that these experiments were conducted on an 8xV100 GPU system.

Table of Top 10 TensorFlow operations on GPU highlighting two local FFT ops in the top 3

Next steps

The feature is new and we have adopted a simplest distributed FFT algorithm. A few ideas to fine tune or improve the performance are:

  • Switch to a different DFT/FFT algorithm.
  • Tweaks on the NCCL communication settings for the particular FFT sizes may improve utilization of the network bandwidth and increase the speed.
  • Reducing the number of collectives to minimize bandwidth requirements.
  • Use N-d local FFTs, rather than multiple 1-d local FFTs.

Try the new distributed FFT! We welcome your feedback on the TensorFlow Forum and look forward to working with you on improving the performance. Your input would be invaluable!

Read More

The TensorFlow Lite Plugin for Flutter is Officially Available

The TensorFlow Lite Plugin for Flutter is Officially Available

Posted by Paul Ruiz, Developer Relations Engineer

We’re excited to announce that the TensorFlow Lite plugin for Flutter has been officially migrated to the TensorFlow GitHub account and released!

Three years ago, Amish Garg, one of our talented Google Summer of Code contributors, wrote a widely used TensorFlow Lite plugin for Flutter. The plugin was so popular that we decided to migrate it to our official repo, making it easier to maintain directly by the Google team. We are grateful to Amish for his contributions to the TensorFlow Lite Flutter plugin.

Through the efforts of developers in the community, the plugin has been updated to the latest version of TensorFlow Lite, and a collection of new features and example apps have been added, such as object detection through a live camera feed.

Moving image of a live camera feed showing several objects on a work desk being detected

So what is TensorFlow Lite? TensorFlow Lite is a way to run TensorFlow models on devices locally, supporting mobile, embedded, web, and edge devices. TensorFlow Lite’s cross-platform support and on-device performance optimizations make it a great addition to the Flutter development toolbox. Our goal with this plugin is to make it easy to integrate TensorFlow Lite models into Flutter apps across mobile platforms, with desktop support currently in development through the efforts of our developer community. Find pre-trained TensorFlow Lite models on model repos like Kaggle Models or create your own custom TensorFlow Lite models.

Let’s take a look at how you could use the Flutter TensorFlow Lite plugin for image classification:

TensorFlow Lite Image Classification with Flutter

First you will need to install the plugin from pub.dev. Once the plugin is installed, you can load a TensorFlow Lite model into your Flutter app and define the input and output tensor shapes. If you’re using the MobileNet model, then the input tensor will be a 224 by 224 RGB image, and the output will be a list of confidence scores for the trained labels.

// Load model
Future<void> _loadModel() async {
final options = InterpreterOptions();

// Load model from assets
interpreter = await Interpreter.fromAsset(modelPath, options: options);
// Get tensor input shape [1, 224, 224, 3]
inputTensor = interpreter.getInputTensors().first;
// Get tensor output shape [1, 1001]
outputTensor = interpreter.getOutputTensors().first;
}

To make things a bit more organized, you can also load in the labels for the 1000 items that MobileNet is trained for:

// Load labels from assets
Future<void> _loadLabels() async {
final labelTxt = await rootBundle.loadString(labelsPath);
labels = labelTxt.split('n');
}

For the sake of being succinct, let’s go ahead and skip some of the pre-processing steps, though you can find them in the repo’s image classification example here.

When you’re ready to run inference, you can create a new input and output based on the tensor shapes that you defined earlier, then call run on the interpreter to get your final results.

// Run inference
Future<void> runInference(
List<List<List<num>>> imageMatrix,
)
async {
// Tensor input [1, 224, 224, 3]
final input = [imageMatrix];
// Tensor output [1, 1001]
final output = [List<int>.filled(1001, 0)];

   // Run inference
   interpreter.run(input, output);

   // Get first output tensor
   final result = output.first;

Now that you have your results, you can match them to your labels and use them in your app.

Moving image of a live camera feed showing several objects on a work desk being correctly identified in the app

What’s next?

To explore what else you can do with the Flutter TensorFlow Lite plugin, check out the official GitHub repository where you can find examples for text classification, super resolution, style transfer, and more!

Additionally, we are working on a new plugin specifically for MediaPipe Tasks, a low-code tool for easily performing common on-device machine learning tasks. This includes image classification and object detection, like you’ve just learned about, as well as audio classification, face landmark detection, and gesture recognition, alongside a whole lot more.

We look forward to all the exciting things you make, so be sure to share them with @googledevs, @TensorFlow, and your developer communities!

Read More

Simpleperf case study: Fast initialization of TFLite’s Memory Arena

Simpleperf case study: Fast initialization of TFLite’s Memory Arena

Posted by Alan Kelly, Software Engineer

One of our previous articles, Optimizing TensorFlow Lite Runtime Memory, discusses how TFLite’s memory arena minimizes memory usage by sharing buffers between tensors. This means we can run models on even smaller edge devices. In today’s article, I will describe the performance optimization of the memory arena initialization so that our users get the benefit of low memory usage with little additional overhead.

ML is normally deployed on-device as part of a larger pipeline. TFLite is used because it’s fast and lightweight, but the rest of the pipeline must also be fast. Profiling on the target device with representative data lets us identify the slowest parts of the pipeline so that we can optimize the most important part of the code.

In this article, I will describe the profiling and optimization of TFLite’s memory arena with instructions on how to use Simpleperf and visualize the results. Sample commands are given. It is assumed that the Android NDK is installed and that you have a development device that you can connect to using adb.

Simpleperf

Simpleperf comes with some scripts to make it easier to use. run_simpleperf_on_device.py pushes simpleperf to the device and runs your binary with the given arguments.

/usr/lib/android-ndk/simpleperf/run_simpleperf_on_device.py record –call-graph fp /data/local/tmp/my_binary arg0 arg1 …

This will generate the output file perf.data which you must then copy back to your computer.

adb pull /data/local/tmp/perf.data

You then generate the binary cache which contains all the information needed later to generate a useful profile.

/usr/lib/android-ndk/simpleperf/binary_cache_builder.py -lib /your/binarys/folder -i perf.data

And generate the proto buffer used for visualization:

/usr/lib/android-ndk/simpleperf/pprof_proto_generator.py –ndk_path=/path/to/android-ndk -i perf.data -o profile.proto

You can then display the results this using pprof:

pprof -http :8888 profile.proto

And open localhost:8888 in your browser to view the profile. I find flame graphs to be the most useful:

Optimizing TFLite’s Memory Arena

ArenaPlanner::ExecuteAllocations accounts for 54.3% of the runtime of this model. I was expecting to find that ML operators such as fully connected layers or convolutions to be the bottleneck of this model, and not runtime overhead. This is a particularly bad case, the memory arena overhead isn’t this bad for every model, but improvements here will impact all models. This model has variable input sizes and many dynamic tensors, whose output size isn’t known until operator evaluation, which trigger frequent tensor re-allocations. This really is as bad as it gets. Let’s zoom in on the profile.

InterpreterInfo::num_tensors() accounts for 10.4% of the runtime. The reason this function is so expensive is because it is a virtual function which calls another function and it is called within a loop. I would never have suspected this.

for (int i = 0; i < static_cast<int>(graph_info_->num_tensors()); ++i) {

}

Arena planner does not create or destroy tensors so the number of tensors is constant. Let’s cache it.

const int num_tensors = static_cast<int>(graph_info_->num_tensors());
for (int i = 0; i < num_tensors); ++i) {

}

Our next piece of low hanging fruit is InterpreterInfo::tensor(unsigned long) which is another virtual function which does bounds checking and then returns a pointer to a tensor. Tensors are stored in an array so let’s add a function to get a pointer to this array. The commits are here and here.

After these simple changes the runtime of this model has reduced by 25% and then overhead of the memory allocator by half. Simpleperf made identifying these inefficiencies easy! Time to profile again to measure the impact of these changes.

ArenaPlanner::CalculateAllocations is now the most expensive function at 12.7%. This calls two functions: SimpleMemoryArena::Allocate and ArenaPlanner::CreateTensorAllocationVector.

Although ArenaPlanner::CreateTensorAllocationVector is the cheaper of the two, the code is far simpler so it might be easier to optimize. This function identifies which tensors are allocated between two nodes in the graph and then sorts them by size as a Greedy by Size allocation algorithm is used where the largest tensors are allocated first. The structure of the graph is constant so we can store a map of tensors allocated at each node. Instead of checking each tensor in the model to see if it is allocated between the two nodes, we can identify the tensors to be allocated by iterating through the map. The cost of ArenaPlanner::CreateTensorAllocationVector has gone from 4.8% to 0.8% of the runtime. Code can be seen here. Sort does not appear in the profile so we ignore it.

The next function to look at is ArenaPlanner::ResolveTensorAllocation which is 10.9% of the runtime after the previous optimizations. This function resets each tensor’s data pointer after allocation. However, these pointers don’t always change. How about keeping track of and only updating the ones which change? After this change, ArenaPlanner::ResolveTensorAllocation doesn’t appear in the profile anymore.

Let’s now take a look at allocation and deallocation. SimpleMemoryArena::Allocate accounts for 7% and SimpleMemoryArena::Deallocate accounts for 6.8% of the runtime. A record of all allocations in the arena are stored in a vector ordered by their offsets within the memory arena. Entries in this sorted data structure are inserted, removed and searched. These operations are all O(N) in a vector. Could an std::multimap with the offset as the key be better? A multimap is needed because the records are ordered by their offsets and there may be multiple tensors with the same offset. Removal and insertion are O(logN) but search would still be O(N) as we are searching for the tensor id and not the offset. The best way to find out is to test and profile.

Replacing the vector with a multimap actually slows down the arena code: it is almost three times slower than using a vector! While this goes against intuition, this is commonly found when optimizing code. Operations on a set or a map have linear or logarithmic complexities, however, there is also a constant value in the complexity. This value is higher than the constant value for the complexity of a vector. We also iterate through the records, which is much cheaper for a vector than for a list or multimap. A list was also tested, coming in at twice as slow as a vector.

Deallocation can still be improved though. SimpleMemoryArena::Deallocate iterates through the records and when it finds the record to deallocate, it removes it from the vector. This has O(N2) complexity. The memcpy seen in the profile above comes from the frequent calls to std::vector::erase. It is much more efficient to mark records to be erased and then to erase them in one pass using std::remove_if. The second optimization here is to look at how tensors are typically deallocated: ArenaPlanner::ResetAllocationsAfter deallocates all tensors from a node until the end of the graph. To address this, SimpleMemoryArena::DeallocateAfter(int32_t node) was added which iterates once through all the records, marking those which are allocated after the node. SimpleMemoryArena::ResolveDeallocations erases these in one pass making deallocation O(N). After these changes, ResetAllocationsAfter no longer appears in the profile! Commits are here and here.

The profile now looks very different with the overhead of tensor allocation gone from 49.9% of the runtime to 11%. This profile already looks much more reasonable. SimpleMemoryArena::Allocate is the last function left to optimize. For each tensor, this function iterates through the vector of records trying to find space for the current allocation. The complexity of this is O(N2). This is a fundamental limitation of the Greedy By Size algorithm. Efficient use of memory comes at the cost of increased overhead. Although the complexity can’t be reduced, N can. We process nodes in the order in which they are executed. Allocation information for tensors which have been deallocated on already executed nodes is not needed anymore, it is only slowing things down. Records are purged periodically so that only records which are active are considered. On a large model, this significantly reduces N. ArenaPlanner::ExecuteAllocations is no longer the most expensive function! It has gone from 11% to 6% and a fully connected operator is now the most expensive function in the profile, which is what we expect when profiling neural network inference.

This is what a neural network profile should look like. Time should be spent running your model’s operators, not in the inference runtime.

The optimized memory arena is now publicly available as part of TensorFlow 2.13.

Next Steps

Today’s post walked you through an example of Simpleperf helping to find easy to fix inefficiencies in TFLite’s memory arena that would never have been found by just looking at the code. Pprof can display annotated source code, disassembly and graphs making it easy to find the bottlenecks in your on-device pipelines.

Read More

What's new in TensorFlow 2.13 and Keras 2.13?

What’s new in TensorFlow 2.13 and Keras 2.13?

Posted by the TensorFlow and Keras Teams

TensorFlow 2.13 and Keras 2.13 have been released! Highlights of this release include publishing Apple Silicon wheels, the new Keras V3 format being default for .keras extension files and many more!

TensorFlow Core

Apple Silicon wheels for TensorFlow

TensorFlow 2.13 is the first version to provide Apple Silicon wheels, which means when you install TensorFlow on an Apple Silicon Mac, you will be able to use the latest version of TensorFlow. The nightly builds for Apple Silicon wheels were released in March 2023 and this new support will enable more fine-grained testing, thanks to technical collaboration between Apple, MacStadium, and Google.

tf.lite

The Python TensorFlow Lite Interpreter bindings now have an option to use experimental_disable_delegate_clustering flag to turn-off delegate clustering during delegate graph partitioning phase. You can set this flag in TensorFlow Lite interpreter Python API

interpreter = new Interpreter(file_of_a_tensorflowlite_model, experimental_preserve_all_tensors=False)

The flag is set to False by default. This is an advanced feature in experimental that is designed for people who insert explicit control dependencies via with tf.control_dependencies() or need to change graph execution order.

Besides, there are several operator improvements in TensorFlow Lite in 2.13

  • add operation now supports broadcasting up to 6 dimensions. This will remove explicit broadcast ops from many models. The new implementation is also much faster than the current one which calculates the entire index for both inputs the the input instead of only calculating the part that changes.
  • Improve the coverage for 16×8 quantization by enabling int16x8 ops for exp, mirror_pad, space_to_batch_nd, batch_to_space_nd
  • Increase the coverage of integer data types
    • enabled int16 for less, greater_than, equal, bitcast, bitwise_xor, right_shift, top_k, mul, and int16 indices for gather and gather_nd
    • enabled int8 for floor_div and floor_mod, bitwise_xor, bitwise_xor
    • enabled 32-bit int for bitcast, bitwise_xorright_shift

tf.data

We have improved usability and added functionality for tf.data APIs.

tf.data.Dataset.zip now supports Python-style zipping. Previously users were required to provide an extra set of parentheses when zipping datasets as in Dataset.zip((a, b, c)). With this change, users can specify the datasets to be zipped simply as Dataset.zip(a, b, c) making it more intuitive.

Additionally, tf.data.Dataset.shuffle now supports full shuffling. To specify that data should be fully shuffled, use dataset = dataset.shuffle(dataset.cardinality()). This will load the full dataset into memory so that it can be shuffled, so make sure to only use this with datasets of filenames or other small datasets.

We have also added a new tf.data.experimental.pad_to_cardinality transformation which pads a dataset with zero elements up to a specified cardinality. This is useful for avoiding partial batches while not dropping any data.

Example usage:

ds = tf.data.Dataset.from_tensor_slices({‘a’: [1, 2]})
ds = ds.apply(tf.data.experimental.pad_to_cardinality(3))
list(ds.as_numpy_iterator())
[{‘a’: 1, ‘valid’: True}, {‘a’: 2, ‘valid’: True}, {‘a’: 0, ‘valid’: False}]

This can be useful, e.g. during eval, when partial batches are undesirable but it is also important not to drop any data.

oneDNN BF16 Math Mode on CPU

oneDNN supports BF16 math mode where full FP32 tensors are implicitly down-converted to BF16 during computations for faster execution time. TensorFlow CPU users can enable this by setting the environment variable TF_SET_ONEDNN_FPMATH_MODE to BF16. This mode may negatively impact model accuracy. To go back to full FP32 math mode, unset the variable.

Keras

Keras Saving format

The new Keras V3 saving format, released in TF 2.12, is now the default for all files with the .keras extension.

You can start using it now by calling model.save(“your_model.keras”).

It provides richer Python-side model saving and reloading with numerous advantages:

  • A lightweight, faster format:
ALT TEXT
  • Human-readable: The new format is name-based, with a more detailed serialization format that makes debugging much easier. What you load is exactly what you saved, from Python’s perspective.
  • Safer: Unlike SavedModel, there is no reliance on loading via bytecode or pickling – a big advancement for secure ML, as pickle files can be exploited to cause arbitrary code execution at loading time.
  • More general: Support for non-numerical states, such as vocabularies and lookup tables, is included in the new format.
  • Extensible: You can add support for saving and loading exotic state elements in custom layers using save_assets(), such as a FIFOQueue – or anything else you want. You have full control of disk I/O for custom assets.

The legacy formats (h5 and Keras SavedModel) will stay supported in perpetuity. However, we recommend that you consider adopting the new Keras v3 format for saving/reloading in Python runtimes, and using model.export() for inference in all other runtimes (such as TF Serving).

Read More

On-device fetal ultrasound assessment with TensorFlow Lite

On-device fetal ultrasound assessment with TensorFlow Lite

Posted by Angelica Willis and Akib Uddin, Health AI Team, Google Research

How researchers at Google are working to expand global access to maternal healthcare with the help of AI

TensorFlow Lite* is an open-source framework to run machine learning models on mobile and edge devices. It’s popular for use cases ranging from image classification, object detection, speech recognition, natural language tasks, and more. From helping parents of deaf children learn sign language, to predicting air quality, projects using TensorFlow Lite are demonstrating how on-device ML could directly and positively impact lives by making these socially beneficial applications of AI more accessible, globally. In this post, we describe how TensorFlow Lite is being used to help develop ultrasound tools in under-resourced settings.

Motivation

According to the WHO, complications from pregnancy and childbirth contribute to roughly 287,000 maternal deaths and 2.4 million neonatal deaths worldwide each year. As many as 95% of these deaths occur in under-resourced settings and many are preventable if detected early. Obstetric diagnostics, such as determining gestational age and fetal presentation, are important indicators in planning prenatal care, monitoring the health of the birthing parent and fetus, and determining when intervention is required. Many of these factors are traditionally determined by ultrasound.

Advancements in sensor technology have made ultrasound devices more affordable and portable, integrating directly with smartphones. However, ultrasound requires years of training and experience, and, in many rural or underserved regions, there is a shortage of trained ultrasonography experts, making it difficult for people to access care. Due to this global lack of availability, it has been estimated that as many as two-thirds of pregnant people in these settings do not receive ultrasound screening during pregnancy.

Expanding access by enabling non-experts

Google Research is building AI models to help expand access to ultrasound, including models to predict gestational age and fetal presentation, to allow health workers with no background in ultrasonography to collect clinically useful ultrasound scans. These models make predictions from ultrasound video obtained using an easy-to-teach operating procedure, a blind sweep protocol, in which a user blindly sweeps the ultrasound probe over the patient’s abdomen. In our recent paper, “A mobile-optimized artificial intelligence system for gestational age and fetal malpresentation assessment”, published in Nature Communications Medicine, we demonstrated that, when utilizing blind sweeps, these models enable these non-experts to match standard of care performance in predicting these diagnostics.

Blind Sweep Operating Procedure

Moving images depicting blind-sweep ultrasound technique from a transverse view on the left and aerial view on the right
This blind-sweep ultrasound acquisition procedure can be performed by non-experts with only a few hours of ultrasound training.
Figure A on the left is a chart sshowing gestatinal age model performance comparing the blind sweep based gestational age with the clinical standard of care method indicating the percentile absolute error in days. Figure B on the right is a graphical representation of the fetal presentation model performance showing the Receiver Operating Characteristic curves for blind sweep based fetal malpresentation for cases collected by novices or expert sonographers.
Figure A compares our blind sweep-based gestational age regression model performance with that of the clinical standard of care method for fetal age estimation from fetal biometry measured by expert sonographers. Boxes indicate 25th, 50th, and 75th percentile absolute error in days, and whiskers indicate 5th and 95th percentile absolute error (n = 407 study participants). Figure B shows the Receiver Operating Characteristic (ROC) curves for our blind sweep-based fetal malpresentation classification model, as well as specific performance curves for cases in which blind sweeps were collected by expert sonographers or novices (n = 623 study participants). See our recent paper for further details and additional analysis.

Model development

Understanding that our target deployment environment is one in which users might not have reliable access to power and internet, we designed these models to be mobile-optimized. Our grouped convolutional LSTM architecture utilizes MobileNetV2 for feature extraction on each video frame as it is received. The final feature layer produces a sequence of image embeddings which are processed by the convolutional LSTM cell state. Since the recurrent connections only operate on the less memory-intensive embeddings, this model can run efficiently in a mobile environment.

For each subsequence of video frames that make up a sweep, we generate a clip-level diagnostic result, and in the case of gestational age, also produce a model confidence estimate represented as the predicted variance in the detected age. Clip-level gestational age predictions are aggregated via inverse variance weighting to produce a final case-level prediction.

Flow chart depicting Gestational Age LSTM Video Model

Optimization through TensorFlow Lite

On-device ML has many advantages, including providing enhanced privacy and security by ensuring that sensitive input data never needs to leave the device. Another important advantage of on-device ML, particularly for our use case, is the ability to leverage ML offline in regions with low internet connectivity, including where smartphones serve as a stand-in for more expensive traditional devices. Our prioritization of on-device ML made TensorFlow Lite a natural choice for optimizing and evaluating the memory use and execution speed of our existing models, without significant changes to model structure or prediction performance.

After converting our models to TensorFlow Lite using the converter API, we explored various optimization strategies, including post-training quantization and alternative delegate configurations. Leveraging a TensorFlow Lite GPU delegate, optimized for sustained inference speed, provided the most significant boost to execution speed. There was a roughly 2x speed improvement with no loss in model accuracy, which equated to real-time inference of more than 30 frames/second with both the gestational age and fetal presentation models running in parallel on Pixel devices. We benchmarked model initialization time, inference time and memory usage for various delegate configurations using TensorFlow Lite performance measurement tools, finding the optimal configuration across multiple mobile device manufacturers.

These critical speed improvements allow us to leverage the model confidence estimate to provide sweep-quality feedback to the user immediately after the sweep was captured. When low-quality sweeps are detected, users can be provided with tips on how their sweep can be improved (for example, applying more pressure or ultrasound gel), then prompted to re-do the sweep.

Screen capture of sweep exam being conducted
We developed a mobile application that demonstrates what a potential user experience could look like and allows us to evaluate our TensorFlow Lite models in realistic environments. This app enables ultrasound video frames to be received directly from portable ultrasound devices that support this use case.

Looking ahead

Our vision is to enable safer pregnancy journeys using AI-driven ultrasound that could broaden access globally. We want to be thoughtful and responsible in how we develop our AI to maximize positive benefits and address challenges, guided by our AI Principles. TensorFlow Lite has helped enable our research team to explore, prototype, and de-risk impactful care-delivery strategies designed with the needs of lower-resource communities in mind.

This research is in its early stages and we look forward to opportunities to expand our work. To achieve our goals and scale this technology for wider reach globally, partnerships are critical. We are excited about our partnerships with Northwestern Medicine in the US and Jacaranda Health in Kenya to further develop and evaluate these models. With more automated and accurate evaluations of maternal and fetal health risks, we hope to lower barriers and help people get timely care.

Acknowledgements

This work was developed by an interdisciplinary team within Google Research: Ryan G. Gomes, Chace Lee, Angelica Willis, Marcin Sieniek, Christina Chen, James A. Taylor, Scott Mayer McKinney, George E. Dahl, Justin Gilmer, Charles Lau, Terry Spitz, T. Saensuksopa, Kris Liu, Tiya Tiyasirichokchai, Jonny Wong, Rory Pilgrim, Akib Uddin, Greg Corrado, Lily Peng, Katherine Chou, Daniel Tse, & Shravya Shetty.

This work was developed in collaboration with:
Department of Obstetrics and Gynaecology, University of Zambia School of Medicine, Lusaka, Zambia
Department of Obstetrics and Gynecology, University of North Carolina School of Medicine, Chapel Hill, NC, USA
UNC Global Projects—Zambia, LLC, Lusaka, Zambia

Special thanks to: Yun Liu, Cameron Chen, Sami Lachgar, Lauren Winer, Annisah Um’rani, and Sachin Kotwani

*TensorFlow Lite has not been certified or validated for clinical, medical, or diagnostic purposes. TensorFlow Lite users are solely responsible for their use of the framework and independently validating any outputs generated by their project.

Read More