My experience with TensorFlow Quantum

My experience with TensorFlow Quantum

A guest post by Owen Lockwood, Rensselaer Polytechnic Institute

Quantum mechanics was once a very controversial theory. Early detractors such as Albert Einstein famously said of quantum mechanics that “God does not play dice” (referring to the probabilistic nature of quantum measurements), to which Niels Bohr replied, “Einstein, stop telling God what to do”. However, all agreed that, to quote John Wheeler “If you are not completely confused by quantum mechanics, you do not understand it”. As our understanding of quantum mechanics has grown, not only has it led to numerous important physical discoveries but it also resulted in the field of quantum computing. Quantum computing is a different paradigm of computing from classical computing. It relies on and exploits the principles of quantum mechanics to achieve speedups (in some cases superpolynomial) over classical computers.

In this article I will be discussing some of the challenges I faced as a beginning researcher in quantum machine learning (QML) and how TensowFlow Quantum (TFQ) and Cirq enabled me and can help other researchers investigate the field of quantum computing and QML. I have previous experience with TensorFlow, which made the transition to using TensorFlow Quantum seamless. TFQ proved instrumental in enabling my work and ultimately my work utilizing TFQ culminated in my first publication on quantum reinforcement learning in the 16th AIIDE conference. I hope that this article helps and inspires other researchers, neophytes and experts alike, to leverage TFQ to help advance the field of QML.

QML Background

QML has important similarities and differences to traditional neural network/deep learning approaches to machine learning. Both methodologies can be seen as using “stacked layers” of transformations that make up a larger model. In both cases, data is used to inform updates to model parameters, typically to minimize some loss function (usually, but not exclusively via gradient based methods). Where they differ is QML models have access to the power of quantum mechanics and deep neural networks do not. An important type of QML that TFQ provides techniques for is called variational quantum circuits (QVC). QVCs are also called quantum neural networks (QNN).

A QVC can be visualized below (from the TFQ white paper). The diagram is read from left to right, with the qubits being represented by the horizontal lines. In a QVC there are three important and distinct parts: the encoder circuit, the variational circuit and the measurement operators. The encoder circuit either takes naturally quantum data (i.e. a nonparametrized quantum circuit) or converts classical data into quantum data. This circuit is connected to the variational circuit which is defined by its learnable parameters. The parametrized part of the circuit is the part that is updated during the learning process. The last part of the QVC is the measurement operators. In order to extract information from the QVC some sort of quantum measurement (such as a Pauli X, Y, or Z basis measurements) must be applied. With the information extracted from these measurements a loss function (and gradients) can be calculated on a classical computer and the parameters can be updated. These gradients can be optimized with the same optimizers as traditional neural networks such as Adam or RMSProp. QVC’s can also be combined with traditional neural networks (as is shown in the diagram) as the quantum circuit is differentiable and thus gradients can be backpropagated through.

TensorFlow image

However, the intuitive models and mathematical framework surrounding QVCs have some important differences from traditional neural networks, and in these differences lies the potential for quantum speedups. Quantum computing and QML can harness quantum phenomena, such as superposition and entanglement. Superposition stems from the wavefunction being a linear combination of multiple states and enables a qubit to represent two different states simultaneously (in a probabilistic manner). The ability to operate on these superpositions, i.e. operate on multiple states simultaneously, is integral to the power of quantum computing. Entanglement is a complex phenomenon that is induced via multi-qubit gates. Getting a basic understanding of these concepts and of quantum computing is an important first step for QML. There are a number of great resources available for this such as Preskill’s Quantum Computing Course, and de Wolf’s lecture notes.

Currently, access to real quantum hardware is limited and as such, many quantum computing researchers conduct work on simulations of quantum computers. Near term and current quantum devices have 10s-100s of quantum bits (qubits) like the Google sycamore processor. Because of their size and the noise, this hardware is often called Noisy Intermediate Scale Quantum (NISQ) technology. TFQ and Cirq are built for these near term NISQ devices. These devices are far smaller than what some of the most famous quantum algorithms require to achieve quantum speedups given current error correction techniques; e.g. Shor’s algorithm requires upwards of thousands of qubits and the Quantum Approximate Optimization Algorithm (QAOA) could require at least 420 qubits for quantum advantages. However, there is still significant potential for NISQ devices to achieve quantum speedups (as Google demonstrated with 53 qubits).

My Work With TFQ

TFQ was announced in mid March this year (2020) and I began to use it shortly after. Around that time I had begun research into QML, specifically QML for reinforcement learning (RL). While there have been great strides in the accessibility of quantum circuit simulators, QML can be a difficult field to get into. Not only are there difficulties from a mathematical and physical perspective, but the time investment for implementation can be substantial. The time it would take to code QVCs from scratch and properly test and debug them (not to mention the optimization) is a challenge, especially for those coming from classical machine learning. Spending so much time building something for an experiment that has the potential to not work is a big risk – especially for an undergraduate student on a deadline! Thankfully I did not have to take this risk. With the release of TFQ it was easy to immediately start on the implementations of my ideas. Realistically, I would never have done this work if TFQ had not been released. Inspired by previous work, we expanded upon applying QML to RL tasks.

In our work we demonstrate the potential to use QVCs in place of neural networks in contemporary RL algorithms (specifically DQN and DDQN). We also show the potential to use multiple types of QVC models, using QVCs with either a dense layer or quantum pooling layers (denoted hybrid and pure respectively) to shrink the number of qubits to the correct output space.

TensorFlow Graph

The representational power of QVCs is also put on display; using a QVC with ~50 parameters we were able to achieve comparable performance to neural networks with orders of magnitude more parameters. See the graphs for a comparison of the reward achieved on the canonical CartPole environment (balancing a pole on a cart), the left graph includes all neural networks and the right shows only the largest neural network. The number in front of the NN represents the size of the parameter space.

We are continuing to work with QML applications to RL and have more manuscripts in submission. Continuation of this work has been accepted into the 2020 NeurIPS workshop: “The pre-registration experiment: an alternative publication model for machine learning research”.

Suggested use of TFQ

TFQ can be an incredible tool for anyone interested in QML research no matter your background. All too common in scientific communities is a ‘publish or perish’ mentality which can stifle innovative work and is prohibitive to intellectual risk taking, especially for experiments that require significant implementation efforts. Not only can TFQ help speed up any experiments you may have, but it also allows for easy implementation of ideas that would otherwise never get tested. Implementation is a common hindrance to new and interesting ideas, and far too many projects never progress out of the idea stage due to difficulties in transitioning the idea to reality, something TFQ makes easy.

For beginners, TFQ enables a first foray into the field without substantial time investment in coding and allows for significant learning. Being able to try and experiment with QVCs without having to build from the ground up is an incredible tool. For classical ML researchers with experience in TensorFlow, TFQ makes it easy to transition and experiment with QML at small or large scales. The API of TFQ and the modules it provides (i.e. Keras-esque layers and differentiators) share design principles with TF and their similarities make for an easier programming transition. For researchers already in the QML field, TFQ can certainly help.

In order to get started with TFQ it is important to become familiar with the basics of quantum computing, with either the references mentioned above or with any of the many other great resources out there. Another important step that often gets overlooked, is reading the TFQ white paper. The white paper is accessible to QML beginners and is an invaluable introduction to QML and the basic as well as advanced usage of TFQ. Just as important is to play around with TFQ. Try different things out, experiment; it is a great way to expand not only understanding of the software but of the theory and mathematics as well. Reading other contemporary papers and the papers that cite TFQ is a great way to become immersed with the current research going on in the field.

Read More

Videos from the TensorFlow User Group Summit in India

Videos from the TensorFlow User Group Summit in India

Posted by Siddhant Agarwal and Biswajeet Mallik, Program Managers

Logo of TFUG India Summit

TensorFlow has a strong developer community in India with 13 TensorFlow User Groups and 20+ Google Developer Experts. In September, these groups came together to organise the “TFUG India Summit“, a 4-day online event with four tracks. You can check out the recordings for these talks below.

Read More

A new YouTube show: TensorFlow.js Community Show & Tell

A new YouTube show: TensorFlow.js Community Show & Tell

Posted by Jason Mayes, Developer Relations Engineer for TensorFlow.js

The TensorFlow YouTube channel has a new show called “TensorFlow.js Community Show & Tell.” In this program, we highlight amazing tech demos from the TensorFlow.js community every quarter. Our next show will be on 11th December 9AM PT over on the TensorFlow YouTube channel, but if you missed the previous ones, you can find past episodes on this playlist.

TensorFlowJS Community Show and Tell thumbnail

About the show

Do you love great tech demos that push the boundaries of what is possible for a given industry? If that sounds like you and you’re looking for fresh inspiration, along with insights from the engineers themselves, then this may be the YouTube show for you.

After hacking with many wonderful folk in the TensorFlow.js community it became clear to us that the creativity and the work you were producing was simply incredible. For that reason, we have put together a brand new format, known as the TensorFlow Show & Tell to showcase top projects that are being made and give developers a platform to share their work. With many subscribers who are as passionate about machine learning as we are, we figured this was a great way to do that and connect great minds.

If you missed our latest show you can check our most recent broadcast here:

We have seen a whole bunch of amazing demos on the first 3 shows using machine learning in truly novel ways. From making music from thin air using your hands which is then web connected to a Tesla coil (yes, that actually happened), to combining machine learning models with mind blowing mixed reality and 3D graphics, we have had a good variety of presenters from all around the world.

Who has presented so far?

We’ve had 25 presenters, and there are more in the making!

1st Show: Rogerio Chaves (Netherlands), Max Bittker (USA), Junya Ishihara (Japan), Ben Farrell (USA), Lizzie Siegle (USA), Manish Raj (India), Yiwen Lin (UK), Jimmy (Thailand)

2nd Show: Cyril Diagne (France), John Cohn (USA), Va Barbosa (USA), FollowTheDarkside (Japan), Jaume Sanchez (UK), Olesya Chernyavskaya (Russia), Alexandre Devaux (France), Amruta (India), Shan Huan (China).

3rd Show: Gant Laborde (USA), Hugo Zanini (Brazil), Charlie Gerard (Amsterdam), Shivay Lamba (India), Anders Jessen (Denmark), Benson Ruan (Australia), Cristina Maillo (Spain), James Seo (USA).

How can I get on the show?

If you have made something using TensorFlow.js simply use the #MadeWithTFJS hashtag over on Twitter or LinkedIn with a post demonstrating your creation and link to try it out. We will select our top picks each quarter to be featured in the show and reach out to you directly if you are selected. You can view existing submissions for Twitter as an example.

I missed an episode – where can I view them?

Catch up on previous live episodes via this playlist, and if you wish to watch shorter bite sized videos over a coffee break you can do so with the Made With TensorFlow.js playlist to allow you to watch just the ones that are relevant to you on demand.

When is the next show?

Join us in December for the next show and tell – we have 6 wonderful new demos lined up. We’re aiming for 11th December, 9AM PT, over on the TensorFlow YouTube channel.

Be sure to subscribe and click the bell icon to set notifications for when new videos are posted (we are aiming for every quarter), or add a Google calendar reminder and see you there!

Also, if you enjoyed the show and would like to see this format replicated for TensorFlow Core, TensorFlow Lite, and more, let us know! You can drop me a message on Twitter or LinkedIn – would love to see what you have made.

Acknowledgements

A huge thank you to all 24 of our show & tell presenters thus far, and to the amazing community who have submitted projects or helped tag amazing finds for us to reach out to. You truly make this show what it is and we are super excited to see even more in the future and look forward to seeing you all again soon.

Read More

Using Model Card Toolkit for TF Model Transparency

Using Model Card Toolkit for TF Model Transparency

Posted by Karan Shukla, Software Engineer, Google Research

Machine learning (ML) model transparency is important across a wide variety of domains that impact peoples’ lives, from healthcare to personal finance to employment. At Google, this desire for transparency led us to develop Model Cards, a framework for transparent reporting on ML model performance, provenance, ethical considerations and more. It can be time consuming, however, to compile the information necessary to create a useful Model Card. To address this, we recently announced the open-source launch of Model Card Toolkit (MCT), a collection of tools that supports ML developers in compiling the information that goes into a Model Card.

The toolkit consists of:

  • A JSON schema, which specifies the fields to include in the Model Card
  • A ModelCard data API to represent an instance of the JSON schema and visualize it as a Model Card
  • A component that uses the model provenance information stored with ML Metadata (MLMD) to automatically populate the JSON with relevant information

We wanted the toolkit to be modular so that Model Card creators can still leverage the JSON schema and ModelCard data API even if their modeling environment is not integrated with MLMD. In this post, we’ll show you how you can use these components to create a Model Card for a Keras MobileNetV2 model trained on ImageNet and fine-tuned on the cats_vs_dogs dataset available in TensorFlow Datasets (TFDS). While this model and use case may be trivial from a transparency standpoint, it allows us to easily demonstrate the components of MCT.

Model card for Fine-tuned MobileNetV2 Model for Cats vs Dogs
An example Model Card. Click here for a larger version.

Model Card Toolkit Walkthrough

You can follow along and run the code yourself in the Colab notebook. In this walkthrough, we’ll include some additional information about the considerations you’ll want to keep in mind while using the toolkit.

We begin by installing the Model Card Toolkit.

!pip install 'model-card-toolkit>=0.1.1,

Now, we load both the MobileNetV2 model and the weights generated by fine-tuning the model on the cats_vs_dogs dataset. For more information on how we fine-tuned our model, you can see the TensorFlow tutorial on the topic.

URL = 'https://storage.googleapis.com/cats_vs_dogs_model/cats_vs_dogs_model.zip'
BASE_PATH = tempfile.mkdtemp()
ZIP_PATH = os.path.join(BASE_PATH, 'cats_vs_dogs_model.zip')
MODEL_PATH = os.path.join(BASE_PATH,'cats_vs_dogs_model')

r = requests.get(URL, allow_redirects=True)
open(ZIP_PATH, 'wb').write(r.content)

with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
zip_ref.extractall(BASE_PATH)

model = tf.keras.models.load_model(MODEL_PATH)

We also calculate the number of examples, storing it to the “examples” object, and the accuracy scores, disaggregated across class. We’ll use both accuracy and examples later to build graphs to display in our Model Card.

examples = cats_vs_dogs.get_data()
accuracy = compute_accuracy(examples['combined'])
cat_accuracy = compute_accuracy(examples['cat'])
dog_accuracy = compute_accuracy(examples['dog'])

Next, we’ll use the Model Card Toolkit to create our Model Card. The first step is to initialize a ModelCardToolkit object, which maintains assets including a Model Card JSON file and Model Card document. Call ModelCardToolkit.scaffold_assets() to generate these assets and return a ModelCard object.

model_card_dir = tempfile.mkdtemp()
mct = ModelCardToolkit(model_card_dir)
model_card = mct.scaffold_assets()

We then populate the Model Card’s fields. First, we’ll fill in the model_card.model_details section, which contains basic metadata fields.

We begin by specifying the model’s name, writing a brief description of the model in the overview section.

model_card.model_details.name = 'Fine-tuned MobileNetV2 Model for Cat vs. Dogs'
model_card.model_details.overview = (
'This model distinguishes cat and dog images. It uses the MobileNetV2 '
'architecture (https://arxiv.org/abs/1801.04381) and is trained on the '
'Cats vs Dogs dataset '
'(https://www.tensorflow.org/datasets/catalog/cats_vs_dogs). This model '
'performed with high accuracy on both Cat and Dog images.'
)

We provide the model’s owners, version, and references.

model_card.model_details.owners = [
{'name': 'Model Cards Team', 'contact': 'model-cards@google.com'}
]
model_card.model_details.version = {'name': 'v1.0', 'data': '08/28/2020'}
model_card.model_details.references = [
'https://www.tensorflow.org/guide/keras/transfer_learning',
'https://arxiv.org/abs/1801.04381',
]

Finally, we share the model’s license information, and a url that future users can cite if they choose to reuse the model in the citation section.

model_card.model_details.license = 'Apache-2.0'
model_card.model_details.citation = 'https://github.com/tensorflow/model-card-toolkit/blob/master/model_card_toolkit/documentation/examples/Standalone_Model_Card_Toolkit_Demo.ipynb'

The model_card.quantitative_analysis field contains information about a model's performance metrics. Here, we’ve created some synthetic performance metric values for a hypothetical model built on our dataset.

model_card.quantitative_analysis.performance_metrics = [
{'type': 'accuracy', 'value': accuracy},
{'type': 'accuracy', 'value': cat_accuracy, 'slice': 'cat'},
{'type': 'accuracy', 'value': dog_accuracy, 'slice': 'Dog'},
]

model_card.considerations contains qualitative information about your model. In particular, we recommend including some, or all of the following information:

Use cases: What are the intended use cases for this model? This is pretty straightforward for our model:

model_card.considerations.use_cases = [
'This model classifies images of cats and dogs.'
]

Limitations: What technical limitations should users keep in mind? What kinds of data cause your model to fail, or underperform? In our case, examples that are not dogs or cats will cause our model to fail, so we’ve acknowledged this:

model_card.considerations.limitations = [
'This model is not able to classify images of other animals.'
]

Ethical considerations: What ethical considerations should users be aware of when deciding whether or not to use the model? In what contexts could the model raise ethical concerns? What steps did you take to mitigate ethical concerns?

model_card.considerations.ethical_considerations = [{
'name':
'While distinguishing between cats and dogs is generally agreed to be '
'a benign application of machine learning, harmful results can occur '
'when the model attempts to classify images that don’t contain cats or '
'dogs.',
'mitigation_strategy':
'Avoid application on non-dog and non-cat images.'
}]

Lastly, you can include graphs in your Model Card. We recommend including graphs that reflect the distributions in both your training and evaluation datasets, as well as graphs of your model’s performance on evaluation data. model_card has sections for each of these:

  • model_card.model_parameters.data.train.graphics for training dataset statistics
  • model_card.model_parameters.data.eval.graphics for evaluation dataset statistics
  • model_card.quantitative_analysis.graphics for quantitative analysis of model performance

For this Model Card, we’ve included Matplotlib graphs of our validation set size and the model’s accuracy, both separated by class. Please visit the associated Colab if you’d like to see the Matplotlib code. If you are using ML Metadata, these graphs will be generated automatically (as demonstrated in this Colab). You can also use other visualization libraries, like Seaborn.

We add our graphs to our Model Card.

model_card.model_parameters.data.eval.graphics.collection = [
{'name': 'Validation Set Size', 'image': validation_set_size_barchart},
]
model_card.quantitative_analysis.graphics.collection = [
{'name': 'Accuracy', 'image': accuracy_barchart},
]

We’re finally ready to generate our Model Card! Let’s do that now. First we need to update the ModelCardToolkit object with the latest ModelCard.

mct.update_model_card_json(model_card)

Lastly, we generate the Model Card document in the chosen output format.

# Generate a model card document in HTML (default)
html_doc = mct.export_format()

# Display the model card document in HTML
display.display(display.HTML(html_doc))



# Generate a model card document in Markdown
md_path = os.path.join(model_card_dir, 'template/md/default_template.md.jinja')
md_doc = mct.export_format(md_path, 'model_card.md')

# Display the model card document in Markdown
display.display(display.Markdown(md_doc))
Model Card for Fine-tuned MobileNetV2 Model for Cats vs Dogs

And we’ve generated our Model Card! It’s a good idea to review the end product with your direct team, as well as members who are further away from the project. In particular, we recommend reviewing the qualitative fields such as “ethical considerations” to ensure you’ve adequately captured all potential use cases and their potential consequences. Does your Model Card answer the questions that people from different backgrounds might have? Is the language accessible to a developer? What about a policy maker, or a downstream user who might interact with the model? In the future, we hope to offer Model Card creators more guidance that they can use to help answer these questions and provide more thorough instructions on how to fill out the considerations fields.

Have questions? Have Model Cards to share? Let us know at model-cards@google.com!

Acknowledgements

Huanming Fang, Hui Miao, Karan Shukla, Dan Nanas, Catherina Xu, Christina Greer, Neoklis Polyzotis, Tulsee Doshi, Tiffany Deng, Margaret Mitchell, Timnit Gebru, Andrew Zaldivar, Mahima Pushkarna, Meena Natarajan, Roy Kim, Parker Barnes, Tom Murray, Susanna Ricco, Lucy Vasserman, and Simone Wu

Read More

Accelerating TensorFlow Performance on Mac

Accelerating TensorFlow Performance on Mac

Posted by Pankaj Kanwar and Fred Alcober

Apple M1 logo

With TensorFlow 2, best-in-class training performance on a variety of different platforms, devices and hardware enables developers, engineers, and researchers to work on their preferred platform. TensorFlow users on Intel Macs or Macs powered by Apple’s new M1 chip can now take advantage of accelerated training using Apple’s Mac-optimized version of TensorFlow 2.4 and the new ML Compute framework. These improvements, combined with the ability of Apple developers being able to execute TensorFlow on iOS through TensorFlow Lite, continue to showcase TensorFlow’s breadth and depth in supporting high-performance ML execution on Apple hardware.

Performance on the Mac with ML Compute

The Mac has long been a popular platform for developers, engineers, and researchers. With Apple’s announcement last week, featuring an updated lineup of Macs that contain the new M1 chip, Apple’s Mac-optimized version of TensorFlow 2.4 leverages the full power of the Mac with a huge jump in performance.

ML Compute, Apple’s new framework that powers training for TensorFlow models right on the Mac, now lets you take advantage of accelerated CPU and GPU training on both M1- and Intel-powered Macs.

For example, the M1 chip contains a powerful new 8-Core CPU and up to 8-core GPU that are optimized for ML training tasks right on the Mac. In the graphs below, you can see how Mac-optimized TensorFlow 2.4 can deliver huge performance increases on both M1- and Intel-powered Macs with popular models.

Training impact on common models using ML Compute on M1- and Intel-powered 13-inch MacBook Pro are shown in seconds per batch, with lower numbers indicating faster training time.
Training impact on common models using ML Compute on M1- and Intel-powered 13-inch MacBook Pro are shown in seconds per batch, with lower numbers indicating faster training time.
Training impact on common models using ML Compute on the Intel-powered 2019 Mac Pro
Training impact on common models using ML Compute on the Intel-powered 2019 Mac Pro are shown in seconds per batch, with lower numbers indicating faster training time.

Getting Started with Mac-optimized TensorFlow

Users do not need to make any changes to their existing TensorFlow scripts to use ML Compute as a backend for TensorFlow and TensorFlow Addons.

To get started, visit Apple’s GitHub repo for instructions to download and install the Mac-optimized TensorFlow 2.4 fork.

In the near future, we’ll be making updates like this even easier for users to get these performance numbers by integrating the forked version into the TensorFlow master branch.

You can learn more about the ML Compute framework on Apple’s Machine Learning website.

Footnotes:

  1. Testing conducted by Apple in October and November 2020 using a preproduction 13-inch MacBook Pro system with Apple M1 chip, 16GB of RAM, and 256GB SSD, as well as a production 1.7GHz quad-core Intel Core i7-based 13-inch MacBook Pro system with Intel Iris Plus Graphics 645, 16GB of RAM, and 2TB SSD. Tested with prerelease macOS Big Sur, TensorFlow 2.3, prerelease TensorFlow 2.4, ResNet50V2 with fine-tuning, CycleGAN, Style Transfer, MobileNetV3, and DenseNet121. Performance tests are conducted using specific computer systems and reflect the approximate performance of MacBook Pro.
  2. Testing conducted by Apple in October and November 2020 using a production 3.2GHz 16-core Intel Xeon W-based Mac Pro system with 32GB of RAM, AMD Radeon Pro Vega II Duo graphics with 64GB of HBM2, and 256GB SSD. Tested with prerelease macOS Big Sur, TensorFlow 2.3, prerelease TensorFlow 2.4, ResNet50V2 with fine-tuning, CycleGAN, Style Transfer, MobileNetV3, and DenseNet121. Performance tests are conducted using specific computer systems and reflect the approximate performance of Mac Pro.

Read More

Applying MinDiff to Improve Model Fairness

Applying MinDiff to Improve Model Fairness

Posted by Summer Misherghi and Thomas Greenspan, Software Engineers, Google Research

Last December, we open-sourced Fairness Indicators, a platform that enables sliced evaluation of machine learning model performance. This type of responsible evaluation is a crucial first step toward avoiding bias as it allows us to determine how our models are working for a wide variety of users. When we do identify that our model underperforms on certain slices of our data, we need a strategy to mitigate this to avoid creating or reinforcing unfair bias, in line with Google’s AI Principles.

Today, we’re announcing MinDiff, a technique for addressing unfair bias in machine learning models. Given two slices of data, MinDiff works by penalizing your model for differences in the distributions of scores between the two sets. As the model trains, it will try to minimize the penalty by bringing the distributions closer together. MinDiff is the first in what will ultimately be a larger Model Remediation Library of techniques, each suitable for different use cases. To learn about the research and theory behind MinDiff, please see our post on the Google AI Blog.

MinDiff Walkthrough

You can follow along and run the code yourself in this MinDiff notebook. In this walkthrough, we’ll emphasize important points in the notebook, while providing context on fairness evaluation and remediation.

In this example, we are training a text classifier to identify written content that could be considered “toxic.” For this task, our baseline model will be a simple Keras sequential model pre-trained on the Civil Comments dataset. Since this text classifier could be used to automatically moderate forums on the internet (for example, to flag potentially toxic comments), we want to ensure that it works well for everyone. You can read more about how fairness problems can arise in automated content moderation in this blog post.

To attempt to mitigate potential fairness concerns, we will:

  1. Evaluate our baseline model’s performance on text containing references to sensitive groups.
  2. Improve performance on any underperforming groups by training with MinDiff.
  3. Evaluate the new model’s performance on our chosen metric.

Our purpose is to demonstrate usage of the MinDiff technique for you with a minimal workflow, not to lay out a complete approach to fairness in machine learning. Our evaluation will only focus on one sensitive category and a single metric. We also don’t address potential shortcomings in the dataset, nor tune our configurations.

In a production setting, you would want to approach each of these with more rigor. For example:

  • Consider the application space and the potential societal impact of your model; what are the implications of different types of model errors?
  • Consider additional categories for which underperformance might have fairness implications. Do you have sufficient examples for groups in each category?
  • Consider any privacy implications to storing the sensitive categories.
  • Consider any metric for which poor performance could translate into harmful outcomes.
  • Conduct thorough evaluation for all relevant metrics on multiple sensitive categories.
  • Experiment with the configuration of MinDiff by tuning hyperparameters to get optimal performance.

For the purpose of this blog post, we’ll skip building and training our baseline model, and jump right to evaluating its performance. We’ve used some utility functions to compute our metrics and we’re ready to visualize evaluation results (See “Render Evaluation Results” in the notebook):

 widget_view.render_fairness_indicator(eval_result)  
TensorFlow image

Let’s look at the evaluation results. Try selecting the metric false positive rate (FPR) with threshold 0.450. We can see that the model does not perform as well for some religious groups as for others, displaying a much higher FPR. Note the wide confidence intervals on some groups because they have too few examples. This makes it difficult to say with certainty that there is a significant difference in performance for these slices. We may want to collect more examples to address this issue. We can, however, attempt to apply MinDiff for the two groups that we are confident are underperforming.

We’ve chosen to focus on FPR, because a higher FPR means that comments referencing these identity groups are more likely to be incorrectly flagged as toxic. This could lead to inequitable outcomes for users engaging in dialogue about religion, but note that disparities in other metrics can lead to other types of harm.

Now, we’ll try to improve the FPR for religious groups for which our model underperforms. We’ll attempt to do so using MinDiff, a remediation technique that seeks to balance error rates across slices of your data by penalizing disparities in performance during training. When we apply MinDiff, model performance may degrade slightly on other slices. As such, our goals with MinDiff will be to improve performance for underperforming groups, while sustaining strong performance for other groups and overall.

To use MinDiff, we create two additional data splits:

  • A split for non-toxic examples referencing minority groups: In our case, this will include comments with references to our underperforming identity terms. We don’t include some of the groups because there are too few examples, leading to higher uncertainty with wide confidence interval ranges.
  • A split for non-toxic examples referencing the majority group.

It’s important to have sufficient examples belonging to the underperforming classes. Based on your model architecture, data distribution, and MinDiff configuration, the amount of data needed can vary significantly. In past applications, we have seen MinDiff work well with at least 5,000 examples in each data split.

In our case, the groups in the minority splits have example quantities of 9,688 and 3,906. Note the class imbalances in the dataset; in practice, this could be cause for concern, but we won’t seek to address them in this notebook since our intention is just to demonstrate MinDiff.

We select only negative examples for these groups, so that MinDiff can optimize on getting these examples right. It may seem counterintuitive to carve out sets of ground truth negative examples if we’re primarily concerned with disparities in false positive rate, but remember that a false positive prediction is a ground truth negative example that’s incorrectly classified as positive, which is the issue we’re trying to address.

To prepare our data splits, we create masks for the sensitive & non-sensitive groups:

minority_mask = data_train.religion.apply(
lambda x: any(religion in x for religion in ('jewish', 'muslim')))
majority_mask = data_train.religion.apply(
lambda x: x == "['christian']")

Next, we select negative examples, so MinDiff will be able to reduce FPR for sensitive groups:

true_negative_mask = data_train['toxicity'] == 0

data_train_main = copy.copy(data_train)
data_train_sensitive = (
data_train[minority_mask & true_negative_mask])
data_train_nonsensitive = (
data_train[majority_mask & true_negative_mask])

To start training with MinDiff, we need to convert our data to TensorFlow Datasets (not shown here — see “Create MinDiff Datasets” in the notebook for details). Don’t forget to batch your data for training. In our case, we set the batch sizes to the same value as the original dataset but this is not a requirement and in practice should be tuned.

dataset_train_sensitive = dataset_train_sensitive.batch(BATCH_SIZE)
dataset_train_nonsensitive = (
dataset_train_nonsensitive.batch(BATCH_SIZE))

Once we have prepared our three datasets, we merge them into one MinDiff dataset using a util function provided in the library.

min_diff_dataset = md.keras.utils.pack_min_diff_data(
dataset_train_main,
dataset_train_sensitive,
dataset_train_nonsensitive)

To train with MinDiff, simply take the original model and wrap it in a MinDiffModel with a corresponding `loss` and `loss_weight`. We are using 1.5 as the default `loss_weight`, but this is a parameter that needs to be tuned for your use case, since it depends on your model and product requirements. You should experiment with changing the value to see how it impacts the model, noting that increasing it pushes the performance of the minority and majority groups closer together but may come with more pronounced tradeoffs.

As specified above, we create the original model, and wrap it in a MinDiffModel. We pass in one of the MinDiff losses and use a moderately high weight of 1.5.

original_model = ...  # Same structure as used for baseline model. 

min_diff_loss = md.losses.MMDLoss()
min_diff_weight = 1.5
min_diff_model = md.keras.MinDiffModel(
original_model, min_diff_loss, min_diff_weight)

After wrapping the original model, we compile the model as usual. This means using the same loss as for the baseline model:

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss = tf.keras.losses.BinaryCrossentropy()
min_diff_model.compile(
optimizer=optimizer, loss=loss, metrics=['accuracy'])

We fit the model to train on the MinDiff dataset, and save the original model to evaluate (see API documentation for details on why we don’t save the MinDiff model).

min_diff_model.fit(min_diff_dataset, epochs=20)

min_diff_model.save_original_model(
min_diff_model_location, save_format='tf')

Finally, we evaluate the new results.

min_diff_eval_subdir = 'eval_results_min_diff'
min_diff_eval_result = util.get_eval_results(
min_diff_model_location, base_dir, min_diff_eval_subdir,
validate_tfrecord_file, slice_selection='religion')

To ensure we evaluate a new model correctly, we need to select a threshold the same way that we would the baseline model. In a production setting, this would mean ensuring that evaluation metrics meet launch standards. In our case, we will pick the threshold that results in a similar overall FPR to the baseline model. This threshold may be different from the one you selected for the baseline model. Try selecting false positive rate with threshold 0.400. (Note that the subgroups with very low quantity examples have very wide confidence range intervals and don’t have predictable results.)

 widget_view.render_fairness_indicator(min_diff_eval_result) 
TensorFlow Image

Note: The scale of the y-axis has changed from .04 in the graph for the baseline model to .02 for our MinDiff model

Reviewing these results, you may notice that the FPRs for our target groups have improved. The gap between our lowest performing group and the majority group has improved from .024 to .006. Given the improvements we’ve observed and the continued strong performance for the majority group, we’ve satisfied both of our goals. Depending on the product, further improvements may be necessary, but this approach has gotten our model one step closer to performing equitably for all users.

MinDiff Chart

To get a better sense of scale, we superimposed the MinDiff model on top of the base model.

You can get started with MinDiff by visiting the MinDiff page on tensorflow.org. More information about the research behind MinDiff is available in our post on the Google AI Blog. You can also learn more about evaluating for fairness in this guide.

Acknowledgements

The MinDiff framework was developed in collaboration with Thomas Greenspan, Summer Misherghi, Sean O’Keefe‎, Christina Greer, Catherina Xu‎, Manasi Joshi, Dan Nanas, Nick Blumm, Jilin Chen, Zhe Zhao, James Chen, Maciej Kula, Lichan Hong, Mahesh Sathiamoorthy. This research effort on ML Fairness in classification was jointly led by (in alphabetical order) Alex Beutel, Ed H. Chi, Flavien Prost, Hai Qian, Jilin Chen, Shuo Chen, and Tulsee Doshi. Further, this work was pursued in collaboration with Christine Luu, Jonathan Bischof, Pierre Kreitmann, and Qiuwen Chen.

Read More

Characterizing quantum advantage in machine learning by understanding the power of data

Characterizing quantum advantage in machine learning by understanding the power of data

Posted by Hsin-Yuan Huang, Google/Caltech, Michael Broughton, Google, Jarrod R. McClean, Google, Masoud Mohseni, Google.

Data drives machine learning. Large scale research and production ML both depend on high volume and high quality sources of data where it is often the case that more is better. The use of training data has enabled machine learning algorithms to achieve better performance than traditional algorithms at recognizing photos, understanding human languages, or even tasks that have been championed as premier applications for quantum computers such as predicting properties of novel molecules for drug discovery and the design of better catalysts, batteries, or OLEDs. Without data, these machine learning algorithms would not be any more useful than traditional algorithms.

While existing machine learning models run on classical computers, quantum computers provide the potential to design quantum machine learning algorithms that achieve even better performance, especially for modelling quantum-mechanical systems like molecules, catalysts, or high-temperature superconductors. Because the quantum world allows the superposition of exponentially many states to evolve and interfere at the same time while classical computers struggle at such tasks, one would expect quantum computers to have an advantage in machine learning problems that have a quantum origin. The hope is that such quantum advantage (the advantage of using quantum computers instead of classical computers) extends to machine learning problems in the classical domain, like computer vision or natural language processing.

TensorFlow image

Learning from data obtained in nature (e.g., physical experiments) can enable classical computers to solve some problems that are hard for classical computers without the data. However, an even larger class of problems can be solved using quantum computers. If we assume nature can be simulated using quantum computers, then quantum computers with data obtained in nature will not have further computational power.

To understand the advantage of quantum computers in machine learning problems, the following big questions come to mind:

  1. If data comes from a quantum origin, such as from experiments developing new materials, are quantum models always better at predicting than classical models?
  2. Can we evaluate when there is a potential for a quantum computer to be helpful for modelling a given data set, from either a classical or quantum origin?
  3. Is it possible to find or construct a data set where quantum models have an advantage over their classical counterparts?

We addressed these questions and more in a recent paper by developing a mathematical framework to compare classical modelling approaches (neural networks, tree-based models, etc) against quantum modelling approaches in order to understand potential advantages in making more accurate predictions. This framework is applicable both when the data comes from the classical world (MNIST, product reviews, etc) and when the data comes from a quantum experiment (chemical reaction, quantum sensors, etc).

Conventional wisdom might suggest that the use of data coming from quantum experiments that are hard to reproduce classically would imply the potential for a quantum advantage. However, we show that this is not always the case. It is perhaps no surprise to machine learning experts that with enough data, an arbitrary function can be learned. It turns out that this extends to learning functions that have a quantum origin, too. By taking data from physical experiments obtained in nature, such as experiments for exploring new catalysts, superconductors, or pharmaceuticals, classical ML models can achieve some degree of generalization beyond the training data. This allows classical algorithms with data to solve problems that would be hard to solve using classical algorithms without access to the data (rigorous justification of this claim is given in our paper).

We provide a method to quantitatively determine the amount of samples required to make accurate predictions in datasets coming from a quantum origin. Perhaps surprisingly, sometimes there is no great difference between the number of samples needed by classical and quantum models. This method also provides a constructive approach to generate datasets that are hard to learn with some classical models.

TensorFlow study graphic

An empirical demonstration of prediction advantage using quantum models compared to the best among a list of common classical models under different amounts of training data N. The projected quantum kernel we introduce has a significant advantage over the best tested Classical ML.

We go on to use this framework to engineer datasets for which all attempted traditional machine learning methods fail, but quantum methods do not. An example of one such data set is presented in the above figure. These trends are examined empirically in the largest gate-based quantum machine learning simulations to date, made possible with TensorFlow Quantum, which is an open source library for quantum machine learning. In order to carry out this work at large quantum system sizes a considerable amount of computing power was needed. The combination of quantum circuit simulation (~300 TeraFLOP/s) and analysis code (~800 TeraFLOP/s) written using TensorFlow and TensorFlow Quantum was able to easily reach throughputs as high as 1.1 PetaFLOP/s, a scale that is rarely seen in the crossover field of quantum computing and machine learning (Though it is nothing new for classical ML which has already hit the exaflop scale).

When preparing distributed workloads for quantum machine learning research in TensorFlow Quantum, the setup and infrastructure should feel very familiar to regular TensorFlow. One way to handle distribution in TensorFlow is with the tf.distribute module, which allows users to configure device placement as well as machine configuration in multi-machine workloads. In this work tf.distribute was used to distribute workloads across ~30 Google cloud machines (some containing GPUs) managed by Kubernetes. The major stages of development were:

  1. Develop a functioning single-node prototype using TensorFlow and TensorFlow Quantum.
  2. Incorporate minimal code changes to use MultiWorkerMirroredStrategy in a manually configured two-node environment.
  3. Create a docker image with the functioning code from step 2. and upload it to the Google container registry .
  4. Using the Google Kubernetes engine, launch a job following along with the ecosystem templates provided here .

TensorFlow Quantum has a tutorial showcasing a variation of step 1. where you can create a dataset (in simulation) that is out of reach for classical neural networks.

It is our hope that the framework and tools outlined here help to enable the TensorFlow community to explore precursors of datasets that require quantum computers to make accurate predictions. If you want to get started with quantum machine learning, check out the beginner tutorials on the TensorFlow Quantum website. If you are a quantum machine learning researcher, you can read our paper addressing the power of data in quantum machine learning for more detailed information. We are excited to see what other kinds of large scale QML experiments can be carried out by the community using TensorFlow Quantum. Only through expanding the community of researchers will machine learning with quantum computers reach its full potential.

Read More

TensorFlow Community Spotlight program update

TensorFlow Community Spotlight program update

Posted by Marcus Chang, TensorFlow Program Manager

In June we started the TensorFlow Community Spotlight Program to offer the developer community an opportunity to showcase their hard work and passion for ML and AI by submitting their TensorFlow projects for the chance to be featured and recognized on Twitter with the hashtag #TFCommunitySpotlight.

GIF of posture tracking tool in use
Olesya Chernyavskaya, a Community Spotlight winner, created a tool in TensorFlow to track posture and blur the screen if a person is sitting poorly.

Now a little over four months in, we’ve received many great submissions and it’s been amazing to see all of the creative uses of TensorFlow across Python, JavaScript, Android, iOS, and many other areas of TensorFlow.

We’d like to learn about your projects, too. You can share them with us using this form. Here are our previous Spotlight winners:

Pranav Natekar

Pranav used TensorFlow to create a tool that identifies patterns in Indian Classical music to help students learn the Tabla. Pranav’s GitHub → http://goo.gle/2Z5f7Op

Chart of Indian Classical music patterns


Olesya Chernyavskaya

Working from home and trying to improve your posture? Olesya created a tool in TensorFlow to track posture and blur the screen if a person is sitting poorly. Olesya’s GitHub → https://goo.gle/2CAHvz9

GIF of posture tracking tool in use

Javier Gamazo Tejero

Javier used TensorFlow to capture movement with a webcam and transfer it to Google Street View to give a virtual experience of walking through different cities. Javier’s GitHub → https://goo.gle/3hgkmBc

GIF of virtual walking through cities

Hugo Zanini

Hugo used TensorFlow.js to create real-time semantic segmentation in a browser. Hugo’s GitHub → https://goo.gle/310RDKc

GIF of real-time semantic segmentation

Yana Vasileva

Ambianic.ai created a fall detection surveillance system in which user data is never sent to any 3rd party cloud servers. Yana’s GitHub → goo.gle/2XvYY3q

GIF of fall detection surveillance system

Samarth Gulati and Praveen Sinha

These developers had artists upload an image as a texture for TensorFlow facemesh 3D model and used CSS blend modes to give an illusion of face paint on the user’s face. Samarth’s GitHub → https://goo.gle/2Qe3Gyx

GIF of facemesh 3D model

Laetitia Hebert

Laetitia created a model system for understanding genes, neurons and behavior of the Roundworm, as it naturally moves through a variety of complex postures. Laetitia’s GitHub → https://goo.gle/2ZgLZ6L

GIF of Worm Pose

Henry Ruiz
Rigging.js is a react.js application that utilizes the facemesh Tensorflow.js model. Using a camera, it maps the movements of a person into a 3D model. Henry’s GitHub → https://goo.gle/3iCXBZj

GIF of Rigging JS

DeepPavlov.ai

The DeepPavlov AI library solves numerous NLP and NLU problems in a short amount of time using pre-trained models or training your own variations of the models. DeepPavlov’s GitHub → https://goo.gle/3jl967S

DeepPavlov AI library

Mayank Thakur

Mayank created a special hand gesture feature to go with the traditional face recognition lock systems on mobile phones that will help increase security. Mayank’s GitHub → https://goo.gle/3j7evyN

GIF of hand gesture software

Firiuza Shigapova

Using TensorFlow 2.x, Firiuza built a library for Graph Neural Networks containing GraphSage and GAT models for node and graph classification problems. Firiuza’s GitHub → https://goo.gle/3kFcmvz

GIF of Graph Neural Networks

Thank you for all the submissions thus far. Congrats to the winners, and we look forward to growing this community of Community Spotlight recipients so be sure to submit your projects here.

More information

Read More

New Coral APIs and tools for AI at the edge

New Coral APIs and tools for AI at the edge

Posted by Carlos Mendonça, Coral

Coral Fall 2020 image

Fall has finally arrived and with it a new release of Coral’s C++ and Python APIs and tools, along with new models optimized for the Edge TPU and further support for TensorFlow 2.0-based workflows.

Coral is a complete toolkit to build products with local AI. Our on-device inferencing capabilities allow you to build products that are efficient, private, fast and offline with the help of TensorFlow Lite and the Edge TPU.

From the beginning, we’ve provided APIs in Python and C++ that enable developers to take advantage of the Edge TPU’s local inference speed. Offline processing for machine learning models allows for considerable savings on bandwidth and cloud compute costs, it keeps data local, and it preserves user privacy. More recently, we’ve been hard at work to refactor our APIs and make them more modular, reusable and performant, while at the same time eliminating unnecessary API abstractions and surfacing more of the native TensorFlow Lite APIs that developers are familiar with.

So in our latest release, we’re now offering two separate reusable libraries, each built upon the powerful TensorFlow Lite APIs and each isolated in their own repositories: libcoral for C++ and PyCoral for Python.

libcoral (C++)

Unlike some of our previous APIs, libcoral doesn’t hide tflite::Interpreter. Instead, we’re making this native TensorFlow Lite class a first-class component and offering some additional helper APIs that simplify some of your code when working with common models such as classification and detection.

With our new libcoral library, developers should typically follow the pattern below to perform an inference in C++:

  1. Create tflite::Interpreter instance with the Edge TPU context and allocate memory.

    To simplify this step, libcoral provides the MakeEdgeTpuInterpreter() function:

     
    // Load the model
    auto model = coral::LoadModelOrDie(absl::GetFlag(FLAGS_model_path));

    // Get the Edge TPU context
    auto tpu_context = coral::ContainsEdgeTpuCustomOp(*model) ?
    coral::GetEdgeTpuContextOrDie() :
    nullptr;

    // Get the interpreter
    auto interpreter = coral::MakeEdgeTpuInterpreterOrDie(
    *model,
    tpu_context.get());
  2. Configure the interpreter’s input.
  3. Invoke the interpreter:

  4. interpreter->Invoke();

    As an alternative to Invoke(), you can achieve higher performance with the InvokeWithMemBuffer() and InvokeWithDmaBuffer() functions, which enable processing the input data without copying from another region of memory or from a DMA file descriptor, respectively.

  5. Process the interpreter’s output.

To simplify this step, libcoral provides some adapters, requiring less code from you:


auto result = coral::GetClassificationResults(
*interpreter,
/* threshold= */0.0f,
/*top_k=*/3);

The above is an example of the classification adapter, where developers can specify the minimum confidence threshold, as well as the maximum number of results to return. The API also features a detection adapter with its own result filtering parameters.

For a full view of the example application source code, see classify_image.cc on GitHub and for instructions on how to integrate libcoral into your application, refer to README.md on GitHub.

This new release also brings updates to on-device retraining with the decoupling of imprinting functions from inference on the updated ImprintingEngine. The new design makes the imprinting engine work with the tflite::Interpreter directly.

To easily address the Edge TPUs available on the host, libcoral supports labels such as "usb:0" or "pci:1“. This should make it easier to manage resources on multi-Edge TPU systems.

Finally, we’ve made a number of performance improvements such as more efficient memory usage and memory-based instead of file-based abstractions. Also, the design of the API is more consistent by leveraging the Abseil library for error propagation, generic interfaces and other common patterns, which should provide a more consistent and stable developer experience.

PyCoral (Python)

The new PyCoral library (provided in a new pycoral Python module) follows some of the design patterns introduced with libcoral, and brings parity across our C++ and Python APIs. PyCoral implements the same imprinting decoupling design, model adapters for classification and detection, and the same label-based TPU addressing semantics.

On PyCoral, the “run inference” functionality is now entirely delegated to the native TensorFlow Lite library, as we’ve done-away with the model “engines” that abstracted the TensorFlow interpreter. This change allowed us to eliminate the code duplication introduced by the Coral-specific BasicEngine, ClassificationEngine and DetectionEngine classes (those APIs—from the “Edge TPU Python library”—are now deprecated).

To perform an inference with PyCoral, we follow a similar pattern to that of libcoral:

  1. Create an interpreter:

  2. interpreter = edgetpu.make_interpreter(model_file)
    interpreter.allocate_tensors()
  3. Configure the interpreter’s input:

  4. common.set_input(interpreter, image)
  5. Invoke the interpreter:

  6. interpreter.invoke()
  7. Process the interpreter’s output:

  8. classes = classify.get_classes(interpreter, top_k=3)

    For fully detailed example code, check out our documentation for Python.

    Updates to the Coral model garden

    With this release, we’re further expanding the Coral model garden with MobileDet. MobileDets refer to a family of lightweight, single-shot detectors using the TensorFlow Object Detection API that achieve state-of-the-art accuracy-latency tradeoff on Edge TPUs. It is a lower-latency detection model that offers better accuracy, compared to the MobileNet family of models.

    Check out the full collection of models available from Coral for the Edge TPU, including Classification, Detection, Segmentation and models specially prepared for on-device training.

    Migrating our entire workflow and model collection to TensorFlow 2 is an ongoing effort. This release of the Coral machine learning API starts introducing support for TensorFlow 2-based workflows. For now, MobileNet v1 (ImageNet), MobileNet v2 (ImageNet), MobileNet v3 (ImageNet), ResNet50 v1 (ImageNet), and UNet MobileNet v2 (Oxford pets) all support training and conversion with TensorFlow 2.

    Model Pipelining

    Both libcoral and PyCoral have graduated the model pipelining functionality from Beta to General Availability. Model pipelining makes it possible for large models to be partitioned and distributed across multiple Edge TPUs to run them considerably faster.

    Refer to the documentation for examples of the API in C++ and Python.

    The partitioning of models is done with the Edge TPU Compiler, which employs a parameter count algorithm, partitioning the model into segments with similar parameter sizes. For cases where this algorithm doesn’t provide the throughput you need, this release is introducing a new tool that supports a profiling-based algorithm, which divides the segments based on latency observed by actually running the model multiple times, possibly resulting in a more balanced output.

    The new profiling_partition tool can be used as such:


    ./profiling_partition
    --edgetpu_compiler_binary $PATH_TO_COMPILER
    --model_path $PATH_TO_MODEL
    --output_dir $OUT_DIR
    --num_segments $NUM_SEGMENTS

    Learn more

    For more information about the Coral APIs mentioned above, see the following documentation:

Read More

Iris landmark tracking in the browser with MediaPipe and TensorFlow.js

Iris landmark tracking in the browser with MediaPipe and TensorFlow.js

Posted by Ann Yuan and Andrey Vakunov, Software Engineers at Google

Iris tracking enables a wide range of applications, such as hands-free interfaces for assistive technologies and understanding user behavior beyond clicks and gestures. Iris tracking is also a challenging computer vision problem. Eyes appear under variable light conditions, are often occluded by hair, and can be perceived as differently shaped depending on the head’s angle of rotation and the person’s expression. Existing solutions rely heavily on specialized hardware, often requiring a costly headset or a remote eye tracker system. These approaches are ill-suited for mobile devices with limited computing resources.

GIF of eye re-coloring tool in use
An example of eye re-coloring enabled.

In March we announced the release of a new package detecting facial landmarks in the browser. Today, we’re excited to add iris tracking to this package through the TensorFlow.js face landmarks detection model. This work is made possible by the MediaPipe Iris model. We have deprecated the original facemesh model, and future updates will be made to the face landmarks detection model.

Note that iris tracking does not infer the location at which people are looking, nor does it provide any form of identity recognition. In our model’s documentation and the accompanying Model Card, we detail the model’s intended uses, limitations and fairness attributes (aligned with Google’s AI Principles).

The MediaPipe iris model is able to track landmarks for the iris and pupil using a single RGB camera, in real-time, without the need for specialized hardware. The model also returns landmarks for the eyelids and eyebrow regions, enabling detection of slight eye movements such as blinking. Try the model out yourself right now in your browser.

Introducing @tensorflow/face-landmarks-detection

GIF of Facemesh predictions
Above left are predictions from @tensorflow-models/facemesh@0.0.4, above right are predictions from @tensorflow-models/face-landmarks-detection@0.0.1. Iris landmarks are in red.

Users familiar with our existing facemesh model will be able to upgrade to the new faceLandmarksDetection model with only a few code changes, detailed below. faceLandmarksDetection offers three major improvements over facemesh:

  1. Iris keypoints detection
  2. Improved eyelid contour detection
  3. Improved detection for rotated faces

These improvements are highlighted in the GIF above, which demonstrates how the landmarks returned by faceLandmarksDetection and facemesh differ for the same image sequence.

Installation

There are two ways to install the faceLandmarksDetection package:

  1. Through script tags:
  2. <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.6.0/dist/tf.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/face-landmarks-detection"></script>
  3. Through NPM (via the yarn package manager):
  4. $ yarn add @tensorflow-models/face-landmarks-detection@0.0.1
    $ yarn add @tensorflow/tfjs@2.6.0

Usage

Once the package is installed, you only need to load the model weights and then pass in an image to start detecting facial landmarks:

// If you are using NPM, first require the model. If you are using script tags, you can skip this step because `faceLandmarksDetection` will already be available in the global scope.
const faceLandmarksDetection = require('@tensorflow-models/face-landmarks-detection');

// Load the faceLandmarksDetection model assets.
const model = await faceLandmarksDetection.load(
faceLandmarksDetection.SupportedPackages.mediapipeFacemesh);

// Pass in a video stream to the model to obtain an array of detected faces from the MediaPipe graph.
// For Node users, the `estimateFaces` API also accepts a `tf.Tensor3D`, or an ImageData object.
const video = document.querySelector("video");
const faces = await model.estimateFaces({ input: video });

The input to estimateFaces can be a video, a static image, a `tf.Tensor3D` or even an ImageData object for use in node.js pipelines. FaceLandmarksDetection then returns an array of prediction objects for the faces in the input, which include information about each face (e.g. a confidence score, and the locations of 478 landmarks within the face).

Here is a sample prediction object:

{
faceInViewConfidence: 1,
boundingBox: {
topLeft: [232.28, 145.26], // [x, y]
bottomRight: [449.75, 308.36],
},
mesh: [
[92.07, 119.49, -17.54], // [x, y, z]
[91.97, 102.52, -30.54],
...
],
// x,y,z positions of each facial landmark within the input space.
scaledMesh: [
[322.32, 297.58, -17.54],
[322.18, 263.95, -30.54]
],
// Semantic groupings of x,y,z positions.
annotations: {
silhouette: [
[326.19, 124.72, -3.82],
[351.06, 126.30, -3.00],
...
],
...
}
}

Refer to our README for more details about the API.

Performance

FaceLandmarksDetection is a lightweight package containing only ~3MB of weights, making it ideally suited for real-time inference on a variety of mobile devices. When testing, note that TensorFlow.js also provides several different backends to choose from, including WebGL and WebAssembly (WASM) with XNNPACK for devices with lower-end GPU’s. The table below shows how the package performs across a few different devices and TensorFlow.js backends.:

Desktop:

Chart of desktop performance

Mobile:

All benchmarks were collected in the Chrome browser. See our earlier blogpost for details on how to activate SIMD for the TF.js WebAssembly backend.

Looking ahead

Both the TensorFlow.js and MediaPipe teams plan to add depth estimation capabilities to our face landmark detection solutions using the improved iris coordinates. We strongly believe in sharing code that enables reproducible research and rapid experimentation, and are looking forward to seeing how the wider community makes use of the MediaPipe iris model.

Try the demo!

Use this link to try our new package in your web browser. We look forward to seeing how you use it in your apps.

More information

Read More