Building a board game with the TFLite plugin for Flutter

Building a board game with the TFLite plugin for Flutter

Posted by Wei Wei, Developer Advocate

In our previous blog posts Building a board game app with TensorFlow: a new TensorFlow Lite reference app and Building a reinforcement learning agent with JAX, and deploying it on Android with TensorFlow Lite, we demonstrated how to train a reinforcement learning (RL) agent with TensorFlow, TensorFlow Agents and JAX respectively, and then deploy the converted TFLite model in an Android app using TensorFlow Lite, to play a simple board game ‘Plane Strike’.

While these end-to-end tutorials are helpful for Android developers, we have heard from the Flutter developer community that it would be interesting to make the app cross-platform. Inspired by the officially released TensorFlow Lite Plugin for Flutter recently, we are going to write one last tutorial and port the app to Flutter.
Flow Chart illustrating training a Reinforncement Learning (RL) Agent with TensorFlow, TensorFlow Agents and JAX, deploying the converted model in an Android app and Flutter using the TensorFlow Lite plugin

Since we already have the model trained with TensorFlow and converted to TFLite, we can just load the model with TFLite interpreter:

void _loadModel() async {
  // Create the interpreter
  _interpreter = await Interpreter.fromAsset(_modelFile);
}

Then we pass in the user board state and help the game agent identify the most promising position to strike next (please refer to our previous blog posts if you need a refresher on the game rules) by running TFLite inference:

int predict(List<List<double>> boardState) {
  var input = [boardState];
  var output = List.filled(_boardSize * _boardSize, 0)
      .reshape([1, _boardSize * _boardSize]);

  // Run inference
  _interpreter.run(input, output);

  // Argmax
  double max = output[0][0
];


  int
maxIdx = 0;
  for (int i = 1; i < _boardSize * _boardSize; i++) {
    if (max < output[0][i]) {
      maxIdx = i;
      max = output[0][i];
    }
  }

  return maxIdx;
}

That’s it! With some additional Flutter frontend code to render the game boards and track game progress, we can immediately run the game on both Android and iOS (currently the plugin only supports these two mobile platforms). You can find the complete code on GitHub.

ALT TEXT
If you want to dig digger, there are a couple of things you can try:
  1. Convert the TFAgents-trained model to TFLite and run it with the plugin
  2. Leverage the RL technique we have used and build a new agent for the tic tac toe game in the Flutter Casual Games Toolkit. You will need to create a new RL environment and train the model from scratch before deployment, but the core concept and technique are pretty much the same.

This concludes this mini-series of blogs on leveraging TensorFlow/JAX to build games for Android and Flutter. And we very much look forward to all the exciting things you build with our tooling, so be sure to share them with @googledevs, @TensorFlow, and your developer communities!

Read More

People of AI: Season 2

People of AI: Season 2

Posted by Ashley Oldacre

If you are joining us for the first time, you can binge listen to our amazing 8 episodes from Season 1 wherever you get your podcasts.

We are back for another season of People of AI with a new lineup of incredible guests! I am so excited to introduce my new co-host Luiz Gustavo Martins as we meet inspiring people with interesting stories in the field of Artificial Intelligence.

Last season we focused on the incredible journeys that our guests took to get into the field of AI. Through our stories, we highlighted that no matter who you are, what your interests are, or what you work on, there is a place for anyone to get into this field. We also explored how much more accessible the technology has become over the years, as well as the importance of building AI-related products responsibly and ethically. It is easier than ever to use tools, platforms and services powered by machine learning to leverage the benefits of AI, and break down the barrier of entry.

For season 2, we will feature amazing conversations, focusing on Generative AI! Specifically, we will be discussing the explosive growth of Generative AI tools and the major technology shift that has happened in recent months. We will dive into various topics to explore areas where Generative AI can contribute tremendous value, as well as boost both productivity and economic growth. We will also continue to explore the personal paths and career development of this season’s guests as they share how their interest in technology was sparked, how they worked hard to get to where they are today, and explore what it is that they are currently working on.

Starting today, we will release one new episode of season 2 per week. Listen to the first episode on the People of AI site or wherever you get your podcasts. And stay tuned for later in the season when we premiere our first video podcasts as well!

  • Episode 1: meet your hosts, Ashley and Gus and learn about Generative AI, Bard and the big shift that has dramatically changed the industry. 
  • Episode 2: meet Sunita Verma, a long-time Googler, as she shares her personal journey from Engineering to CS, and into Google. As an early pioneer of AI and Google Ads, we will talk about the evolution of AI and how Generative AI will transform the way we work. 
  • Episode 3: meet Sayak Paul, a Google Developer Expert (GDE) as we explore what it means to be a GDE and how to leverage the power of your community through community contributions. 
  • Episode 4: meet Crispin Velez, the lead for Cloud’s Vertex AI as we dig into his experience in Cloud working with customers and partners on how to integrate and deploy AI. We also learn how he grew his AI developer community in LATAM from scratch. 
  • Episode 5: meet Joyce Shen, venture capital/private equity investor. She shares her fascinating career in AI and how she has worked with businesses to spot AI talent, incorporate AI technology into workflows and implement responsible AI into their products. 
  • Episode 6: meet Anne Simonds and Brian Gary, founders of Muse https://www.museml.com. Join us as we talk about their recent journeys into AI and their new company which uses the power of Generative AI to spark creativity. 
  • Episode 7: meet Tulsee Doshi, product lead for Google’s Responsible AI efforts as we discuss the development of Google-wide resources and best practices for developing more inclusive, diverse, and ethical algorithm driven products. 
  • Episode 8: meet Jeanine Banks, Vice President and General Manager of Google Developer X and Head of Developer Relations. Join us as we debunk AI and get down to what Generative AI really is, how it has changed over the past few months and will continue to change the developer landscape. 
  • Episode 9: meet Simon Tokumine, Director of Product Management at Google. We will talk about how AI has brought us into the era of task-orientated products and is fueling a new community of makers.

Listen now to the first episode of Season 2. We can’t wait to share the stories of these exceptional People of AI with you!

This podcast is sponsored by Google. Any remarks made by the speakers are their own and are not endorsed by Google.

Read More

Pre-processing temporal data made easier with TensorFlow Decision Forests and Temporian

Pre-processing temporal data made easier with TensorFlow Decision Forests and Temporian

Posted by Google: Mathieu Guillame-Bert, Richard Stotz, Robert Crowe, Luiz GUStavo Martins (Gus), Ashley Oldacre, Kris Tonthat, Glenn Cameron, and Tryolabs: Ian Spektor, Braulio Rios, Guillermo Etchebarne, Diego Marvid, Lucas Micol, Gonzalo Marín, Alan Descoins, Agustina Pizarro, Lucía Aguilar, Martin Alcala Rubi

Temporal data is omnipresent in applied machine learning applications. Data often changes over time or is only available or valuable at a certain point in time. For example, market prices and weather conditions change constantly. Temporal data is also often highly discriminative in decision-making tasks. For example, the rate of change and interval between two consecutive heartbeats provides valuable insights into a person’s physical health, and temporal patterns of network logs are used to detect configuration issues and intrusions. Hence, it is essential to incorporate temporal data and temporal information in ML applications.

INFO:  Temporian is a new open-source Python library for preprocessing and feature engineering temporal data for machine learning applications. It is developed in collaboration between Google and Tryolabs. Check the sister blog post for more details.

This blog post demonstrates how to train a forecasting model on transactional data. Specifically, we will show how to forecast the total weekly sales from individual sales records. For the modeling part, we will use TensorFlow Decision Forests as they are well suited to handle temporal data. To feed the transaction data to our model, and to compute temporal specific features, we will use Temporian, a newly released library designed for ingesting and aggregating transactional data from multiple non-synchronized sources.

ALT TEXT

Time series are the most commonly used representation for temporal data. They consist of uniformly sampled values, which can be useful for representing aggregate signals. However, time series are sometimes not sufficient to represent the richness of available data. Instead, multivariate time series can represent multiple signals together, while time sequences or event sets can represent non-uniformly sampled measurements. Multi-index time sequences can be used to represent relations between different time sequences. In this blog post, we will use the multivariate multi-index time sequence, also known as event sets. Don’t worry, they’re not as complex as they sound.

Examples of temporal data include:

  • Weather and other environmental data for weather forecasting, soil profile forecasting and crop yield optimization, temperature tracking, and climate change characterization.

  • Sensory data for quality monitoring, and predictive maintenance.

  • Health data for early treatment, personalized medicine, and epidemic detection.

  • Retail customer data for sales forecasting, sales optimization, and targeted advertising.

  • Banking customer data for fraud detection and loan risk analysis.

  • Economic and financial data for risk analysis, budgetary analysis, stock market analysis, and yield projections.

A simple example

Let’s start with a simple example. We have collected sales records from a fictitious online shop. Each time a client makes a purchase, we record the following information: time of the purchase, client id, product purchased, and price of the product.

The dataset is stored in a single CSV file, with one transaction per line:

$ head -n 5 sales.csv
timestamp,client,product,price
2010-10-05 11:09:56,c64,p35,405.35
2010-09-27 15:00:49,c87,p29,605.35
2010-09-09 12:58:33,c97,p10,108.99
2010-09-06 12:43:45,c60,p85,443.35

Looking at data is crucial to understand the data and spot potential issues. Our first task is to load the sales data into an EventSet and plot it.

INFO: A Temporian EventSet is a general-purpose container for temporal data. It can represent multivariate time series, time sequences, and indexed data.

# Import Temporian
import temporian as tp

# Load the csv dataset
sales = tp.from_csv("/tmp/sales.csv")

# Print details about the EventSet
sales

This code snippet load and print the data:

We can also plot the data:
# Plot "price" feature of the EventSet
sales["price"].plot()

ALT TEXT

We have shown how to load and visualize temporal data in just a few lines of code. However, the resulting plot is very busy, as it shows all transactions for all clients in the same view.

A common operation on temporal data is to calculate the moving sum. Let’s calculate and plot the sum of sales for each transaction in the previous seven days. The moving sum can be computed using the moving_sum operator.

weekly_sales = sales["price"].moving_sum(tp.duration.days(7))

weekly_sales.plot()

ALT TEXT

BONUS: To make the plots interactive, you can add the interactive=True argument to the plot function. 

Sales per products

In the previous step, we computed the overall moving sum of sales for the entire shop. However, what if we wanted to calculate the rolling sum of sales for each product or client separately?

For this task, we can use an index.

# Index the data by "product"
sales_per_product = sales.add_index("product")

# Compute the moving sum for each product
weekly_sales_per_product = sales_per_product["price"].moving_sum(
        tp.duration.days(7)
)

# Plot the results
weekly_sales_per_product.plot()

ALT TEXT

NOTE: Many operators such as moving_sum applied independently on each index.

Aggregate transactions into time series

Our dataset contains individual client transactions. To use this data with a machine learning model, it is often useful to aggregate it into time series, where the data is sampled uniformly over time. For example, we could aggregate the sales weekly, or calculate the total sales in the last week for each day.

However, it is important to note that aggregating transaction data into time series can result in some data loss. For example, the individual transaction timestamps and values would be lost. This is because the aggregated time series would only represent the total sales for each time period.

Let’s compute the total sales in the last week for each day for each product individually.

# The data is sampled daily
daily_sampling = sales_per_product.tick(tp.duration.days(1))

weekly_sales_daily = sales_per_product["price"].moving_sum(
    tp.duration.days(7),
    sampling=daily_sampling, # The new bit
)

weekly_sales_daily.plot()

ALT TEXT

NOTE: The current plot is a continuous line, while the previous plots have markers. This is because Temporian uses continuous lines by default when the data is uniformly sampled, and markers otherwise.

After the data preparation stage is finished, the data can be exported to a Pandas DataFrame as a final step.

tp.to_pandas(weekly_sales_daily)

Train a forecasting model with TensorFlow model

A key application of Temporian is to clean data and perform feature engineering for machine learning models. It is well suited for forecasting, anomaly detection, fraud detection, and other tasks where data comes continuously.

In this example, we show how to train a TensorFlow model to predict the next day’s sales using past sales for each product individually. We will feed the model various levels of aggregations of sales as well as calendar information.

Let’s first augment our dataset and convert it to a dataset compatible with a tabular ML model.

sales_per_product = sales.add_index("product")

# Create one example per day
daily_sampling = sales_per_product.tick(tp.duration.days(1))

# Compute moving sums with various window length.
# Machine learning models are able to select the ones that matter.

features = [] for w in [3, 7, 14, 28]:
features.append(sales_per_product["price"] .moving_sum(
tp.duration.days(w),
sampling=daily_sampling)
.rename(f"moving_sum_{w}"))

# Calendar information such as the day of the week are
# very informative of human activities.
features.append(daily_sampling.calendar_day_of_week())

# The label is the daly sales shifted / leaked one days in the future.
label = (sales_per_product["price"] .leak(tp.duration.days(1))
.moving_sum(
tp.duration.days(1),
sampling=daily_sampling,
)
.rename("label"))

# Collect the features and labels together.
dataset = tp.glue(*features, label)

dataset

ALT TEXT

We can then convert the dataset from EventSet to TensorFlow Dataset format, and train a Random Forest.

import tensorflow_decision_forests as tfdf

def extract_label(example):
example.pop("timestamp") # Don't use use the timestamps as feature
label = example.pop("label")
return example, label

tf_dataset = tp.to_tensorflow_dataset(dataset).map(extract_label).batch(100)

model = tfdf.keras.RandomForestModel(task=tfdf.keras.Task.REGRESSION,verbose=2)
model.fit(tf_dataset)

And that’s it, we have a model trained to forecast sales. We now can look at the variable importance of the model to understand what features matter the most.

model.summary()

In the summary, we can find the INV_MEAN_MIN_DEPTH variable importance:

Type: "RANDOM_FOREST"
Task: REGRESSION
...
Variable Importance: INV_MEAN_MIN_DEPTH:
1. "moving_sum_28" 0.342231 ################
2. "product" 0.294546 ############
3. "calendar_day_of_week" 0.254641 ##########
4. "moving_sum_14" 0.197038 ######
5. "moving_sum_7" 0.124693 #
6. "moving_sum_3" 0.098542

We see that moving_sum_28 is the feature with the highest importance (0.342231). This indicates that the sum of sales in the last 28 days is very important to the model. To further improve our model, we should probably add more temporal aggregation features. The product feature also matters a lot.

And to get an idea of the model itself, we can plot one of the trees of the Random Forest.

tfdf.model_plotter.plot_model_in_colab(model, tree_idx=0, max_depth=2)
ALT TEXT

More on temporal data preprocessing

We demonstrated some simple data preprocessing. If you want to see other examples of temporal data preprocessing on different data domains, check the Temporian tutorials. Notably:

  • Heart rate analysis ❤️ detects individual heartbeats and derives heart rate related features on raw ECG signals from Physionet.
  • M5 Competition 🛒 predicts retail sales in the M5 Makridakis Forecasting competition.
  • Loan outcomes prediction 🏦 prepares relational SQL data to predict outcomes for finished loans.
  • Detecting payment card fraud 💳 detects fraudulent payment card transactions in real time.
  • Supervised and unsupervised anomaly detection 🔎 perform data analysis and feature engineering to detect anomalies in a group of server’s resource usage metrics.

Next Steps

We demonstrated how to handle temporal data such as transactions in TensorFlow using the Temporian library. Now you can try it too!

To learn more about model training with TensorFlow Decision Forests:

Read More

Distributed Fast Fourier Transform in TensorFlow

Distributed Fast Fourier Transform in TensorFlow

Posted by Ruijiao Sun, Google Intern – DTensor team

Fast Fourier Transform is an important method of signal processing, which is commonly used in a number of ways, including speeding up convolutions, extracting features, and regularizing models. Distributed Fast Fourier Transform (Distributed FFT) offers a way to compute Fourier Transforms in models that work with image-like datasets that are too large to fit into the memory of a single accelerator device. In a previous Google Research Paper, “Large-Scale Discrete Fourier Transform on TPUs” by Tianjian Lu, a Distributed FFT algorithm was implemented for TensorFlow v1 as a library. This work presents the newly added native support in TensorFlow v2 for Distributed FFT, through the new TensorFlow distribution API, DTensor.

About DTensor

DTensor is an extension to TensorFlow for synchronous distributed computing. It distributes the program and tensors through a procedure called Single program, multiple data (SPMD) extension. DTensor offers an uniform API for traditional data and model parallelism patterns used widely in Machine Learning.

Example Usage

The API interface for distributed FFT is the same as the original FFT in TensorFlow. Users just need to pass a sharded tensor as an input to the existing FFT ops in TensorFlow, such as tf.signal.fft2d. The output of a distributed FFT becomes sharded too.

import TensorFlow as tf
from TensorFlow.experimental import dtensor

# Set up devices
device_type = dtensor.preferred_device_type()
if device_type == 'CPU':
cpu = tf.config.list_physical_devices(device_type)
tf.config.set_logical_device_configuration(cpu[0], [tf.config.LogicalDeviceConfiguration()] * 8)
if device_type == 'GPU':
gpu = tf.config.list_physical_devices(device_type)
tf.config.set_logical_device_configuration(gpu[0], [tf.config.LogicalDeviceConfiguration(memory_limit=1000)] * 8)
dtensor.initialize_accelerator_system()

# Create a mesh
mesh = dtensor.create_distributed_mesh(mesh_dims=[('x', 1), ('y', 2), ('z', 4)], device_type=device_type)

# Set up a distributed input Tensor
input = tf.complex(
tf.random.stateless_normal(shape=(2, 2, 4), seed=(1, 2), dtype=tf.float32),
tf.random.stateless_normal(shape=(2, 2, 4), seed=(2, 4), dtype=tf.float32))
init_layout = dtensor.Layout(['x', 'y', 'z'], mesh)
d_input = dtensor.relayout(input, layout=init_layout)

# Run distributed fft2d. DTensor determines the most efficient
#
layout of of d_output.
d_output = tf.signal.fft2d(d_input)

Performance Analysis

The following experiment demonstrates that the distributed FFT can process more data than the non-distributed one by utilizing memory across multiple devices. The tradeoff is spending additional time on communication and data transposes that slow down the calculation speed.

Graph of performance on different machines, measuri8ng wall clock time in seconds by size per dimension across single GPU, Distributed FFT and Undistributed FFT

This phenomenon is shown in detail from the profiling result of the 10K*10K distributed FFT experiment. The current implementation of distributed FFT in TensorFlow follows the simple shuffle+local FFT method, which is also used by other popular distributed FFT libraries such as FFTW and PFFT. Notably, the two local FFT ops only take 3.6% of the total time (15ms). This is around 1/3 of the time for non-distributed fft2d. Most of the computing time is spent on data shuffling, represented by the ncclAllToAll Operation. Note that these experiments were conducted on an 8xV100 GPU system.

Table of Top 10 TensorFlow operations on GPU highlighting two local FFT ops in the top 3

Next steps

The feature is new and we have adopted a simplest distributed FFT algorithm. A few ideas to fine tune or improve the performance are:

  • Switch to a different DFT/FFT algorithm.
  • Tweaks on the NCCL communication settings for the particular FFT sizes may improve utilization of the network bandwidth and increase the speed.
  • Reducing the number of collectives to minimize bandwidth requirements.
  • Use N-d local FFTs, rather than multiple 1-d local FFTs.

Try the new distributed FFT! We welcome your feedback on the TensorFlow Forum and look forward to working with you on improving the performance. Your input would be invaluable!

Read More

The TensorFlow Lite Plugin for Flutter is Officially Available

The TensorFlow Lite Plugin for Flutter is Officially Available

Posted by Paul Ruiz, Developer Relations Engineer

We’re excited to announce that the TensorFlow Lite plugin for Flutter has been officially migrated to the TensorFlow GitHub account and released!

Three years ago, Amish Garg, one of our talented Google Summer of Code contributors, wrote a widely used TensorFlow Lite plugin for Flutter. The plugin was so popular that we decided to migrate it to our official repo, making it easier to maintain directly by the Google team. We are grateful to Amish for his contributions to the TensorFlow Lite Flutter plugin.

Through the efforts of developers in the community, the plugin has been updated to the latest version of TensorFlow Lite, and a collection of new features and example apps have been added, such as object detection through a live camera feed.

Moving image of a live camera feed showing several objects on a work desk being detected

So what is TensorFlow Lite? TensorFlow Lite is a way to run TensorFlow models on devices locally, supporting mobile, embedded, web, and edge devices. TensorFlow Lite’s cross-platform support and on-device performance optimizations make it a great addition to the Flutter development toolbox. Our goal with this plugin is to make it easy to integrate TensorFlow Lite models into Flutter apps across mobile platforms, with desktop support currently in development through the efforts of our developer community. Find pre-trained TensorFlow Lite models on model repos like Kaggle Models or create your own custom TensorFlow Lite models.

Let’s take a look at how you could use the Flutter TensorFlow Lite plugin for image classification:

TensorFlow Lite Image Classification with Flutter

First you will need to install the plugin from pub.dev. Once the plugin is installed, you can load a TensorFlow Lite model into your Flutter app and define the input and output tensor shapes. If you’re using the MobileNet model, then the input tensor will be a 224 by 224 RGB image, and the output will be a list of confidence scores for the trained labels.

// Load model
Future<void> _loadModel() async {
final options = InterpreterOptions();

// Load model from assets
interpreter = await Interpreter.fromAsset(modelPath, options: options);
// Get tensor input shape [1, 224, 224, 3]
inputTensor = interpreter.getInputTensors().first;
// Get tensor output shape [1, 1001]
outputTensor = interpreter.getOutputTensors().first;
}

To make things a bit more organized, you can also load in the labels for the 1000 items that MobileNet is trained for:

// Load labels from assets
Future<void> _loadLabels() async {
final labelTxt = await rootBundle.loadString(labelsPath);
labels = labelTxt.split('n');
}

For the sake of being succinct, let’s go ahead and skip some of the pre-processing steps, though you can find them in the repo’s image classification example here.

When you’re ready to run inference, you can create a new input and output based on the tensor shapes that you defined earlier, then call run on the interpreter to get your final results.

// Run inference
Future<void> runInference(
List<List<List<num>>> imageMatrix,
)
async {
// Tensor input [1, 224, 224, 3]
final input = [imageMatrix];
// Tensor output [1, 1001]
final output = [List<int>.filled(1001, 0)];

   // Run inference
   interpreter.run(input, output);

   // Get first output tensor
   final result = output.first;

Now that you have your results, you can match them to your labels and use them in your app.

Moving image of a live camera feed showing several objects on a work desk being correctly identified in the app

What’s next?

To explore what else you can do with the Flutter TensorFlow Lite plugin, check out the official GitHub repository where you can find examples for text classification, super resolution, style transfer, and more!

Additionally, we are working on a new plugin specifically for MediaPipe Tasks, a low-code tool for easily performing common on-device machine learning tasks. This includes image classification and object detection, like you’ve just learned about, as well as audio classification, face landmark detection, and gesture recognition, alongside a whole lot more.

We look forward to all the exciting things you make, so be sure to share them with @googledevs, @TensorFlow, and your developer communities!

Read More

Simpleperf case study: Fast initialization of TFLite’s Memory Arena

Simpleperf case study: Fast initialization of TFLite’s Memory Arena

Posted by Alan Kelly, Software Engineer

One of our previous articles, Optimizing TensorFlow Lite Runtime Memory, discusses how TFLite’s memory arena minimizes memory usage by sharing buffers between tensors. This means we can run models on even smaller edge devices. In today’s article, I will describe the performance optimization of the memory arena initialization so that our users get the benefit of low memory usage with little additional overhead.

ML is normally deployed on-device as part of a larger pipeline. TFLite is used because it’s fast and lightweight, but the rest of the pipeline must also be fast. Profiling on the target device with representative data lets us identify the slowest parts of the pipeline so that we can optimize the most important part of the code.

In this article, I will describe the profiling and optimization of TFLite’s memory arena with instructions on how to use Simpleperf and visualize the results. Sample commands are given. It is assumed that the Android NDK is installed and that you have a development device that you can connect to using adb.

Simpleperf

Simpleperf comes with some scripts to make it easier to use. run_simpleperf_on_device.py pushes simpleperf to the device and runs your binary with the given arguments.

/usr/lib/android-ndk/simpleperf/run_simpleperf_on_device.py record –call-graph fp /data/local/tmp/my_binary arg0 arg1 …

This will generate the output file perf.data which you must then copy back to your computer.

adb pull /data/local/tmp/perf.data

You then generate the binary cache which contains all the information needed later to generate a useful profile.

/usr/lib/android-ndk/simpleperf/binary_cache_builder.py -lib /your/binarys/folder -i perf.data

And generate the proto buffer used for visualization:

/usr/lib/android-ndk/simpleperf/pprof_proto_generator.py –ndk_path=/path/to/android-ndk -i perf.data -o profile.proto

You can then display the results this using pprof:

pprof -http :8888 profile.proto

And open localhost:8888 in your browser to view the profile. I find flame graphs to be the most useful:

Optimizing TFLite’s Memory Arena

ArenaPlanner::ExecuteAllocations accounts for 54.3% of the runtime of this model. I was expecting to find that ML operators such as fully connected layers or convolutions to be the bottleneck of this model, and not runtime overhead. This is a particularly bad case, the memory arena overhead isn’t this bad for every model, but improvements here will impact all models. This model has variable input sizes and many dynamic tensors, whose output size isn’t known until operator evaluation, which trigger frequent tensor re-allocations. This really is as bad as it gets. Let’s zoom in on the profile.

InterpreterInfo::num_tensors() accounts for 10.4% of the runtime. The reason this function is so expensive is because it is a virtual function which calls another function and it is called within a loop. I would never have suspected this.

for (int i = 0; i < static_cast<int>(graph_info_->num_tensors()); ++i) {

}

Arena planner does not create or destroy tensors so the number of tensors is constant. Let’s cache it.

const int num_tensors = static_cast<int>(graph_info_->num_tensors());
for (int i = 0; i < num_tensors); ++i) {

}

Our next piece of low hanging fruit is InterpreterInfo::tensor(unsigned long) which is another virtual function which does bounds checking and then returns a pointer to a tensor. Tensors are stored in an array so let’s add a function to get a pointer to this array. The commits are here and here.

After these simple changes the runtime of this model has reduced by 25% and then overhead of the memory allocator by half. Simpleperf made identifying these inefficiencies easy! Time to profile again to measure the impact of these changes.

ArenaPlanner::CalculateAllocations is now the most expensive function at 12.7%. This calls two functions: SimpleMemoryArena::Allocate and ArenaPlanner::CreateTensorAllocationVector.

Although ArenaPlanner::CreateTensorAllocationVector is the cheaper of the two, the code is far simpler so it might be easier to optimize. This function identifies which tensors are allocated between two nodes in the graph and then sorts them by size as a Greedy by Size allocation algorithm is used where the largest tensors are allocated first. The structure of the graph is constant so we can store a map of tensors allocated at each node. Instead of checking each tensor in the model to see if it is allocated between the two nodes, we can identify the tensors to be allocated by iterating through the map. The cost of ArenaPlanner::CreateTensorAllocationVector has gone from 4.8% to 0.8% of the runtime. Code can be seen here. Sort does not appear in the profile so we ignore it.

The next function to look at is ArenaPlanner::ResolveTensorAllocation which is 10.9% of the runtime after the previous optimizations. This function resets each tensor’s data pointer after allocation. However, these pointers don’t always change. How about keeping track of and only updating the ones which change? After this change, ArenaPlanner::ResolveTensorAllocation doesn’t appear in the profile anymore.

Let’s now take a look at allocation and deallocation. SimpleMemoryArena::Allocate accounts for 7% and SimpleMemoryArena::Deallocate accounts for 6.8% of the runtime. A record of all allocations in the arena are stored in a vector ordered by their offsets within the memory arena. Entries in this sorted data structure are inserted, removed and searched. These operations are all O(N) in a vector. Could an std::multimap with the offset as the key be better? A multimap is needed because the records are ordered by their offsets and there may be multiple tensors with the same offset. Removal and insertion are O(logN) but search would still be O(N) as we are searching for the tensor id and not the offset. The best way to find out is to test and profile.

Replacing the vector with a multimap actually slows down the arena code: it is almost three times slower than using a vector! While this goes against intuition, this is commonly found when optimizing code. Operations on a set or a map have linear or logarithmic complexities, however, there is also a constant value in the complexity. This value is higher than the constant value for the complexity of a vector. We also iterate through the records, which is much cheaper for a vector than for a list or multimap. A list was also tested, coming in at twice as slow as a vector.

Deallocation can still be improved though. SimpleMemoryArena::Deallocate iterates through the records and when it finds the record to deallocate, it removes it from the vector. This has O(N2) complexity. The memcpy seen in the profile above comes from the frequent calls to std::vector::erase. It is much more efficient to mark records to be erased and then to erase them in one pass using std::remove_if. The second optimization here is to look at how tensors are typically deallocated: ArenaPlanner::ResetAllocationsAfter deallocates all tensors from a node until the end of the graph. To address this, SimpleMemoryArena::DeallocateAfter(int32_t node) was added which iterates once through all the records, marking those which are allocated after the node. SimpleMemoryArena::ResolveDeallocations erases these in one pass making deallocation O(N). After these changes, ResetAllocationsAfter no longer appears in the profile! Commits are here and here.

The profile now looks very different with the overhead of tensor allocation gone from 49.9% of the runtime to 11%. This profile already looks much more reasonable. SimpleMemoryArena::Allocate is the last function left to optimize. For each tensor, this function iterates through the vector of records trying to find space for the current allocation. The complexity of this is O(N2). This is a fundamental limitation of the Greedy By Size algorithm. Efficient use of memory comes at the cost of increased overhead. Although the complexity can’t be reduced, N can. We process nodes in the order in which they are executed. Allocation information for tensors which have been deallocated on already executed nodes is not needed anymore, it is only slowing things down. Records are purged periodically so that only records which are active are considered. On a large model, this significantly reduces N. ArenaPlanner::ExecuteAllocations is no longer the most expensive function! It has gone from 11% to 6% and a fully connected operator is now the most expensive function in the profile, which is what we expect when profiling neural network inference.

This is what a neural network profile should look like. Time should be spent running your model’s operators, not in the inference runtime.

The optimized memory arena is now publicly available as part of TensorFlow 2.13.

Next Steps

Today’s post walked you through an example of Simpleperf helping to find easy to fix inefficiencies in TFLite’s memory arena that would never have been found by just looking at the code. Pprof can display annotated source code, disassembly and graphs making it easy to find the bottlenecks in your on-device pipelines.

Read More

What's new in TensorFlow 2.13 and Keras 2.13?

What’s new in TensorFlow 2.13 and Keras 2.13?

Posted by the TensorFlow and Keras Teams

TensorFlow 2.13 and Keras 2.13 have been released! Highlights of this release include publishing Apple Silicon wheels, the new Keras V3 format being default for .keras extension files and many more!

TensorFlow Core

Apple Silicon wheels for TensorFlow

TensorFlow 2.13 is the first version to provide Apple Silicon wheels, which means when you install TensorFlow on an Apple Silicon Mac, you will be able to use the latest version of TensorFlow. The nightly builds for Apple Silicon wheels were released in March 2023 and this new support will enable more fine-grained testing, thanks to technical collaboration between Apple, MacStadium, and Google.

tf.lite

The Python TensorFlow Lite Interpreter bindings now have an option to use experimental_disable_delegate_clustering flag to turn-off delegate clustering during delegate graph partitioning phase. You can set this flag in TensorFlow Lite interpreter Python API

interpreter = new Interpreter(file_of_a_tensorflowlite_model, experimental_preserve_all_tensors=False)

The flag is set to False by default. This is an advanced feature in experimental that is designed for people who insert explicit control dependencies via with tf.control_dependencies() or need to change graph execution order.

Besides, there are several operator improvements in TensorFlow Lite in 2.13

  • add operation now supports broadcasting up to 6 dimensions. This will remove explicit broadcast ops from many models. The new implementation is also much faster than the current one which calculates the entire index for both inputs the the input instead of only calculating the part that changes.
  • Improve the coverage for 16×8 quantization by enabling int16x8 ops for exp, mirror_pad, space_to_batch_nd, batch_to_space_nd
  • Increase the coverage of integer data types
    • enabled int16 for less, greater_than, equal, bitcast, bitwise_xor, right_shift, top_k, mul, and int16 indices for gather and gather_nd
    • enabled int8 for floor_div and floor_mod, bitwise_xor, bitwise_xor
    • enabled 32-bit int for bitcast, bitwise_xorright_shift

tf.data

We have improved usability and added functionality for tf.data APIs.

tf.data.Dataset.zip now supports Python-style zipping. Previously users were required to provide an extra set of parentheses when zipping datasets as in Dataset.zip((a, b, c)). With this change, users can specify the datasets to be zipped simply as Dataset.zip(a, b, c) making it more intuitive.

Additionally, tf.data.Dataset.shuffle now supports full shuffling. To specify that data should be fully shuffled, use dataset = dataset.shuffle(dataset.cardinality()). This will load the full dataset into memory so that it can be shuffled, so make sure to only use this with datasets of filenames or other small datasets.

We have also added a new tf.data.experimental.pad_to_cardinality transformation which pads a dataset with zero elements up to a specified cardinality. This is useful for avoiding partial batches while not dropping any data.

Example usage:

ds = tf.data.Dataset.from_tensor_slices({‘a’: [1, 2]})
ds = ds.apply(tf.data.experimental.pad_to_cardinality(3))
list(ds.as_numpy_iterator())
[{‘a’: 1, ‘valid’: True}, {‘a’: 2, ‘valid’: True}, {‘a’: 0, ‘valid’: False}]

This can be useful, e.g. during eval, when partial batches are undesirable but it is also important not to drop any data.

oneDNN BF16 Math Mode on CPU

oneDNN supports BF16 math mode where full FP32 tensors are implicitly down-converted to BF16 during computations for faster execution time. TensorFlow CPU users can enable this by setting the environment variable TF_SET_ONEDNN_FPMATH_MODE to BF16. This mode may negatively impact model accuracy. To go back to full FP32 math mode, unset the variable.

Keras

Keras Saving format

The new Keras V3 saving format, released in TF 2.12, is now the default for all files with the .keras extension.

You can start using it now by calling model.save(“your_model.keras”).

It provides richer Python-side model saving and reloading with numerous advantages:

  • A lightweight, faster format:
ALT TEXT
  • Human-readable: The new format is name-based, with a more detailed serialization format that makes debugging much easier. What you load is exactly what you saved, from Python’s perspective.
  • Safer: Unlike SavedModel, there is no reliance on loading via bytecode or pickling – a big advancement for secure ML, as pickle files can be exploited to cause arbitrary code execution at loading time.
  • More general: Support for non-numerical states, such as vocabularies and lookup tables, is included in the new format.
  • Extensible: You can add support for saving and loading exotic state elements in custom layers using save_assets(), such as a FIFOQueue – or anything else you want. You have full control of disk I/O for custom assets.

The legacy formats (h5 and Keras SavedModel) will stay supported in perpetuity. However, we recommend that you consider adopting the new Keras v3 format for saving/reloading in Python runtimes, and using model.export() for inference in all other runtimes (such as TF Serving).

Read More

On-device fetal ultrasound assessment with TensorFlow Lite

On-device fetal ultrasound assessment with TensorFlow Lite

Posted by Angelica Willis and Akib Uddin, Health AI Team, Google Research

How researchers at Google are working to expand global access to maternal healthcare with the help of AI

TensorFlow Lite* is an open-source framework to run machine learning models on mobile and edge devices. It’s popular for use cases ranging from image classification, object detection, speech recognition, natural language tasks, and more. From helping parents of deaf children learn sign language, to predicting air quality, projects using TensorFlow Lite are demonstrating how on-device ML could directly and positively impact lives by making these socially beneficial applications of AI more accessible, globally. In this post, we describe how TensorFlow Lite is being used to help develop ultrasound tools in under-resourced settings.

Motivation

According to the WHO, complications from pregnancy and childbirth contribute to roughly 287,000 maternal deaths and 2.4 million neonatal deaths worldwide each year. As many as 95% of these deaths occur in under-resourced settings and many are preventable if detected early. Obstetric diagnostics, such as determining gestational age and fetal presentation, are important indicators in planning prenatal care, monitoring the health of the birthing parent and fetus, and determining when intervention is required. Many of these factors are traditionally determined by ultrasound.

Advancements in sensor technology have made ultrasound devices more affordable and portable, integrating directly with smartphones. However, ultrasound requires years of training and experience, and, in many rural or underserved regions, there is a shortage of trained ultrasonography experts, making it difficult for people to access care. Due to this global lack of availability, it has been estimated that as many as two-thirds of pregnant people in these settings do not receive ultrasound screening during pregnancy.

Expanding access by enabling non-experts

Google Research is building AI models to help expand access to ultrasound, including models to predict gestational age and fetal presentation, to allow health workers with no background in ultrasonography to collect clinically useful ultrasound scans. These models make predictions from ultrasound video obtained using an easy-to-teach operating procedure, a blind sweep protocol, in which a user blindly sweeps the ultrasound probe over the patient’s abdomen. In our recent paper, “A mobile-optimized artificial intelligence system for gestational age and fetal malpresentation assessment”, published in Nature Communications Medicine, we demonstrated that, when utilizing blind sweeps, these models enable these non-experts to match standard of care performance in predicting these diagnostics.

Blind Sweep Operating Procedure

Moving images depicting blind-sweep ultrasound technique from a transverse view on the left and aerial view on the right
This blind-sweep ultrasound acquisition procedure can be performed by non-experts with only a few hours of ultrasound training.
Figure A on the left is a chart sshowing gestatinal age model performance comparing the blind sweep based gestational age with the clinical standard of care method indicating the percentile absolute error in days. Figure B on the right is a graphical representation of the fetal presentation model performance showing the Receiver Operating Characteristic curves for blind sweep based fetal malpresentation for cases collected by novices or expert sonographers.
Figure A compares our blind sweep-based gestational age regression model performance with that of the clinical standard of care method for fetal age estimation from fetal biometry measured by expert sonographers. Boxes indicate 25th, 50th, and 75th percentile absolute error in days, and whiskers indicate 5th and 95th percentile absolute error (n = 407 study participants). Figure B shows the Receiver Operating Characteristic (ROC) curves for our blind sweep-based fetal malpresentation classification model, as well as specific performance curves for cases in which blind sweeps were collected by expert sonographers or novices (n = 623 study participants). See our recent paper for further details and additional analysis.

Model development

Understanding that our target deployment environment is one in which users might not have reliable access to power and internet, we designed these models to be mobile-optimized. Our grouped convolutional LSTM architecture utilizes MobileNetV2 for feature extraction on each video frame as it is received. The final feature layer produces a sequence of image embeddings which are processed by the convolutional LSTM cell state. Since the recurrent connections only operate on the less memory-intensive embeddings, this model can run efficiently in a mobile environment.

For each subsequence of video frames that make up a sweep, we generate a clip-level diagnostic result, and in the case of gestational age, also produce a model confidence estimate represented as the predicted variance in the detected age. Clip-level gestational age predictions are aggregated via inverse variance weighting to produce a final case-level prediction.

Flow chart depicting Gestational Age LSTM Video Model

Optimization through TensorFlow Lite

On-device ML has many advantages, including providing enhanced privacy and security by ensuring that sensitive input data never needs to leave the device. Another important advantage of on-device ML, particularly for our use case, is the ability to leverage ML offline in regions with low internet connectivity, including where smartphones serve as a stand-in for more expensive traditional devices. Our prioritization of on-device ML made TensorFlow Lite a natural choice for optimizing and evaluating the memory use and execution speed of our existing models, without significant changes to model structure or prediction performance.

After converting our models to TensorFlow Lite using the converter API, we explored various optimization strategies, including post-training quantization and alternative delegate configurations. Leveraging a TensorFlow Lite GPU delegate, optimized for sustained inference speed, provided the most significant boost to execution speed. There was a roughly 2x speed improvement with no loss in model accuracy, which equated to real-time inference of more than 30 frames/second with both the gestational age and fetal presentation models running in parallel on Pixel devices. We benchmarked model initialization time, inference time and memory usage for various delegate configurations using TensorFlow Lite performance measurement tools, finding the optimal configuration across multiple mobile device manufacturers.

These critical speed improvements allow us to leverage the model confidence estimate to provide sweep-quality feedback to the user immediately after the sweep was captured. When low-quality sweeps are detected, users can be provided with tips on how their sweep can be improved (for example, applying more pressure or ultrasound gel), then prompted to re-do the sweep.

Screen capture of sweep exam being conducted
We developed a mobile application that demonstrates what a potential user experience could look like and allows us to evaluate our TensorFlow Lite models in realistic environments. This app enables ultrasound video frames to be received directly from portable ultrasound devices that support this use case.

Looking ahead

Our vision is to enable safer pregnancy journeys using AI-driven ultrasound that could broaden access globally. We want to be thoughtful and responsible in how we develop our AI to maximize positive benefits and address challenges, guided by our AI Principles. TensorFlow Lite has helped enable our research team to explore, prototype, and de-risk impactful care-delivery strategies designed with the needs of lower-resource communities in mind.

This research is in its early stages and we look forward to opportunities to expand our work. To achieve our goals and scale this technology for wider reach globally, partnerships are critical. We are excited about our partnerships with Northwestern Medicine in the US and Jacaranda Health in Kenya to further develop and evaluate these models. With more automated and accurate evaluations of maternal and fetal health risks, we hope to lower barriers and help people get timely care.

Acknowledgements

This work was developed by an interdisciplinary team within Google Research: Ryan G. Gomes, Chace Lee, Angelica Willis, Marcin Sieniek, Christina Chen, James A. Taylor, Scott Mayer McKinney, George E. Dahl, Justin Gilmer, Charles Lau, Terry Spitz, T. Saensuksopa, Kris Liu, Tiya Tiyasirichokchai, Jonny Wong, Rory Pilgrim, Akib Uddin, Greg Corrado, Lily Peng, Katherine Chou, Daniel Tse, & Shravya Shetty.

This work was developed in collaboration with:
Department of Obstetrics and Gynaecology, University of Zambia School of Medicine, Lusaka, Zambia
Department of Obstetrics and Gynecology, University of North Carolina School of Medicine, Chapel Hill, NC, USA
UNC Global Projects—Zambia, LLC, Lusaka, Zambia

Special thanks to: Yun Liu, Cameron Chen, Sami Lachgar, Lauren Winer, Annisah Um’rani, and Sachin Kotwani

*TensorFlow Lite has not been certified or validated for clinical, medical, or diagnostic purposes. TensorFlow Lite users are solely responsible for their use of the framework and independently validating any outputs generated by their project.

Read More

Augmenting recommendation systems with LLMs

Augmenting recommendation systems with LLMs

Posted by Wei Wei, Developer Advocate

Large language models (LLMs) are taking the world by storm, thanks to their powerful ability to generate text, translate languages, and answer questions in a coherent and informative way. At Google I/O 2023, we released the PaLM API as ‘public preview’ so that many developers can start building apps with it. While PaLM API already has excellent documentation on its extensive usage and best practices, in this blog we are going to take a more focused approach to explore how to leverage LLMs to augment your ML systems in a practical application: recommendation systems.

As a refresher, modern recommendation systems often follow a retrieval-ranking architecture, which enables them to effectively and efficiently filter and rank relevant items to maximize the utility in production. You can go through this codelab to learn about building a fullstack movie recommendation system using TensorFlow and Flutter.

Demonstration of the retrieval-ranking architecture where candidate items move from retrieval to ranking, then post-ranking.

We will discuss how LLMs can be incorporated into this retrieval-ranking pipeline.

Conversational recommendations

If you already have access to Bard, you can ask it to create recommendations for you interactively in a dialogue. Here is an example of asking Bard for movie recommendations:

A user asks 'I'm in the mood for some drama movies with artistic elements tonight. Could you recommend three? Titles only. No other text' Bard responds 'Sure, here are three drama movies with artistic elements that you might enjoy: The Tree of Life, The Piano Teacher, The Passion of Joan of Arc'

As a developer, you can build a similar functionality in your own applications, using the PaLM API Chat service with minimal effort:

prompt = """You are a movie recommender and your job is to recommend new movies based on user input.
So for user 42, he is in the mood for some drama movies with artistic elements tonight.
Could you recommend three? Output the titles only. Do not include other text."""

response = palm.chat(messages=prompt)
print(response.last)

# Sure, here are three drama movies with artistic elements that I recommend for user 42:
#
# 1. The Tree of Life (2011)
# 2. 20th Century Women (2016)
# 3. The Florida Project (2017)
#
# I hope you enjoy these movies!

The PaLM API also allows you to help your user continue the exploration and interactively refine the recommendations (e.g., asking to swap The Florida Project for another one) in a dialogue, which is what Chat service is designed for. This kind of conversational recommendation interface (think having a knowledgeable chatbot that guides a customer along the way in your shopping app) provides a fluid and personalized experience for the users, and can sometimes be a very appealing addition to your existing recommendation surfaces.

Sequential recommendations

Recommendations would be much more useful if your system knows what your users may like. One way to find out your users’ interest is looking at their historical activities and then extrapolating. This is often called ‘sequential recommendation’ because the recommender looks at the sequence of items that have been interacted with and infers what to recommend. Usually you need to use a ML library (i.e., TensorFlow Recommenders) to achieve this. But now with the power of LLMs, you can also do this with the PaLM API Text service:

prompt = """You are a movie recommender and your job is to recommend new movies based on the sequence of movies that a user has watched. You pay special attention to the order of movies because it matters.

User 42 has watched the following movies sequentially:

"Margin Call",
“The Big Short”,
"Moneyball",
"The Martian",

Recommend three movies and rank them in terms of priority. Titles only. Do not include any other text.
"""

response = palm.generate_text(
model="models/text-bison-001", prompt=prompt, temperature=0
)
print(response.result)

# 1. The Wolf of Wall Street
# 2. The Social Network
# 3. Inside Job

This example prompts the Text service with 4 movies that have been watched and asks the PaLM API to generate new recommendations based on the sequence of past movies.

Rating predictions

In the ranking phase of modern recommendation engines, a list of candidates needs to be sorted based on certain criteria. This is usually done by using a learning-to-rank library (such as, TensorFlow Ranking) to predict the ordering. Now you can do this with the PaLM API. Here is an example of predicting movie ratings:

prompt = """You are a movie recommender and your job is to predict a user's rating (ranging from 1 to 5, with 5 being the highest) on a movie, based on that user's previous ratings.

User 42 has rated the following movies:
"Moneyball" 4.5
"The Martian" 4
"Pitch Black" 3.5
“12 Angry Men” 5

Predict the user's rating on "The Matrix". Output the rating score only. Do not include other text.
"""
response = palm.generate_text(model="models/text-bison-001", prompt=prompt)
print(response.result)

# 4.5

The PaLM API predicted a high score for The Matrix. You can ask the PaLM API to predict a rating for a list of candidate movies one by one and then sort them in order before making final recommendations; this process is called ‘pointwise ranking’. You can even leverage the PaLM API to do pairwise ranking or listwise ranking, if you adjust the prompt accordingly.

For a more comprehensive study on rating prediction with LLMs, you can refer to this paper from Google.

Text embedding-based recommendations

At this point you may be asking: all the use cases so far involve well known movies that the LLM is already aware of, so maybe there is a requirement that candidate items need to be captured in LLMs in advance (in the training phase)? What if I have private items not known to LLMs beforehand? How could I use the PaLM API then?

Not to worry. The PaLM API for Embeddings can help you out in this case. The basic idea is to embed text associated with your items (for example, product description, movie plot) into vectors and use nearest neighbor search techniques (i.e., using the tf.math.top_k op from TensorFlow for brute force search or Google ScaNN/Chroma for approximate search) to identify similar items to recommend, based on a user query. Let’s walk through a simple example.

Suppose you are building a news app and at the bottom of each news article you want to recommend similar news to your users. First you can embed all news articles by calling the PaLM API Embedding service like below:

embedding = palm.generate_embeddings(model='embedding-gecko-001', text='example news article text')['embedding']

For simplicity, let’s assume you store all the news texts and their embeddings in a simple Pandas DataFrame with 2 columns: news_text and embedding. Then you can recommend interestings news to your users using the following:

def recommend_news(query_text, df, topk=5):
"""
Recommend news based on user query
"""

query_embedding = palm.generate_embeddings(model='embedding-gecko-001', text=query_text)
dot_products = np.dot(np.stack(df['embedding']), query_embedding['embedding'])
result = tf.math.top_k(dot_products, k=topk)
indices = result.indices.numpy()
return df.loc[indices]['news_text']

recommend_news('news currently being read', dataframe, 5)

The recommend_news function computes the query embedding’s dot product similarity with all news articles using the pre-computed embeddings, and then identifies 5 news articles most similar to what your user is reading.

This approach is often a quick and effective way to generate candidates and create recommendations based on item similarities. It may be sufficient for many use cases and can be particularly useful in the item cold start situation.

In practice, the candidate generation phase of modern large scale recommenders often consists of multiple sources. For example, you can use a mixer of text embedding-based retrieval, collaborative filtering, users’ subscriptions (i.e., new uploads from followed accounts on YouTube), real time trending items (i.e., breaking news) and etc. Thus, leveraging the PaLM API Embedding service could be a helpful augment for the retrieval stage in your existing recommendation system.

An example of text embedding-based retrieval.

Text embeddings as side features

In addition, you could also use the text embeddings as side features in a recommendation model. The text embeddings capture the semantic information of the candidate items via the description text and can potentially help improve the model accuracy. For example, in this TensorFlow Recommenders feature preprocessing tutorial, if you have pre-computed text embeddings for movie plot using LLMs, it’s fairly easy to inject them into the model as side features, when concatenating all the embeddings:

class MovieModel(tf.keras.Model):

# ......

def call(self, inputs):
return tf.concat(
[
self.title_embedding(inputs["movie_title"]),
self.title_text_embedding(inputs["movie_title"]),
inputs["movie_plot_embedding"], # inject movie plot embedding
],
axis=1,
)

The default PaLM Embedding service returns a vector of 768 floating numbers for any text, which may be too much. You can reduce the dimensions by initializing a tf.keras.layers.Embedding layer with the movie plot embedding matrix and then stacking a fully connected layer on top of it to project it down to fewer dimensions.

Conclusion

We have shared several ideas on leveraging LLMs to augment recommenders. Obviously, this is just scratching the surface as there are more not covered here. Also note that there may still be a long way before they can make it into production (i.e., latency and cost issues). But we hope this blog inspires you to start thinking about how you can improve your own recommendation systems with LLMs.

Lastly, we are holding an online Developer Summit on Recommendation Systems on June 9, 2023. If you want to learn more about Google products related to building recommendation systems, feel free to sign up here to attend.

Read More

Visualizing and interpreting decision trees

Visualizing and interpreting decision trees

Posted by Terence Parr, Google

Decision trees are the fundamental building block of Gradient Boosted Trees and Random Forests, the two most popular machine learning models for tabular data. To learn how decision trees work and how to interpret your models, visualization is essential.

TensorFlow recently published a new tutorial that shows how to use dtreeviz, a state-of-the-art visualization library, to visualize and interpret TensorFlow Decision Forest Trees.

The dtreeviz library, first released in 2018, is now the most popular visualization library for decision trees. The library is constantly being updated and improved, and there is a large community of users who can provide support and answer questions. There is a helpful YouTube video and article on the design of dtreeviz.

Let’s demonstrate how to use dtreeviz to interpret decision tree predictions.

At a basic level, a decision tree is a machine learning model that learns the relationship between observations and target values by examining and condensing training data into a binary tree. Each leaf in the decision tree is responsible for making a specific prediction. For regression trees, the prediction is a value, such as price. For classifier trees, the prediction is a target category, such as cancer or not-cancer.

Any path from the root of the decision tree to a specific leaf predictor passes through a series of (internal) decision nodes. Each decision node compares a single feature’s value with a specific split point value learned during training. Making a prediction means walking from the root down the tree, comparing feature values, until we reach a leaf. Consider the following simple decision tree that tries to classify animals based upon two features, the number of legs and the number of eyes.

Illustration of a simple decision tree to select an animal based on number of legs (more than or equal to 4; if no = penguin, and/or number of eyes (more than or equal to three; if yes = spider, if no = dog)

Let’s say that our test animal has four legs and two eyes. To classify the test animal, we start at the root of the tree and compare our test animal’s number of legs to four. Since the number of legs is equal to four, we move to the left. Next, we test the number of eyes to three. Since our test animal only has two eyes, we move to the right and arrive at a leaf node, which gives us a prediction of dog. To learn more, check out this class on decision trees.

To interpret decision tree predictions we use dtreeviz to visualize how each decision node in the tree splits up a specific feature’s domain, and to show the distribution of training instances in each leaf. For example, here is the first few levels of a classification tree from a Random Forest trained on the Penguin data set:

Illustration of the first few levels of a classification tree from a Random Forest trained on the Penguin data set

To make a prediction for a test penguin, this decision tree first tests the flipper_length_mm feature and if it’s less than 206, it descends to the left and then tests the island feature; otherwise, if the flipper length were >= 206, it would descend to the right and test the bill_length_mm feature. (Check out the tutorial for a description of the visualization elements.)

The code used to generate that tree is short. Given a classifier model called cmodel, we collect and wrap up all of the information about the data and model then ask dtreeviz to visualize the tree:

penguin_features = [f.name for f in cmodel.make_inspector().features()]

penguin_label = “species”   # Name of the classification target label

viz_cmodel = dtreeviz.model(cmodel,

                            tree_index=3, # pick tree from forest

                            X_train=train_ds_pd[penguin_features],

                            y_train=train_ds_pd[penguin_label],

                            feature_names=penguin_features,

                            target_name=penguin_label,

                            class_names=classes)

viz_cmodel.view()

And here are the first few layers of a regressor tree from a Random Forest trained on the Abalone data set:

Illustration of the first few layers of a regressor tree from a Random Forest trained on the Abalone data set

Another useful tool for interpretation is to visualize how a specific test instance (feature vector) weaves its way down the tree from the root to a specific leaf. By looking at the path taken by a decision tree when making a prediction, we learn why a test instance was classified in a particular way. We know which features are tested and against what range of values. Imagine being rejected for a bank loan. Looking at the decision tree could tell us exactly why we were rejected (e.g., credit score too low or debt to income ratio too high). Here’s an example showing the decisions made by the decision tree for a specific Penguin instance, with the path highlighted in orange boxes and the test instance features shown at the bottom left:

Illustration of the decisions made by the decision tree for a specific Penguin instance, with the path highlighted in orange boxes and the test instance features

You can also look at information about the leaf contents by calling viz_cmodel.ctree_leaf_distributions(). For example, here’s a plot showing the leaf ID versus samples-per-class for the Penguin dataset:

Bar diagram showing the leaf ID versus samples-per-class for the Penguin dataset

For regressors, the leaf plot shows the distribution of the target (predicted) variable for the instances in each leaf, such as in this plot from an Abalone decision tree:

Plot diagram an Abalone decision tree

Each “row” in this plot represents a specific leaf and the blue dots indicate the distribution of the rings prediction values for instances associated with that leaf by the training process.

The library can do lots more; this is just a taste. Your next step is to check out the tutorial! Then, try dtreeviz on your own tree models. To dig deeper into how decision trees are built and how they carve up feature space to make predictions, you can watch the YouTube video or the article on the design of dtreeviz. Enjoy!

Read More