2021 Year in Review: Google Quantum AI

Google’s Quantum AI team has had a productive 2021. Despite ongoing global challenges, we’ve made significant progress in our effort to build a fully error-corrected quantum computer, working towards our next hardware milestone of building an error-corrected quantum bit (qubit) prototype. At the same time, we have continued our commitment to realizing the potential of quantum computers in various applications. That’s why we published results in top journals, collaborated with researchers across academia and industry, and expanded our team to bring on new talent and expertise.

An update on hardware

The Quantum AI team is determined to build an error-corrected quantum computer within the next decade, and to simultaneously use what we learn along the way to deliver helpful—and even transformational—quantum computing applications. This long-term commitment is expanded broadly into three key questions for our quantum hardware:

  1. Can we demonstrate that quantum computers can outperform the classical supercomputers of today in a specific task? We demonstrated beyond-classical computation in 2019.
  2. Can we build a prototype of an error-corrected qubit? In order to use quantum computers to their full potential, we will need to realize quantum error correction to overcome the noise that is present during our computations. As a key step in this direction, we aim to realize the primitives of quantum error correction by redundantly encoding quantum information across several physical qubits, demonstrating that such redundancy leads to an improvement over using individual physical qubits. This is our current target.
  3. Can we build a logical qubit which does not have errors for an arbitrarily long time? Logical qubits encode information redundantly across several physical qubits, and are able to reduce the impact of noise on the overall quantum computation. Putting together a few thousand logical qubits would allow us to realize the full potential of quantum computers for various applications.

Progress toward building an error-corrected qubit prototype

The distance between the noisy quantum computers of today and the fully error-corrected quantum computers of the future is vast. In 2021, we made significant progress in closing this gap by working toward building a prototype logical qubit whose errors are smaller than those of the physical qubits on our chips.

This work requires improvements across the entire quantum computing stack. We have made chips with better qubits, improved the methods that we use to package these chips to better connect them with our control electronics, and developed techniques to calibrate large chips with several dozens of qubits simultaneously.

These improvements culminated in two key results. First, we are now able to reset our qubits with high fidelity, allowing us to reuse qubits in quantum computations. Second, we have realized mid-circuit measurement that allows us to keep track of computation within quantum circuits. Together, the high-fidelity resets and mid-circuit measurements were used in our recent demonstration of exponential suppression of bit and phase flip errors using repetition codes, resulting in 100x suppression of these errors as the size of the code grows from 5 to 21 qubits.

Chart chronicling repetition code

Suppression of logical errors as the number of qubits in the repetition code is increased. As we increase the code size from 5 to 21 qubits, we see 100x reduction in logical. Image acknowledgement: Kevin Satzinger/Google Quantum AI

Repetition codes, an error correction tool, enable us to trade-off between resources (more qubits) and performance (lower error) which will be central in guiding our hardware research and development going forward. This year we showed how error decreases as we increase the number of included qubits for a 1-dimensional code. We are currently running experiments to extend these results to two-dimensional surface codes which will correct errors more comprehensively.

Applications of quantum computation

In addition to building quantum hardware, our team is also looking for clear margins of quantum advantage in real world applications. With our collaborators in academia and industry, we are exploring fields where quantum computers can provide significant speedups, with realistic expectations that error-corrected quantum computers will likely require better than quadratic speedups for meaningful improvements.

As always, our collaborations with academic and industry partners were invaluable in 2021. One notable collaboration with Caltech showed that, under certain conditions, quantum machines can learn about physical systems from exponentially fewer experiments than what is conventionally required. This novel method was validated experimentally using 40 qubits and 1300 quantum operations, demonstrating a substantial quantum advantage even with the noisy quantum processors we have today. This paves the way to more innovation in quantum machine learning and quantum sensing, with potential near-term use cases.

In collaboration with researchers at Columbia University, we combined one of the most powerful techniques for chemical simulation, Quantum Monte Carlo, with quantum computation. This approach surpasses previous methods as a promising quantum approach to ground state many-electron calculations, which are critical in creating new materials and understanding their chemical properties. When we run a component of this technique on a real quantum computer, we are able to double the size of prior calculations without sacrificing accuracy of the measurements, even in the presence of noise on a device with up to 16 qubits. The resilience of this method to noise is an indication of its potential for scalability even on today’s quantum computers.

We continue to study how quantum computers can be used to simulate quantum physical phenomena—as was most recently reflected in our experimental observation of a time crystal on a quantum processor (Ask a Techspert: What exactly is a time crystal?). This was a great moment for theorists, who’ve pondered the possibility of time crystals for nearly a century. In other work, we also explored the emergence of quantum chaotic dynamics by experimentally measuring out-of-time-ordered correlations on one of our quantum computers, which was done jointly with collaborators at the NASA Ames Research Center; and experimentally measuring the entanglement entropy of the ground state of the Toric code Hamiltonian by creating its eigenstates using shallow quantum circuits with collaborators at the Technical University of Munich.

Our collaborators contributed to, and even inspired, some of our most impactful research in 2021. Quantum AI remains committed to discovering and realizing meaningful quantum applications in collaboration with scientists and researchers from across the world in 2022 and beyond as we continue our focus on machine learning, chemistry, and many-body quantum physics.

You can find a list of all our publications here.

Continuing investment in the quantum computing ecosystem

This year, at Google’s annual developer conference, Google I/O, we reaffirmed our commitment to the roadmap and investments required to make a useful quantum computer within the decade. While we were busy growing in Santa Barbara, we also continue to support the enablement of researchers in the quantum community through our open source software. Our quantum programming framework, Cirq, continues to improve with contributions from the community. 2021 also saw the release of specialized tools in collaboration with partners in the ecosystem. Two examples of these are:

  • The release of a new Fermionic Quantum Simulator for quantum chemistry applications in collaboration with QSimulate, taking advantage of the symmetry in quantum chemistry problems to provide efficient simulations.
  • A significant upgrade to qsim which allows for simulation of noisy quantum circuits on high performance processors such as GPUs via Google Cloud, and qsim integration with NVIDIA’s cuQuantum SDK to enable qsim users to make the most of NVIDIA GPUs when developing quantum algorithms and applications.

We also released an open-source tool called stim, which provides a 10000x speedup when simulating error correction circuits.

You can access our portfolio of open-source software here.

Looking toward 2022

Resident quantum scientist Qubit the Dog taking part in a holiday sing-along.

Resident quantum scientist Qubit the Dog taking part in a holiday sing-along led by team members Jimmy Chen and Ofer Naaman.

Through teamwork, collaboration, and some innovative science, we are excited about the progress that we have seen in 2021. We have big expectations for 2022 as we focus on progressing through our hardware milestones, the discovery of new quantum algorithms, and the realization of quantum applications on the quantum processors of today. To tackle our difficult mission, we are growing our team, building on our existing network of collaborators, and expanding our Santa Barbara campus. Together with the broader quantum community, we are excited to see the progress that quantum computing makes in 2022 and beyond.

Read More

This archaeologist fights tomb raiders with Google Earth

In the summer, Dr. Gino Caspari’s day starts at 5:30 a.m. in Siberia, where he studies the ancient Scythians with the Swiss National Science Foundation. There, he looks for burial places of these nomadic warriors who rode through Asia 2,500 years ago. The work isn’t easy, from dealing with extreme temperatures, to swamps covered with mosquitos. But the biggest challenge is staying one step ahead of tomb raiders.

It’s believed that more than 90% of the tombs — called kurgans — have already been destroyed by raiders looking to profit off what they find, but Gino is looking for the thousands he believes remain scattered across Russia, Mongolia and Western China. To track his progress, he began mapping these burial sites using Google Earth. “There’s a plethora of open data sources out there, but most of them don’t have the resolution necessary to detect individual archaeological structures,” Dr. Caspari says, pointing out that getting quality data is also very expensive. “Google Earth updates high-res data across the globe, and, especially in remote regions, it was a windfall for archaeologists. Google Earth expanded our possibilities to plan surveys and understand cultural heritage on a broader geographic scale.”

While Google Earth helped Dr. Caspari plan his expeditions, he still couldn’t stay ahead of the looters. He needed to get there faster. That’s when he met data scientist Pablo Crespo and started using another Google tool, TensorFlow.

“Since I started my PhD in 2013, I have been interested in automatic detection of archaeological sites from remote sensing data,” Gino says. “It was clear we needed to look at landscapes and human environmental interaction to understand past cultures. The problem was that our view was obscured by a lack of data and a focus on individual sites.” Back then, he tried some simple automatization processes to detect the places he needed for his research with the available technology, but only got limited results. In 2020, though, Gino and Pablo created a machine learning model using TensorFlow that could analyze satellite images they pulled from Google Earth. This model would look for places on the images that had the characteristics of a Scythian tomb.

The progress in the field of machine learning has been insanely fast, improving the quality of classification and detection to a point where it has become much more than just a theoretical possibility. Google’s freely available technologies have help

This technology sped up the discovery process for Gino, giving him an advantage over looters and even deterioration caused by climate change.

“Frankly, I think that without these tools, I probably wouldn’t have gotten this far in my understanding of technology and what it can do to make a difference in the study of our shared human past,” Gino says. “As a young scholar, I just lack the funds to access a lot of the resources I need. Working with Pablo and others has widened my perspective on what is possible and where we can go.”

Technology solutions have given Dr. Caspari’s work a new set of capabilities, supercharging what he’s able to do. And it’s also made him appreciate the importance of the human touch. “The deeper we dive into our past with the help of technology, the more apparent it becomes how patchy and incomplete our knowledge really is,” he says. “Technology often serves as an extension of our senses and mitigates our reality. Weaving the fabric of our reality will remain the task of the storyteller in us.”

Read More

A Scalable Approach for Partially Local Federated Learning

Federated learning enables users to train a model without sending raw data to a central server, thus avoiding the collection of privacy-sensitive data. Often this is done by learning a single global model for all users, even though the users may differ in their data distributions. For example, users of a mobile keyboard application may collaborate to train a suggestion model but have different preferences for the suggestions. This heterogeneity has motivated algorithms that can personalize a global model for each user.

However, in some settings privacy considerations may prohibit learning a fully global model. Consider models with user-specific embeddings, such as matrix factorization models for recommender systems. Training a fully global federated model would involve sending user embedding updates to a central server, which could potentially reveal the preferences encoded in the embeddings. Even for models without user-specific embeddings, having some parameters be completely local to user devices would reduce server-client communication and responsibly personalize those parameters to each user.

Left: A matrix factorization model with a user matrix P and items matrix Q. The user embedding for a user u (Pu) and item embedding for item i (Qi) are trained to predict the user’s rating for that item (Rui). Right: Applying federated learning approaches to learn a global model can involve sending updates for Pu to a central server, potentially leaking individual user preferences.

In “Federated Reconstruction: Partially Local Federated Learning”, presented at NeurIPS 2021, we introduce an approach that enables scalable partially local federated learning, where some model parameters are never aggregated on the server. For matrix factorization, this approach trains a recommender model while keeping user embeddings local to each user device. For other models, this approach trains a portion of the model to be completely personal for each user while avoiding communication of these parameters. We successfully deployed partially local federated learning to Gboard, resulting in better recommendations for hundreds of millions of keyboard users. We’re also releasing a TensorFlow Federated tutorial demonstrating how to use Federated Reconstruction.

Federated Reconstruction
Previous approaches for partially local federated learning used stateful algorithms, which require user devices to store a state across rounds of federated training. Specifically, these approaches required devices to store local parameters across rounds. However, these algorithms tend to degrade in large-scale federated learning settings. In these cases, the majority of users do not participate in training, and users who do participate likely only do so once, resulting in a state that is rarely available and can get stale across rounds. Also, all users who do not participate are left without trained local parameters, preventing practical applications.

Federated Reconstruction is stateless and avoids the need for user devices to store local parameters by reconstructing them whenever needed. When a user participates in training, before updating any globally aggregated model parameters, they randomly initialize and train their local parameters using gradient descent on local data with global parameters frozen. They can then calculate updates to global parameters with local parameters frozen. A round of Federated Reconstruction training is depicted below.

Models are partitioned into global and local parameters. For each round of Federated Reconstruction training: (1) The server sends the current global parameters g to each user i; (2) Each user i freezes g and reconstructs their local parameters li; (3) Each user i freezes li and updates g to produce gi; (4) Users’ gi are averaged to produce the global parameters for the next round. Steps (2) and (3) generally use distinct parts of the local data.

This simple approach avoids the challenges of previous methods. It does not assume users have a state from previous rounds of training, enabling large-scale training, and local parameters are always freshly reconstructed, preventing staleness. Users unseen during training can still get trained models and perform inference by simply reconstructing local parameters using local data.

Federated Reconstruction trains better performing models for unseen users compared to other approaches. For a matrix factorization task with unseen users, the approach significantly outperforms both centralized training and baseline Federated Averaging.

RMSE ↓ Accuracy ↑
Centralized 1.36 40.8%
FedAvg .934 40.0%
FedRecon (this work) .907 43.3%
Root-mean-square-error (lower is better) and accuracy for a matrix factorization task with unseen users. Centralized training and Federated Averaging (FedAvg) both reveal privacy-sensitive user embeddings to a central server, while Federated Reconstruction (FedRecon) avoids this.

These results can be explained via a connection to meta learning (i.e., learning to learn); Federated Reconstruction trains global parameters that lead to fast and accurate reconstruction of local parameters for unseen users. That is, Federated Reconstruction is learning to learn local parameters. In practice, we observe that just one gradient descent step can yield successful reconstruction, even for models with about one million local parameters.

Federated Reconstruction also provides a way to personalize models for heterogeneous users while reducing communication of model parameters — even for models without user-specific embeddings. To evaluate this, we apply Federated Reconstruction to personalize a next word prediction language model and observe a substantial increase in performance, attaining accuracy on par with other personalization methods despite reduced communication. Federated Reconstruction also outperforms other personalization methods when executed at a fixed communication level.

Accuracy ↑ Communication ↓
FedYogi 24.3% Whole Model
FedYogi + Finetuning 30.8% Whole Model
FedRecon (this work) 30.7% Partial Model
Accuracy and server-client communication for a next word prediction task without user-specific embeddings. FedYogi communicates all model parameters, while FedRecon avoids this.

Real-World Deployment in Gboard
To validate the practicality of Federated Reconstruction in large-scale settings, we deployed the algorithm to Gboard, a mobile keyboard application with hundreds of millions of users. Gboard users use expressions (e.g., GIFs, stickers) to communicate with others. Users have highly heterogeneous preferences for these expressions, making the setting a good fit for using matrix factorization to predict new expressions a user might want to share.

Gboard users can communicate with expressions, preferences for which are highly personal.

We trained a matrix factorization model over user-expression co-occurrences using Federated Reconstruction, keeping user embeddings local to each Gboard user. We then deployed the model to Gboard users, leading to a 29.3% increase in click-through-rate for expression recommendations. Since most Gboard users were unseen during federated training, Federated Reconstruction played a key role in this deployment.

Further Explorations
We’ve presented Federated Reconstruction, a method for partially local federated learning. Federated Reconstruction enables personalization to heterogeneous users while reducing communication of privacy-sensitive parameters. We scaled the approach to Gboard in alignment with our AI Principles, improving recommendations for hundreds of millions of users.

For a technical walkthrough of Federated Reconstruction for matrix factorization, check out the TensorFlow Federated tutorial. We’ve also released general-purpose TensorFlow Federated libraries and open-source code for running experiments.

Acknowledgements
Karan Singhal, Hakim Sidahmed, Zachary Garrett, Shanshan Wu, Keith Rush, and Sushant Prakash co-authored the paper. Thanks to Wei Li, Matt Newton, and Yang Lu for their partnership on Gboard deployment. We’d also like to thank Brendan McMahan, Lin Ning, Zachary Charles, Warren Morningstar, Daniel Ramage, Jakub Konecný, Alex Ingerman, Blaise Agüera y Arcas, Jay Yagnik, Bradley Green, and Ewa Dominowska for their helpful comments and support.

Read More

Training Machine Learning Models More Efficiently with Dataset Distillation

For a machine learning (ML) algorithm to be effective, useful features must be extracted from (often) large amounts of training data. However, this process can be made challenging due to the costs associated with training on such large datasets, both in terms of compute requirements and wall clock time. The idea of distillation plays an important role in these situations by reducing the resources required for the model to be effective. The most widely known form of distillation is model distillation (a.k.a. knowledge distillation), where the predictions of large, complex teacher models are distilled into smaller models.

An alternative option to this model-space approach is dataset distillation [1, 2], in which a large dataset is distilled into a synthetic, smaller dataset. Training a model with such a distilled dataset can reduce the required memory and compute. For example, instead of using all 50,000 images and labels of the CIFAR-10 dataset, one could use a distilled dataset consisting of only 10 synthesized data points (1 image per class) to train an ML model that can still achieve good performance on the unseen test set.

Top: Natural (i.e., unmodified) CIFAR-10 images. Bottom: Distilled dataset (1 image per class) on CIFAR-10 classification task. Using only these 10 synthetic images as training data, a model can achieve test set accuracy of ~51%.

In “Dataset Meta-Learning from Kernel Ridge Regression”, published in ICLR 2021, and “Dataset Distillation with Infinitely Wide Convolutional Networks”, presented at NeurIPS 2021, we introduce two novel dataset distillation algorithms, Kernel Inducing Points (KIP) and Label Solve (LS), which optimize datasets using the loss function arising from kernel regression (a classical machine learning algorithm that fits a linear model to features defined through a kernel). Applying the KIP and LS algorithms, we obtain very efficient distilled datasets for image classification, reducing the datasets to 1, 10, or 50 data points per class while still obtaining state-of-the-art results on a number of benchmark image classification datasets. Additionally, we are also excited to release our distilled datasets to benefit the wider research community.

Methodology
One of the key theoretical insights of deep neural networks (DNN) in recent years has been that increasing the width of DNNs results in more regular behavior that makes them easier to understand. As the width is taken to infinity, DNNs trained by gradient descent converge to the familiar and simpler class of models arising from kernel regression with respect to the neural tangent kernel (NTK), a kernel that measures input similarity by computing dot products of gradients of the neural network. Thanks to the Neural Tangents library, neural kernels for various DNN architectures can be computed in a scalable manner.

We utilized the above infinite-width limit theory of neural networks to tackle dataset distillation. Dataset distillation can be formulated as a two-stage optimization process: an “inner loop” that trains a model on learned data, and an “outer loop” that optimizes the learned data for performance on natural (i.e., unmodified) data. The infinite-width limit replaces the inner loop of training a finite-width neural network with a simple kernel regression. With the addition of a regularizing term, the kernel regression becomes a kernel ridge-regression (KRR) problem. This is a highly valuable outcome because the kernel ridge regressor (i.e., the predictor from the algorithm) has an explicit formula in terms of its training data (unlike a neural network predictor), which means that one can easily optimize the KRR loss function during the outer loop.

The original data labels can be represented by one-hot vectors, i.e., the true label is given a value of 1 and all other labels are given values of 0. Thus, an image of a cat would have the label “cat” assigned a 1 value, while the labels for “dog” and “horse” would be 0. The labels we use involve a subsequent mean-centering step, where we subtract the reciprocal of the number of classes from each component (so 0.1 for 10-way classification) so that the expected value of each label component across the dataset is normalized to zero.

While the labels for natural images appear in this standard form, the labels for our learned distilled datasets are free to be optimized for performance. Having obtained the kernel ridge regressor from the inner loop, the KRR loss function in the outer loop computes the mean-square error between the original labels of natural images and the labels predicted by the kernel ridge regressor. KIP optimizes the support data (images and possibly labels) by minimizing the KRR loss function through gradient-based methods. The Label Solve algorithm directly solves for the set of support labels that minimizes the KRR loss function, generating a unique dense label vector for each (natural) support image.

Example of labels obtained by label solving. Left and Middle: Sample images with possible labels listed below. The raw, one-hot label is shown in blue and the final LS generated dense label is shown in orange. Right: The covariance matrix between original labels and learned labels. Here, 500 labels were distilled from the CIFAR-10 dataset. A test accuracy of 69.7% is achieved using these labels for kernel ridge-regression.

Distributed Computation
For simplicity, we focus on architectures that consist of convolutional neural networks with pooling layers. Specifically, we focus on the so-called “ConvNet” architecture and its variants because it has been featured in other dataset distillation studies. We used a slightly modified version of ConvNet that has a simple architecture given by three blocks of convolution, ReLu, and 2×2 average pooling and then a final linear readout layer, with an additional 3×3 convolution and ReLu layer prepended (see our GitHub for precise details).

ConvNet architecture used in DC/DSA. Ours has an additional 3×3 Conv and ReLu prepended.

To compute the neural kernels needed in our work, we used the Neural Tangents library.

The first stage of this work, in which we applied KRR, focused on fully-connected networks, whose kernel elements are cheap to compute. But a hurdle facing neural kernels for models with convolutional layers plus pooling is that the computation of each kernel element between two images scales as the square of the number of input pixels (due to the capturing of pixel-pixel correlations by the kernel). So, for the second stage of this work, we needed to distribute the computation of the kernel elements and their gradients across many devices.

Distributed computation for large scale metalearning.

We invoke a client-server model of distributed computation in which a server distributes independent workloads to a large pool of client workers. A key part of this is to divide the backpropagation step in a way that is computationally efficient (explained in detail in the paper).

We accomplish this using the open-source tools Courier (part of DeepMind’s Launchpad), which allows us to distribute computations across GPUs working in parallel, and JAX, for which novel usage of the jax.vjp function enables computationally efficient gradients. This distributed framework allows us to utilize hundreds of GPUs per distillation of the dataset, for both the KIP and LS algorithms. Given the compute required for such experiments, we are releasing our distilled datasets to benefit the wider research community.

Examples
Our first set of distilled images above used KIP to distill CIFAR-10 down to 1 image per class while keeping the labels fixed. Next, in the below figure, we compare the test accuracy of training on natural MNIST images, KIP distilled images with labels fixed, and KIP distilled images with labels optimized. We highlight that learning the labels provides an effective, albeit mysterious benefit to distilling datasets. Indeed the resulting set of images provides the best test performance (for infinite-width networks) despite being less interpretable.

MNIST dataset distillation with trainable and non-trainable labels. Top: Natural MNIST data. Middle: Kernel Inducing Point distilled data with fixed labels. Bottom: Kernel Inducing Point distilled data with learned labels.

Results
Our distilled datasets achieve state-of-the-art performance on benchmark image classification datasets, improving performance beyond previous state-of-the-art models that used convolutional architectures, Dataset Condensation (DC) and Dataset Condensation with Differentiable Siamese Augmentation (DSA). In particular, for CIFAR-10 classification tasks, a model trained on a dataset consisting of only 10 distilled data entries (1 image / class, 0.02% of the whole dataset) achieves a 64% test set accuracy. Here, learning labels and an additional image preprocessing step leads to a significant increase in performance beyond the 50% test accuracy shown in our first figure (see our paper for details). With 500 images (50 images / class, 1% of the whole dataset), the model reaches 80% test set accuracy. While these numbers are with respect to neural kernels (using the KRR infinite width limit), these distilled datasets can be used to train finite-width neural networks as well. In particular, for 10 data points on CIFAR-10, a finite-width ConvNet neural network achieves 50% test accuracy with 10 images and 68% test accuracy using 500 images, which are still state-of-the-art results. We provide a simple Colab notebook demonstrating this transfer to a finite-width neural network.

Dataset distillation using Kernel Inducing Points (KIP) with a convolutional architecture outperforms prior state-of-the-art models (DC/DSA) on all benchmark settings on image classification tasks. Label Solve (LS, middle columns) while only distilling information in the labels could often (e.g. CIFAR-10 10, 50 data points per class) outperform prior state-of-the-art models as well.

In some cases, our learned datasets are more effective than a natural dataset one hundred times larger in size.

Conclusion
We believe that our work on dataset distillation opens up many interesting future directions. For instance, our algorithms KIP and LS have demonstrated the effectiveness of using learned labels, an area that remains relatively underexplored. Furthermore, we expect that utilizing efficient kernel approximation methods can help to reduce computational burden and scale up to larger datasets. We hope this work encourages researchers to explore other applications of dataset distillation, including neural architecture search and continual learning, and even potential applications to privacy.

Anyone interested in the KIP and LS learned datasets for further analysis is encouraged to check out our papers [ICLR 2021, NeurIPS 2021] and open-sourced code and datasets available on Github.

Acknowledgement
This project was done in collaboration with Zhourong Chen, Roman Novak and Lechao Xiao. We would like to acknowledge special thanks to Samuel S. Schoenholz, who proposed and helped develop the overall strategy for our distributed KIP learning methodology.


1Now at DeepMind.  

Read More

Interpretable Deep Learning for Time Series Forecasting

Multi-horizon forecasting, i.e. predicting variables-of-interest at multiple future time steps, is a crucial challenge in time series machine learning. Most real-world datasets have a time component, and forecasting the future can unlock great value. For example, retailers can use future sales to optimize their supply chain and promotions, investment managers are interested in forecasting the future prices of financial assets to maximize their performance, and healthcare institutions can use the number of future patient admissions to have sufficient personnel and equipment.

Deep neural networks (DNNs) have increasingly been used in multi-horizon forecasting, demonstrating strong performance improvements over traditional time series models. While many models (e.g., DeepAR, MQRNN) have focused on variants of recurrent neural networks (RNNs), recent improvements, including Transformer-based models, have used attention-based layers to enhance the selection of relevant time steps in the past beyond the inductive bias of RNNs – sequential ordered processing of information including. However, these often do not consider the different inputs commonly present in multi-horizon forecasting and either assume that all exogenous inputs are known into the future or neglect important static covariates.

Multi-horizon forecasting with static covariates and various time-dependent inputs.

Additionally, conventional time series models are controlled by complex nonlinear interactions between many parameters, making it difficult to explain how such models arrive at their predictions. Unfortunately, common methods to explain the behavior of DNNs have limitations. For example, post-hoc methods (e.g., LIME and SHAP) do not consider the order of input features. Some attention-based models are proposed with inherent interpretability for sequential data, primarily language or speech, but multi-horizon forecasting has many different types of inputs, not just language or speech. Attention-based models can provide insights into relevant time steps, but they cannot distinguish the importance of different features at a given time step. New methods are needed to tackle the heterogeneity of data in multi-horizon forecasting for high performance and to render these forecasts interpretable.

To that end, we announce “Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting”, published in the International Journal of Forecasting, where we propose the Temporal Fusion Transformer (TFT), an attention-based DNN model for multi-horizon forecasting. TFT is designed to explicitly align the model with the general multi-horizon forecasting task for both superior accuracy and interpretability, which we demonstrate across various use cases.

Temporal Fusion Transformer
We design TFT to efficiently build feature representations for each input type (i.e., static, known, or observed inputs) for high forecasting performance. The major constituents of TFT (shown below) are:

  1. Gating mechanismsto skip over any unused components of the model (learned from the data), providing adaptive depth and network complexity to accommodate a wide range of datasets.
  2. Variable selection networksto select relevant input variables at each time step. While conventional DNNs may overfit to irrelevant features, attention-based variable selection can improve generalization by encouraging the model to anchor most of its learning capacity on the most salient features.
  3. Static covariate encodersintegrate static features to control how temporal dynamics are modeled. Static features can have an important impact on forecasts, e.g., a store location could have different temporal dynamics for sales (e.g., a rural store may see higher weekend traffic, but a downtown store may see daily peaks after working hours).
  4. Temporal processingto learn both long- and short-term temporal relationships from both observed and known time-varying inputs. A sequence-to-sequence layer is employed for local processing as the inductive bias it has for ordered information processing is beneficial, whereas long-term dependencies are captured using a novel interpretable multi-head attention block. This can cut the effective path length of information, i.e., any past time step with relevant information (e.g. sales from last year) can be focused on directly.
  5. Prediction intervals show quantile forecasts to determine the range of target values at each prediction horizon, which help users understand the distribution of the output, not just the point forecasts.
TFT inputs static metadata, time-varying past inputs and time-varying a priori known future inputs. Variable Selection is used for judicious selection of the most salient features based on the input. Gated information is added as a residual input, followed by normalization. Gated residual network (GRN) blocks enable efficient information flow with skip connections and gating layers. Time-dependent processing is based on LSTMs for local processing, and multi-head attention for integrating information from any time step.

Forecasting Performance
We compare TFT to a wide range of models for multi-horizon forecasting, including various deep learning models with iterative methods (e.g., DeepAR, DeepSSM, ConvTrans) and direct methods (e.g., LSTM Seq2Seq, MQRNN), as well as traditional models such as ARIMA, ETS, and TRMF. Below is a comparison to a truncated list of models.

Model Electricity Traffic Volatility Retail
ARIMA 0.154 (+180%) 0.223 (+135%)
ETS 0.102 (+85%) 0.236 (+148%)
DeepAR 0.075 (+36%) 0.161 (+69%) 0.050 (+28%) 0.574 (+62%)
Seq2Seq 0.067 (+22%) 0.105 (+11%) 0.042 (+7%) 0.411 (+16%)
MQRNN 0.077 (+40%) 0.117 (+23%) 0.042 (+7%) 0.379 (+7%)
TFT 0.055 0.095 0.039 0.354
P50 quantile losses (lower is better) for TFT vs. alternative models.

As shown above, TFT outperforms all benchmarks over a variety of datasets. This applies to both point forecasts and uncertainty estimates, with TFT yielding an average 7% lower P50 and 9% lower P90 losses, respectively, compared to the next best model.

Interpretability Use Cases
We demonstrate how TFT’s design allows for analysis of its individual components for enhanced interpretability with three use cases.

  • Variable Importance
    One can observe how different variables impact retail sales by observing their model weights. For example, the largest weights for static variables were the specific store and item, while the largest weights for future variables were promotion period and national holiday (shown below).

    Variable importance for the retail dataset. The 10th, 50th, and 90th percentiles of the variable selection weights are shown, with values larger than 0.1 in bold purple.
  • Persistent Temporal Patterns
    Visualizing persistent temporal patterns can help in understanding the time-dependent relationships present in a given dataset. We identify similar persistent patterns by measuring the contributions of features at fixed lags in the past forecasts at various horizons. Shown below, attention weights reveal the most important past time steps on which TFT bases its decisions.

    Persistent temporal patterns for the traffic dataset (𝛕 denotes the forecasting horizon) for the 10%, 50% and 90% quantile levels. Clear periodicity is observed with peaks being separated by ~24 hours, i.e., the model attends the most to the time steps that are at the same time of the day from past days, which is aligned with the expected daily traffic patterns.

    The above shows the attention weight patterns across time, indicating how TFT learns persistent temporal patterns without any hard-coding. Such capability can help build trust with users because the output confirms expected known patterns. Model developers can also use these towards model improvements, e.g., via specific feature engineering or data collection.

  • Identifying Significant Events
    Identifying sudden changes can be useful, as temporary shifts can occur due to the presence of significant events. TFT uses the distance between attention patterns at each point with the average pattern to identify the significant deviations. The figures below show that TFT can alter its attention between events — placing equal attention across past inputs when volatility is low, while attending more to sharp trend changes during high volatility periods.

    Event identification for S&P 500 realized volatility from 2002 through 2014.

    Significant deviations in attention patterns can be observed above around periods of high volatility, corresponding to the peaks observed in dist(t), distance between attention patterns (red line). We use a threshold to denote significant events, as highlighted in purple.

    Focusing on periods around the 2008 financial crisis, the bottom plot below zooms on midway through the significant event (evident from the increased attention on sharp trend changes), compared to the normal event in the top plot (where attention is equal over low volatility periods).

    Event identification for S&P 500 realized volatility, a zoom of the above on a period from 2004 and 2005.

    Event identification for S&P 500 realized volatility, a zoom of the above on a period from 2008 and 2009.

Real-World Impact
Finally, TFT has been used to help retail and logistics companies with demand forecasting by both improving forecasting accuracy and providing interpretability capabilities.

Additionally, TFT has potential applications for climate-related challenges: for example, reducing greenhouse gas emissions by balancing electricity supply and demand in real time, and improving the accuracy and interpretability of rainfall forecasting results.

Conclusion
We present a novel attention-based model for high-performance multi-horizon forecasting. In addition to improved performance across a range of datasets, TFT also contains specialized components for inherent interpretability — i.e., variable selection networks and interpretable multi-head attention. With three interpretability use-cases, we also demonstrate how these components can be used to extract insights on feature importance and temporal dynamics.

Acknowledgements
We gratefully acknowledge contributions of Bryan Lim, Nicolas Loeff, Minho Jin, Yaguang Li, and Andrew Moore.

Read More

A Fast WordPiece Tokenization System

Tokenization is a fundamental pre-processing step for most natural language processing (NLP) applications. It involves splitting text into smaller units called tokens (e.g., words or word segments) in order to turn an unstructured input string into a sequence of discrete elements that is suitable for a machine learning (ML) model. ln deep learning–based models (e.g., BERT), each token is mapped to an embedding vector to be fed into the model.

Tokenization in a typical deep learning model, like BERT.

A fundamental tokenization approach is to break text into words. However, using this approach, words that are not included in the vocabulary are treated as “unknown”. Modern NLP models address this issue by tokenizing text into subword units, which often retain linguistic meaning (e.g., morphemes). So, even though a word may be unknown to the model, individual subword tokens may retain enough information for the model to infer the meaning to some extent. One such subword tokenization technique that is commonly used and can be applied to many other NLP models is called WordPiece. Given text, WordPiece first pre-tokenizes the text into words (by splitting on punctuation and whitespaces) and then tokenizes each word into subword units, called wordpieces.

The WordPiece tokenization process with an example sentence.

In “Fast WordPiece Tokenization”, presented at EMNLP 2021, we developed an improved end-to-end WordPiece tokenization system that speeds up the tokenization process, reducing the overall model latency and saving computing resources. In comparison to traditional algorithms that have been used for decades, this approach reduces the complexity of the computation by an order of magnitude, resulting in significantly improved performance, up to 8x faster than standard approaches. The system has been applied successfully in a number of systems at Google and has been publicly released in TensorFlow Text.

Single-Word WordPiece Tokenization
WordPiece uses a greedy longest-match-first strategy to tokenize a single word — i.e., it iteratively picks the longest prefix of the remaining text that matches a word in the model’s vocabulary. This approach is known as maximum matching or MaxMatch, and has also been used for Chinese word segmentation since the 1980s. Yet despite its wide use in NLP for decades, it is still relatively computation intensive, with the commonly adopted MaxMatch approaches’ computation being quadratic with respect to the input word length (n). This is because two pointers are needed to scan over the input: one to mark a start position, and the other to search for the longest substring matching a vocabulary token at that position.

We propose an alternative to the MaxMatch algorithm for WordPiece tokenization, called LinMaxMatch, which has a tokenization time that is strictly linear with respect to n. First, we organize the vocabulary tokens in a trie (also called a prefix tree), where each trie edge is labeled by a character, and a tree path from the root to some node represents a prefix of some token in the vocabulary. In the figure below, nodes are depicted as circles and tree edges are black solid arrows. Given a trie, a vocabulary token can be located to match an input text by traversing from the root and following the trie edges to match the input character by character; this process is referred to as trie matching.

The figure below shows the trie created from the vocabulary consisting of “a”, “abcd”, “##b”, “##bc”, and “##z”. An input text “abcd” can be matched to a vocabulary token by walking from the root (upper left) and following the trie edges with labels “a”, “b”, “c”, “d” one by one. (The leading “##” symbols are special characters used in WordPiece tokenization that are described in more detail below.)

Trie diagram of the vocabulary [“a”, “abcd”, “##b”, “##bc”, “##z”]. Circles and arrows represent nodes and edges along the trie, respectively.

Second, inspired by the Aho-Corasick algorithm, a classical string-searching algorithm invented in 1975, we introduce a method that breaks out of a trie branch that fails to match the given input and skips directly to an alternative branch to continue matching. As in standard trie matching, during tokenization, we follow the trie edges to match the input characters one by one. When trie matching cannot match an input character for a given node, a standard algorithm would backtrack to the last character where a token was matched and then restart the trie matching procedure from there, which results in repetitive and wasteful iterations. Instead of backtracking, our method triggers a failure transition, which is done in two steps: (1) it collects the precomputed tokens stored at that node, which we call failure pops; and (2) it then follows the precomputed failure link to a new node from which the trie matching process continues.

For example, given a model with the vocabulary described above (“a”, “abcd”, “##b”, “##bc”, and “##z”), WordPiece tokenization distinguishes subword tokens matching at the start of the input word from the subword tokens starting in the middle (the latter being marked with two leading hashes “##”). Hence, for input text “abcz”, the expected tokenization output is [“a”, “##bc”, “##z”], where “a” matches at the beginning of the input while “##bc” and “##z” match in the middle. For this example, the figure below shows that, after successfully matching three characters ‘a’, ‘b’, ‘c’, trie matching cannot match the next character ‘z’ because “abcz” is not in the vocabulary. In this situation, LinMaxMatch conducts a failure transition by outputting the first recognized token (using the failure pop token “a”) and following the failure link to a new node to continue the matching process (in this case, node with “##bc” as the failure pop tokens).The process then repeats from the new node.

Trie structure for the same vocabulary as shown in the example above, now illustrating the approach taken by our new Fast WordPiece Tokenizer algorithm. Failure pops are bracketed and shown in purple. Failure links between nodes are indicated with dashed red line arrows.

Since at least n operations are required to read the entire input, the LinMaxMatch algorithm is asymptotically optimal for the MaxMatch problem.

End-to-End WordPiece Tokenization
Whereas the existing systems pre-tokenize the input text (splitting it into words by punctuation and whitespace characters) and then call WordPiece tokenization on each resulting word, we propose an end-to-end WordPiece tokenizer that combines pre-tokenization and WordPiece into a single, linear-time pass. It uses the LinMaxMatch trie matching and failure transitions as much as possible and only checks for punctuation and whitespace characters among the relatively few input characters that are not handled by the loop. It is more efficient as it traverses the input only once, performs fewer punctuation / whitespace checks, and skips the creation of intermediate words.

End-to-End WordPiece Tokenization.

Benchmark Results
We benchmark our method against two widely-adopted WordPiece tokenization implementations, HuggingFace Tokenizers, from the HuggingFace Transformer library, one of the most popular open-source NLP tools, and TensorFlow Text, the official library of text utilities for TensorFlow. We use the WordPiece vocabulary released with the BERT-Base, Multilingual Cased model.

We compared our algorithms with HuggingFace and TensorFlow Text on a large corpus (several million words) and found that the way the strings are split into tokens is identical to other implementations for both single-word and end-to-end tokenization.

To generate the test data, we sample 1,000 sentences from the multilingual Wikipedia dataset, covering 82 languages. On average, each word has four characters, and each sentence has 82 characters or 17 words. We found this dataset large enough because a much larger dataset (consisting of hundreds of thousands of sentences) generated similar results.

We compare the average runtime when tokenizing a single word or general text (end-to-end) for each system. Fast WordPiece tokenizer is 8.2x faster than HuggingFace and 5.1x faster than TensorFlow Text, on average, for general text end-to-end tokenization.

Average runtime of each system. Note that for better visualization, single-word tokenization and end-to-end tokenization are shown in different scales.

We also examine how the runtime grows with respect to the input length for single-word tokenization. Because of its linear-time complexity, the runtime of LinMaxMatch increases at most linearly with the input length, which is much slower than other quadratic-time approaches.

The average runtime of each system with respect to the input length for single-word tokenization.

Conclusion
We proposed LinMaxMatch for single-word WordPiece tokenization, which solves the decades-old MaxMatch problem in the asymptotically-optimal time with respect to the input length. LinMaxMatch extends the Aho-Corasick Algorithm, and the idea can be applied to more string search and transducer challenges. We also proposed an End-to-End WordPiece algorithm that combines pre-tokenization and WordPiece tokenization into a single, linear-time pass for even higher efficiency.

Acknowledgements
We gratefully acknowledge the key contributions and useful advices from other team members and colleagues, including Abbas Bazzi, Alexander Frömmgen, Alex Salcianu, Andrew Hilton, Bradley Green, Ed Chi, Chen Chen, Dave Dopson, Eric Lehman, Fangtao Li, Gabriel Schubiner, Gang Li, Greg Billock, Hong Wang, Jacob Devlin, Jayant Madhavan, JD Chen, Jifan Zhu, Jing Li, John Blitzer, Kirill Borozdin, Kristina Toutanova, Majid Hadian-Jazi, Mark Omernick, Max Gubin, Michael Fields, Michael Kwong, Namrata Godbole, Nathan Lintz, Pandu Nayak, Pew Putthividhya, Pranav Khaitan, Robby Neale, Ryan Doherty, Sameer Panwar, Sundeep Tirumalareddy, Terry Huang, Thomas Strohmann, Tim Herrmann, Tom Small, Tomer Shani, Wenwei Yu, Xiaoxue Zang, Xin Li, Yang Guo, Yang Song, Yiming Xiao, Yuan Shen, and many more.

Read More

More Efficient In-Context Learning with GLaM

Large language models (e.g., GPT-3) have many significant capabilities, such as performing few-shot learning across a wide array of tasks, including reading comprehension and question answering with very few or no training examples. While these models can perform better by simply using more parameters, training and serving these large models can be very computationally intensive. Is it possible to train and use these models more efficiently?

In pursuit of that question, today we introduce the Generalist Language Model (GLaM), a trillion weight model that can be trained and served efficiently (in terms of computation and energy use) thanks to sparsity, and achieves competitive performance on multiple few-shot learning tasks. GLaM’s performance compares favorably to a dense language model, GPT-3 (175B) with significantly improved learning efficiency across 29 public NLP benchmarks in seven categories, spanning language completion, open-domain question answering, and natural language inference tasks.

Dataset
To build GLaM, we began by building a high-quality 1.6 trillion token dataset containing language usage representative of a wide range of downstream use-cases for the model. Web pages constitute the vast quantity of data in this unlabelled corpus, but their quality ranges from professional writing to low-quality comment and forum pages. We then developed a text quality filter that was trained on a collection of text from Wikipedia and books (both of which are generally higher quality sources) to determine the quality of the content for a webpage. Finally, we applied this filter to generate the final subset of webpages and combined this with books and Wikipedia to create the final training dataset.

Model and Architecture
GLaM is a mixture of experts (MoE) model, a type of model that can be thought of as having different submodels (or experts) that are each specialized for different inputs. The experts in each layer are controlled by a gating network that activates experts based on the input data. For each token (generally a word or part of a word), the gating network selects the two most appropriate experts to process the data. The full version of GLaM has 1.2T total parameters across 64 experts per MoE layer with 32 MoE layers in total, but only activates a subnetwork of 97B (8% of 1.2T) parameters per token prediction during inference.

The architecture of GLaM where each input token is dynamically routed to two selected expert networks out of 64 for prediction.

Similar to the GShard MoE Transformer, we replace the single feedforward network (the simplest layer of an artificial neural network, “Feedforward or FFN” in the blue boxes) of every other transformer layer with a MoE layer. This MoE layer has multiple experts, each a feedforward network with identical architecture but different weight parameters. Even though this MoE layer has many more parameters, the experts are sparsely activated, meaning that for a given input token, only two experts are used, giving the model more capacity while limiting computation. During training, each MoE layer’s gating network is trained to use its input to activate the best two experts for each token, which are then used for inference. For a MoE layer of E experts, this essentially provides a collection of E×(E-1) different feedforward network combinations (instead of one as in the classic Transformer architecture), leading to more computational flexibility.

The final learned representation of a token will be the weighted combination of the outputs from the two experts. This allows different experts to activate on different types of inputs. To enable scaling to larger models, each expert within the GLaM architecture can span multiple computational devices. We use the GSPMD compiler backend to solve the challenges in scaling the experts and train several variants (based on expert size and number of experts) of this architecture to understand the scaling effects of sparsely activated language models.

Evaluation
We use a zero-shot and one-shot setting where the tasks are never seen during training. The benchmarks for evaluation include (1) cloze and completion tasks [1,2,3]; (2) Open-domain question answering [4,5,6]; (3) Winograd-style tasks [7,8]; (4) commonsense reasoning [9,10,11]; (5) in-context reading comprehension [12,13,14,15,16]; (6) the SuperGLUE tasks; and (7) natural language inference [17]. In total, there are eight natural language generation tasks (NLG) where the generated phrases are evaluated against the ground truth targets via Exact Match (EM) accuracy and F1 measure, and 21 language understanding tasks (NLU) where the prediction from several options is chosen via conditional log-likelihood. Some tasks have variants and SuperGLUE consists of multiple tasks. Both EM accuracy and F1 are scaled from 0 to 100 across all our results and averaged for the NLG score below. The NLU score is an average of accuracy and F1 scores.

Results
GLaM reduces to a basic dense Transformer-based language model architecture when each MoE layer only has one expert. In all experiments, we adopt the notation of (base dense model size) / (number of experts per MoE layer) to describe the GLaM model. For example, 1B/64E represents the architecture of a 1B parameter dense model with every other layer replaced by a 64 expert MoE layer. In the following sections, we explore GLaM’s performance and scaling properties, including baseline dense models trained on the same datasets. Compared with the recently announced Megatron-Turing model, GLaM is on-par on the seven respective tasks if using a 5% margin, while using 5x less computation during inference.

Below, we show the 1.2T-parameter sparsely activated model (GLaM) achieved higher results on average and on more tasks than the 175B-parameter dense GPT-3 model while using less computation during inference.

Average score for GLaM and GPT-3 on NLG (left) and NLU (right) tasks (higher is better).

Below we show a summary of the performance on 29 benchmarks compared to the dense model (GPT-3, 175B). GLaM exceeds or is on-par with the performance of the dense model on almost 80% of zero-shot tasks and almost 90% of one-shot tasks.

Evaluation Higher (>+5%) On-par (within 5%) Lower (<-5%)
Zero-shot 13 11 5
One-shot 14 10 5

Moreover, while the full version of GLaM has 1.2T total parameters, it only activates a subnetwork of 97B parameters (8% of 1.2T) per token during inference.

GLaM (64B/64E) GPT-3 (175B)
Total Parameters 1.162T 0.175T
Activated Parameters 0.097T 0.175T

Scaling Behavior
GLaM has two ways to scale: 1) scale the number of experts per layer, where each expert is hosted within one computation device, or 2) scale the size of each expert to go beyond the limit of a single device. To evaluate the scaling properties, we compare the respective dense model (FFN layers instead of MoE layers) of similar FLOPS per token at inference time.

Average zero-shot and one-shot performance by increasing the size of each expert. The FLOPS per token prediction at inference time increases as the expert size grows.

As shown above, performance across tasks scales with the size of the experts. GLaM sparsely activated models also perform better than dense models for similar FLOPs during inference for generation tasks. For understanding tasks, we observed that they perform similarly at smaller scales, but sparsely activated models outperform at larger scales.

Data Efficiency
Training large language models is computationally intensive, so efficiency improvements are useful to reduce energy consumption.

Below we show the computation costs for the full version of GLaM.

Computation cost in GFLOPS both for inference, per token (left) and for training (right).

These compute costs show that GLaM uses more computation during training since it trains on more tokens, but uses significantly less computation during inference. We show comparisons using different numbers of tokens to train below.

We also evaluated the learning curves of our models compared to the dense baseline.

Average zero-shot and one-shot performance of sparsely-activated and dense models on eight generative tasks as more tokens are processed in training.
Average zero-shot and one-shot performance of sparsely-activated and dense models on 21 understanding tasks as more tokens are processed in training.

The results above show that sparsely activated models need to train with significantly less data than dense models to reach similar zero-shot and one-shot performance, and if the same amount of data is used, sparsely activated models perform significantly better.

Finally, we assessed the energy efficiency of GLaM.

Comparison of power consumption during training.

While GLaM uses more computation during training, thanks to the more efficient software implementation powered by GSPMD and the advantage of TPUv4, it uses less power to train than other models.

Conclusions
Our large-scale sparsely activated language model, GLaM, achieves competitive results on zero-shot and one-shot learning and is a more efficient model than prior monolithic dense counterparts. We also show quantitatively that a high-quality dataset is essential for large language models. We hope that our work will spark more research into compute-efficient language models.

Acknowledgements
We wish to thank Claire Cui, Zhifeng Chen, Yonghui Wu, Quoc Le, Macduff Hughes, Fernando Pereira, Zoubin Ghahramani‎ and Jeff Dean for their support and invaluable input. Special thanks to our collaborators: Yanping Huang, Simon Tong, Yanqi Zhou, Yuanzhong Xu, Dmitry Lepikhin, Orhan Firat, Maxim Krikun, Tao Wang, Noam Shazeer, Barret Zoph, Liam Fedus, Maarten Bosma, Kun Zhang, Emma Wang, David Patterson, Zongwei Zhou, Naveen Kumar, Adams Yu, Laurent Shafey, Jonathan Shen, Ben Lee, Anmol Gulati, David So, Marie Pellat, Kevin Robinson, Kathy Meier-Hellstern‎, Aakanksha Chowdhery, Sharan Narang, Erica Moreira and Eric Ni for helpful discussions and inspirations; and the larger Google Research team. We would also like to thank Tom Small for the animated figure used in this post.

Read More

General and Scalable Parallelization for Neural Networks

Scaling neural networks, whether it be the amount of training data used, the model size or the computation being utilized, has been critical for improving model quality in many real-world machine learning applications, such as computer vision, language understanding and neural machine translation. This, in turn, has motivated recent studies to scrutinize the factors that play a critical role in the success of scaling a neural model. Although increasing model capacity can be a sound approach to improve model quality, doing so presents a number of systems and software engineering challenges that must be overcome. For instance, in order to train large models that exceed the memory capacity of an accelerator, it becomes necessary to partition the weights and the computation of the model across multiple accelerators. This process of parallelization increases the network communication overhead and can result in device under-utilization. Moreover, a given algorithm for parallelization, which typically requires a significant amount of engineering effort, may not work with different model architectures.

To address these scaling challenges, we present “GSPMD: General and Scalable Parallelization for ML Computation Graphs”, in which we describe an open-source automatic parallelization system based on the XLA compiler. GSPMD is capable of scaling most deep learning network architectures and has already been applied to many deep learning models, such as GShard-M4, LaMDA, BigSSL, ViT, and MetNet-2, leading to state-of-the-art-results across several domains. GSPMD has also been integrated into multiple ML frameworks, including TensorFlow and JAX, which use XLA as a shared compiler.

Overview
GSPMD separates the task of programming an ML model from the challenge of parallelization. It allows model developers to write programs as if they were run on a single device with very high memory and computation capacity — the user simply needs to add a few lines of annotation code to a subset of critical tensors in the model code to indicate how to partition the tensors. For example, to train a large model-parallel Transformer, one may only need to annotate fewer than 10 tensors (less than 1% of all tensors in the entire computation graph), one line of additional code per tensor. Then GSPMD runs a compiler pass that determines the entire graph’s parallelization plan, and transforms it into a mathematically equivalent, parallelized computation that can be executed on each device. This allows users to focus on model building instead of parallelization implementation, and enables easy porting of existing single-device programs to run at a much larger scale.

The separation of model programming and parallelism also allows developers to minimize code duplication. With GSPMD, developers may employ different parallelism algorithms for different use cases without the need to reimplement the model. For example, the model code that powered the GShard-M4 and LaMDA models can apply a variety of parallelization strategies appropriate for different models and cluster sizes with the same model implementation. Similarly, by applying GSPMD, the BigSSL large speech models can share the same implementation with previous smaller models.

Generality and Flexibility
Because different model architectures may be better suited to different parallelization strategies, GSPMD is designed to support a large variety of parallelism algorithms appropriate for different use cases. For example, with smaller models that fit within the memory of a single accelerator, data parallelism is preferred, in which devices train the same model using different input data. In contrast, models that are larger than a single accelerator’s memory capacity are better suited for a pipelining algorithm (like that employed by GPipe) that partitions the model into multiple, sequential stages, or operator-level parallelism (e.g., Mesh-TensorFlow), in which individual computation operators in the model are split into smaller, parallel operators.

GSPMD supports all the above parallelization algorithms with a uniform abstraction and implementation. Moreover, GSPMD supports nested patterns of parallelism. For example, it can be used to partition models into individual pipeline stages, each of which can be further partitioned using operator-level parallelism.

GSPMD also facilitates innovation on parallelism algorithms by allowing performance experts to focus on algorithms that best utilize the hardware, instead of the implementation that involves lots of cross-device communications. For example, for large Transformer models, we found a novel operator-level parallelism algorithm that partitions multiple dimensions of tensors on a 2D mesh of devices. It reduces peak accelerator memory usage linearly with the number of training devices, while maintaining a high utilization of accelerator compute due to its balanced data distribution over multiple dimensions.

To illustrate this, consider a simplified feedforward layer in a Transformer model that has been annotated in the above way. To execute the first matrix multiply on fully partitioned input data, GSPMD applies an MPI-style AllGather communication operator to partially merge with partitioned data from another device. It then executes the matrix multiply locally and produces a partitioned result. Before the second matrix multiply, GSPMD adds another AllGather on the right-hand side input, and executes the matrix multiply locally, yielding intermediate results that will then need to be combined and partitioned. For this, GSPMD adds an MPI-style ReduceScatter communication operator that accumulates and partitions these intermediate results. While the tensors generated with the AllGather operator at each stage are larger than the original partition size, they are short-lived and the corresponding memory buffers will be freed after use, which does not affect peak memory usage in training.

Left: A simplified feedforward layer of a Transformer model. Blue rectangles represent tensors with dashed red & blue lines overlaid representing the desired partitioning across a 2×2 mesh of devices. Right: A single partition, after GSPMD has been applied.

A Transformer Example with Nested Parallelism
As a shared, robust mechanism for different parallelism modes, GSPMD allows users to conveniently switch between modes in different parts of a model. This is particularly valuable for models that may have different components with distinct performance characteristics, for example, multimodal models that handle both images and audio. Consider a model with the Transformer encoder-decoder architecture, which has an embedding layer, an encoder stack with Mixture-of-Expert layers, a decoder stack with dense feedforward layers, and a final softmax layer. In GSPMD, a complex combination of several parallelism modes that treats each layer separately can be achieved with simple configurations.

In the figure below, we show a partitioning strategy over 16 devices organized as a logical 4×4 mesh. Blue represents partitioning along the first mesh dimension X, and yellow represents partitioning along the second mesh dimension Y. X and Y are repurposed for different model components to achieve different parallelism modes. For example, the X dimension is used for data parallelism in the embedding and softmax layers, but used for pipeline parallelism in the encoder and decoder. The Y dimension is also used in different ways to partition the vocabulary, batch or model expert dimensions.

Computation Efficiency
GSPMD provides industry-leading performance in large model training. Parallel models require extra communication to coordinate multiple devices to do the computation. So parallel model efficiency can be estimated by examining the fraction of time spent on communication overhead — the higher percentage utilization and the less time spent on communication, the better. In the recent MLPerf set of performance benchmarks, a BERT-like encoder-only model with ~500 billion parameters to which we applied GSPMD for parallelization over 2048 TPU-V4 chips yielded highly competitive results (see table below), utilizing up to 63% of the peak FLOPS that the TPU-V4s offer. We also provide efficiency benchmarks for some representative large models in the table below. These example model configs are open sourced in the Lingvo framework along with instructions to run them on Google Cloud. More benchmark results can be found in the experiment section of our paper.

Model Family Parameter Count % of model activated* No. of Experts** No. of Layers No. of TPU FLOPS utilization
Dense Decoder (LaMDA) 137B 100% 1 64 1024 TPUv3 56.5%
Dense Encoder (MLPerf-Bert) 480B 100% 1 64 2048 TPUv4 63%
Sparsely Activated Encoder-Decoder (GShard-M4) 577B 0.25% 2048 32 1024 TPUv3 46.8%
Sparsely Activated Decoder 1.2T 8% 64 64 1024 TPUv3 53.8%
*The fraction of the model activated during inference, which is a measure of model sparsity.
**Number of experts included in the Mixture of Experts layer. A value of 1 corresponds to a standard Transformer, without a Mixture of Experts layer.

Conclusion
The ongoing development and success of many useful machine learning applications, such as NLP, speech recognition, machine translation, and autonomous driving, depend on achieving the highest accuracy possible. As this often requires building larger and even more complex models, we are pleased to share the GSPMD paper and the corresponding open-source library to the broader research community, and we hope it is useful for efficient training of large-scale deep neural networks.

Acknowledgements
We wish to thank Claire Cui, Zhifeng Chen, Yonghui Wu, Naveen Kumar, Macduff Hughes, Zoubin Ghahramani and Jeff Dean for their support and invaluable input. Special thanks to our collaborators Dmitry Lepikhin, HyoukJoong Lee, Dehao Chen, Orhan Firat, Maxim Krikun, Blake Hechtman, Rahul Joshi, Andy Li, Tao Wang, Marcello Maggioni, David Majnemer, Noam Shazeer, Ankur Bapna, Sneha Kudugunta, Quoc Le, Mia Chen, Shibo Wang, Jinliang Wei, Ruoming Pang, Zongwei Zhou, David So, Yanqi Zhou, Ben Lee, Jonathan Shen, James Qin, Yu Zhang, Wei Han, Anmol Gulati, Laurent El Shafey, Andrew Dai, Kun Zhang, Nan Du, James Bradbury, Matthew Johnson, Anselm Levskaya, Skye Wanderman-Milne‎, and Qiao Zhang for helpful discussions and inspirations.

Read More

Unlocking human rights information with machine learning

Human rights defenders need information from many sources to do their work effectively. But as issues evolve and new precedents are set, finding the right information to defend a particular case can be like looking for a needle in a haystack.

For example, a human rights advocate campaigning for LGBTQ rights may want to know which countries have made the most progress and what resolutions they’ve passed. To do so, they have to manually sift through thousands of pages of dense documentation covering global laws and victims’ testimonies to find what they’re looking for.

The curation and cataloging of documents makes this process much easier, but still relies on the manual work of skilled experts. To help, the non-profit organization HURIDOCS looked to machine learning. With support from Google.org Fellows and grant funding, they’ve built new tools that can automatically tag human rights documents so they are searchable — making the curation process 13 times faster.

How machine learning can make information more accessible

Typically, non-governmental organizations collect and curate large bodies of human rights information, with the goal of making these collections useful for advocates. Manually processing these documents can take several days, particularly when they’re published in unfamiliar languages or in PDF format which is difficult to search through. As a result, many NGOs face a large backlog of documents that remain to be processed, and by the time they’re added to collections new documentation often supersedes them.

Based in Geneva, HURIDOCS has been developing tools to manage and analyze collections of human rights evidence, law and research for nearly four decades. In 2016, they had an idea: What if machine learning could skim through documents, make terms extractable, and classify the content to catalog documents more quickly?

HURIDOCS took their idea to the Google AI Impact Challenge and was selected for a $1 million grant from Google.org and six months of technical support from a team of seven full-time pro bono Google.org Fellows. As one of the Fellows, I helped train AI models and make sure that the tool was useful to human rights experts, not just machine learning experts.

The curation process of human rights documents gets a boost

Since then, HURIDOCS has launched ML-powered features to improve platforms they’ve built with other NGO partners, and, earlier this year, they began integrating the technology into more of its tools, including their flagship application Uwazi. As a result, updating documents now takes one week instead of two to three months, and curators have been able to catch up on multi-year document backlogs.

In June, HURIDOCS won a CogX Award for its machine learning work, and now the organization is continuing to explore what else its machine learning models can do — from creating automatic tables of contents for documents to identifying references within text. With the power of artificial intelligence, HURIDOCS hopes to solve the trickiest challenges facing human rights defenders.

Read More

Improving Vision Transformer Efficiency and Accuracy by Learning to Tokenize

Transformer models consistently obtain state-of-the-art results in computer vision tasks, including object detection and video classification. In contrast to standard convolutional approaches that process images pixel-by-pixel, the Vision Transformers (ViT) treat an image as a sequence of patch tokens (i.e., a smaller part, or “patch”, of an image made up of multiple pixels). This means that at every layer, a ViT model recombines and processes patch tokens based on relations between each pair of tokens, using multi-head self-attention. In doing so, ViT models have the capability to construct a global representation of the entire image.

At the input-level, the tokens are formed by uniformly splitting the image into multiple segments, e.g., splitting an image that is 512 by 512 pixels into patches that are 16 by 16 pixels. At the intermediate levels, the outputs from the previous layer become the tokens for the next layer. In the case of videos, video ‘tubelets’ such as 16x16x2 video segments (16×16 images over 2 frames) become tokens. The quality and quantity of the visual tokens decide the overall quality of the Vision Transformer.

The main challenge in many Vision Transformer architectures is that they often require too many tokens to obtain reasonable results. Even with 16×16 patch tokenization, for instance, a single 512×512 image corresponds to 1024 tokens. For videos with multiple frames, that results in tens of thousands of tokens needing to be processed at every layer. Considering that the Transformer computation increases quadratically with the number of tokens, this can often make Transformers intractable for larger images and longer videos. This leads to the question: is it really necessary to process that many tokens at every layer?

In “TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?”, an earlier version of which is presented at NeurIPS 2021, we show that adaptively generating a smaller number of tokens, rather than always relying on tokens formed by uniform splitting, enables Vision Transformers to run much faster and perform better. TokenLearner is a learnable module that takes an image-like tensor (i.e., input) and generates a small set of tokens. This module could be placed at various different locations within the model of interest, significantly reducing the number of tokens to be handled in all subsequent layers. The experiments demonstrate that having TokenLearner saves memory and computation by half or more without damaging classification performance, and because of its ability to adapt to inputs, it even increases the accuracy.

The TokenLearner
We implement TokenLearner using a straightforward spatial attention approach. In order to generate each learned token, we compute a spatial attention map highlighting regions-of-importance (using convolutional layers or MLPs). Such a spatial attention map is then applied to the input to weight each region differently (and discard unnecessary regions), and the result is spatially pooled to generate the final learned tokens. This is repeated multiple times in parallel, resulting in a few (~10) tokens out of the original input. This can also be viewed as performing a soft-selection of the pixels based on the weight values, followed by global average pooling. Note that the functions to compute the attention maps are governed by different sets of learnable parameters, and are trained in an end-to-end fashion. This allows the attention functions to be optimized in capturing different spatial information in the input. The figure below illustrates the process.

The TokenLearner module learns to generate a spatial attention map for each output token, and uses it to abstract the input to tokenize. In practice, multiple spatial attention functions are learned, are applied to the input, and generate different token vectors in parallel.

As a result, instead of processing fixed, uniformly tokenized inputs, TokenLearner enables models to process a smaller number of tokens that are relevant to the specific recognition task. That is, (1) we enable adaptive tokenization so that the tokens can be dynamically selected conditioned on the input, and (2) this effectively reduces the total number of tokens, greatly reducing the computation performed by the network. These dynamically and adaptively generated tokens can be used in standard transformer architectures such as ViT for images and ViViT for videos.

Where to Place TokenLearner
After building the TokenLearner module, we had to determine where to place it. We first tried placing it at different locations within the standard ViT architecture with 224×224 images. The number of tokens TokenLearner generated was 8 and 16, much less than 196 or 576 tokens the standard ViTs use. The below figure shows ImageNet few-shot classification accuracies and FLOPS of the models with TokenLearner inserted at various relative locations within ViT B/16, which is the base model with 12 attention layers operating on 16×16 patch tokens.

Top: ImageNet 5-shot transfer accuracy with JFT 300M pre-training, with respect to the relative TokenLearner locations within ViT B/16. Location 0 means TokenLearner is placed before any Transformer layer. Base is the original ViT B/16. Bottom: Computation, measured in terms of billions of floating point operations (GFLOPS), per relative TokenLearner location.

We found that inserting TokenLearner after the initial quarter of the network (at 1/4) achieves almost identical accuracies as the baseline, while reducing the computation to less than a third of the baseline. In addition, placing TokenLearner at the later layer (after 3/4 of the network) achieves even better performance compared to not using TokenLearner while performing faster, thanks to its adaptiveness. Due to the large difference between the number of tokens before and after TokenLearner (e.g., 196 before and 8 after), the relative computation of the transformers after the TokenLearner module becomes almost negligible.

Comparing Against ViTs
We compared the standard ViT models with TokenLearner against those without it while following the same setting on ImageNet few-shot transfer. TokenLearner was placed in the middle of each ViT model at various locations such as at 1/2 and at 3/4. The below figure shows the performance/computation trade-off of the models with and without TokenLearner.

Performance of various versions of ViT models with and without TokenLearner, on ImageNet classification. The models were pre-trained with JFT 300M. The closer a model is to the top-left of each graph the better, meaning that it runs faster and performs better. Observe how TokenLearner models perform better than ViT in terms of both accuracy and computation.

We also inserted TokenLearner within larger ViT models, and compared them against the giant ViT G/14 model. Here, we applied TokenLearner to ViT L/10 and L/8, which are the ViT models with 24 attention layers taking 10×10 (or 8×8) patches as initial tokens. The below figure shows that despite using many fewer parameters and less computation, TokenLearner performs comparably to the giant G/14 model with 48 layers.

Left: Classification accuracy of large-scale TokenLearner models compared to ViT G/14 on ImageNet datasets. Right: Comparison of the number of parameters and FLOPS.

High-Performing Video Models
Video understanding is one of the key challenges in computer vision, so we evaluated TokenLearner on multiple video classification datasets. This was done by adding TokenLearner into Video Vision Transformers (ViViT), which can be thought of as a spatio-temporal version of ViT. TokenLearner learned 8 (or 16) tokens per timestep.

When combined with ViViT, TokenLearner obtains state-of-the-art (SOTA) performance on multiple popular video benchmarks, including Kinetics-400, Kinetics-600, Charades, and AViD, outperforming the previous Transformer models on Kinetics-400 and Kinetics-600 as well as previous CNN models on Charades and AViD.

Models with TokenLearner outperform state-of-the-art on popular video benchmarks (captured from Nov. 2021). Left: popular video classification tasks. Right: comparison to ViViT models.
Visualization of the spatial attention maps in TokenLearner, over time. As the person is moving in the scene, TokenLearner pays attention to different spatial locations to tokenize.

Conclusion
While Vision Transformers serve as powerful models for computer vision, a large number of tokens and their associated computation amount have been a bottleneck for their application to larger images and longer videos. In this project, we illustrate that retaining such a large number of tokens and fully processing them over the entire set of layers is not necessary. Further, we demonstrate that by learning a module that extracts tokens adaptively based on the input image allows attaining even better performance while saving compute. The proposed TokenLearner was particularly effective in video representation learning tasks, which we confirmed with multiple public datasets. A preprint of our work as well as code are publicly available.

Acknowledgement
We thank our co-authors: AJ Piergiovanni, Mostafa Dehghani, and Anelia Angelova. We also thank the Robotics at Google team members for the motivating discussions.

Read More