WeatherBench 2: A benchmark for the next generation of data-driven weather models

WeatherBench 2: A benchmark for the next generation of data-driven weather models

In 1950, weather forecasting started its digital revolution when researchers used the first programmable, general-purpose computer ENIAC to solve mathematical equations describing how weather evolves. In the more than 70 years since, continuous advancements in computing power and improvements to the model formulations have led to steady gains in weather forecast skill: a 7-day forecast today is about as accurate as a 5-day forecast in 2000 and a 3-day forecast in 1980. While improving forecast accuracy at the pace of approximately one day per decade may not seem like a big deal, every day improved is important in far reaching use cases, such as for logistics planning, disaster management, agriculture and energy production. This “quiet” revolution has been tremendously valuable to society, saving lives and providing economic value across many sectors.

Now we are seeing the start of yet another revolution in weather forecasting, this time fueled by advances in machine learning (ML). Rather than hard-coding approximations of the physical equations, the idea is to have algorithms learn how weather evolves from looking at large volumes of past weather data. Early attempts at doing so go back to 2018 but the pace picked up considerably in the last two years when several large ML models demonstrated weather forecasting skill comparable to the best physics-based models. Google’s MetNet [1, 2], for instance, demonstrated state-of-the-art capabilities for forecasting regional weather one day ahead. For global prediction, Google DeepMind created GraphCast, a graph neural network to make 10 day predictions at a horizontal resolution of 25 km, competitive with the best physics-based models in many skill metrics.

Apart from potentially providing more accurate forecasts, one key advantage of such ML methods is that, once trained, they can create forecasts in a matter of minutes on inexpensive hardware. In contrast, traditional weather forecasts require large super-computers that run for hours every day. Clearly, ML represents a tremendous opportunity for the weather forecasting community. This has also been recognized by leading weather forecasting centers, such as the European Centre for Medium-Range Weather Forecasts’ (ECMWF) machine learning roadmap or the National Oceanic and Atmospheric Administration’s (NOAA) artificial intelligence strategy.

To ensure that ML models are trusted and optimized for the right goal, forecast evaluation is crucial. Evaluating weather forecasts isn’t straightforward, however, because weather is an incredibly multi-faceted problem. Different end-users are interested in different properties of forecasts, for example, renewable energy producers care about wind speeds and solar radiation, while crisis response teams are concerned about the track of a potential cyclone or an impending heat wave. In other words, there is no single metric to determine what a “good” weather forecast is, and the evaluation has to reflect the multi-faceted nature of weather and its downstream applications. Furthermore, differences in the exact evaluation setup — e.g., which resolution and ground truth data is used — can make it difficult to compare models. Having a way to compare novel and established methods in a fair and reproducible manner is crucial to measure progress in the field.

To this end, we are announcing WeatherBench 2 (WB2), a benchmark for the next generation of data-driven, global weather models. WB2 is an update to the original benchmark published in 2020, which was based on initial, lower-resolution ML models. The goal of WB2 is to accelerate the progress of data-driven weather models by providing a trusted, reproducible framework for evaluating and comparing different methodologies. The official website contains scores from several state-of-the-art models (at the time of writing, these are Keisler (2022), an early graph neural network, Google DeepMind’s GraphCast and Huawei’s Pangu-Weather, a transformer-based ML model). In addition, forecasts from ECMWF’s high-resolution and ensemble forecasting systems are included, which represent some of the best traditional weather forecasting models.

Making evaluation easier

The key component of WB2 is an open-source evaluation framework that allows users to evaluate their forecasts in the same manner as other baselines. Weather forecast data at high-resolutions can be quite large, making even evaluation a computational challenge. For this reason, we built our evaluation code on Apache Beam, which allows users to split computations into smaller chunks and evaluate them in a distributed fashion, for example using DataFlow on Google Cloud. The code comes with a quick-start guide to help people get up to speed.

Additionally, we provide most of the ground-truth and baseline data on Google Cloud Storage in cloud-optimized Zarr format at different resolutions, for example, a comprehensive copy of the ERA5 dataset used to train most ML models. This is part of a larger Google effort to provide analysis-ready, cloud-optimized weather and climate datasets to the research community and beyond. Since downloading these data from the respective archives and converting them can be time-consuming and compute-intensive, we hope that this should considerably lower the entry barrier for the community.

Assessing forecast skill

Together with our collaborators from ECMWF, we defined a set of headline scores that best capture the quality of global weather forecasts. As the figure below shows, several of the ML-based forecasts have lower errors than the state-of-the-art physical models on deterministic metrics. This holds for a range of variables and regions, and underlines the competitiveness and promise of ML-based approaches.

This scorecard shows the skill of different models compared to ECMWF’s Integrated Forecasting System (IFS), one of the best physics-based weather forecasts, for several variables. IFS forecasts are evaluated against IFS analysis. All other models are evaluated against ERA5. The order of ML models reflects publication date.

Toward reliable probabilistic forecasts

However, a single forecast often isn’t enough. Weather is inherently chaotic because of the butterfly effect. For this reason, operational weather centers now run ~50 slightly perturbed realizations of their model, called an ensemble, to estimate the forecast probability distribution across various scenarios. This is important, for example, if one wants to know the likelihood of extreme weather.

Creating reliable probabilistic forecasts will be one of the next key challenges for global ML models. Regional ML models, such as Google’s MetNet already estimate probabilities. To anticipate this next generation of global models, WB2 already provides probabilistic metrics and baselines, among them ECMWF’s IFS ensemble, to accelerate research in this direction.

As mentioned above, weather forecasting has many aspects, and while the headline metrics try to capture the most important aspects of forecast skill, they are by no means sufficient. One example is forecast realism. Currently, many ML forecast models tend to “hedge their bets” in the face of the intrinsic uncertainty of the atmosphere. In other words, they tend to predict smoothed out fields that give lower average error but do not represent a realistic, physically consistent state of the atmosphere. An example of this can be seen in the animation below. The two data-driven models, Pangu-Weather and GraphCast (bottom), predict the large-scale evolution of the atmosphere remarkably well. However, they also have less small-scale structure compared to the ground truth or the physical forecasting model IFS HRES (top). In WB2 we include a range of these case studies and also a spectral metric that quantifies such blurring.

Forecasts of a front passing through the continental United States initialized on January 3, 2020. Maps show temperature at a pressure level of 850 hPa (roughly equivalent to an altitude of 1.5km) and geopotential at a pressure level of 500 hPa (roughly 5.5 km) in contours. ERA5 is the corresponding ground-truth analysis, IFS HRES is ECMWF’s physics-based forecasting model.

Conclusion

WeatherBench 2 will continue to evolve alongside ML model development. The official website will be updated with the latest state-of-the-art models. (To submit a model, please follow these instructions). We also invite the community to provide feedback and suggestions for improvements through issues and pull requests on the WB2 GitHub page.

Designing evaluation well and targeting the right metrics is crucial in order to make sure ML weather models benefit society as quickly as possible. WeatherBench 2 as it is now is just the starting point. We plan to extend it in the future to address key issues for the future of ML-based weather forecasting. Specifically, we would like to add station observations and better precipitation datasets. Furthermore, we will explore the inclusion of nowcasting and subseasonal-to-seasonal predictions to the benchmark.

We hope that WeatherBench 2 can aid researchers and end-users as weather forecasting continues to evolve.

Acknowledgements

WeatherBench 2 is the result of collaboration across many different teams at Google and external collaborators at ECMWF. From ECMWF, we would like to thank Matthew Chantry, Zied Ben Bouallegue and Peter Dueben. From Google, we would like to thank the core contributors to the project: Stephan Rasp, Stephan Hoyer, Peter Battaglia, Alex Merose, Ian Langmore, Tyler Russell, Alvaro Sanchez, Antonio Lobato, Laurence Chiu, Rob Carver, Vivian Yang, Shreya Agrawal, Thomas Turnbull, Jason Hickey, Carla Bromberg, Jared Sisk, Luke Barrington, Aaron Bell, and Fei Sha. We also would like to thank Kunal Shah, Rahul Mahrsee, Aniket Rawat, and Satish Kumar. Thanks to John Anderson for sponsoring WeatherBench 2. Furthermore, we would like to thank Kaifeng Bi from the Pangu-Weather team and Ryan Keisler for their help in adding their models to WeatherBench 2.

Read More

Modeling and improving text stability in live captions

Modeling and improving text stability in live captions

Automatic speech recognition (ASR) technology has made conversations more accessible with live captions in remote conferencing software, mobile applications, and head-worn displays. However, to maintain real-time responsiveness, live caption systems often display interim predictions that are updated as new utterances are received. This can cause text instability (a “flicker” where previously displayed text is updated, shown in the captions on the left in the video below), which can impair users’ reading experience due to distraction, fatigue, and difficulty following the conversation.

In “Modeling and Improving Text Stability in Live Captions”, presented at ACM CHI 2023, we formalize this problem of text stability through a few key contributions. First, we quantify the text instability by employing a vision-based flicker metric that uses luminance contrast and discrete Fourier transform. Second, we also introduce a stability algorithm to stabilize the rendering of live captions via tokenized alignment, semantic merging, and smooth animation. Finally, we conducted a user study (N=123) to understand viewers’ experience with live captioning. Our statistical analysis demonstrates a strong correlation between our proposed flicker metric and viewers’ experience. Furthermore, it shows that our proposed stabilization techniques significantly improves viewers’ experience (e.g., the captions on the right in the video above).

Raw ASR captions vs. stabilized captions

Metric

Inspired by previous work, we propose a flicker-based metric to quantify text stability and objectively evaluate the performance of live captioning systems. Specifically, our goal is to quantify the flicker in a grayscale live caption video. We achieve this by comparing the difference in luminance between individual frames (frames in the figures below) that constitute the video. Large visual changes in luminance are obvious (e.g., addition of the word “bright” in the figure on the bottom), but subtle changes (e.g., update from “… this gold. Nice..” to “… this. Gold is nice”) may be difficult to discern for readers. However, converting the change in luminance to its constituting frequencies exposes both the obvious and subtle changes.

Thus, for each pair of contiguous frames, we convert the difference in luminance into its constituting frequencies using discrete Fourier transform. We then sum over each of the low and high frequencies to quantify the flicker in this pair. Finally, we average over all of the frame-pairs to get a per-video flicker.

For instance, we can see below that two identical frames (top) yield a flicker of 0, while two non-identical frames (bottom) yield a non-zero flicker. It is worth noting that higher values of the metric indicate high flicker in the video and thus, a worse user experience than lower values of the metric.

Illustration of the flicker metric between two identical frames.
Illustration of the flicker between two non-identical frames.

Stability algorithm

To improve the stability of live captions, we propose an algorithm that takes as input already rendered sequence of tokens (e.g., “Previous” in the figure below) and the new sequence of ASR predictions, and outputs an updated stabilized text (e.g., “Updated text (with stabilization)” below). It considers both the natural language understanding (NLU) aspect as well as the ergonomic aspect (display, layout, etc.) of the user experience in deciding when and how to produce a stable updated text. Specifically, our algorithm performs tokenized alignment, semantic merging, and smooth animation to achieve this goal. In what follows, a token is defined as a word or punctuation produced by ASR.

We show (a) the previously already rendered text, (b) the baseline layout of updated text without our merging algorithm, and (c) the updated text as generated by our stabilization algorithm.

Our algorithm address the challenge of producing stabilized updated text by first identifying three classes of changes (highlighted in red, green, and blue below):

  1. Red: Addition of tokens to the end of previously rendered captions (e.g., “How about”).
  2. Green: Addition / deletion of tokens, in the middle of already rendered captions.
    • B1: Addition of tokens (e.g., “I” and “friends”). These may or may not affect the overall comprehension of the captions, but may lead to layout change. Such layout changes are not desired in live captions as they cause significant jitter and poorer user experience. Here “I” does not add to the comprehension but “friends” does. Thus, it is important to balance updates with stability specially for B1 type tokens.
    • B2: Removal of tokens, e.g., “in” is removed in the updated sentence.
  3. Blue: Re-captioning of tokens: This includes token edits that may or may not have an impact on the overall comprehension of the captions.
  • C1: Proper nouns like “disney land” are updated to “Disneyland”.
  • C2: Grammatical shorthands like “it’s” are updated to “It was”.
Classes of changes between previously displayed and updated text.

Alignment, merging, and smoothing

To maximize text stability, our goal is to align the old sequence with the new sequence using updates that make minimal changes to the existing layout while ensuring accurate and meaningful captions. To achieve this, we leverage a variant of the Needleman-Wunsch algorithm with dynamic programming to merge the two sequences depending on the class of tokens as defined above:

  • Case A tokens: We directly add case A tokens, and line breaks as needed to fit the updated captions.
  • Case B tokens: Our preliminary studies showed that users preferred stability over accuracy for previously displayed captions. Thus, we only update case B tokens if the updates do not break an existing line layout.
  • Case C tokens: We compare the semantic similarity of case C tokens by transforming original and updated sentences into sentence embeddings, measuring their dot-product, and updating them only if they are semantically different (similarity < 0.85) and the update will not cause new line breaks.

Finally, we leverage animations to reduce visual jitter. We implement smooth scrolling and fading of newly added tokens to further stabilize the overall layout of the live captions.

User evaluation

We conducted a user study with 123 participants to (1) examine the correlation of our proposed flicker metric with viewers’ experience of the live captions, and (2) assess the effectiveness of our stabilization techniques.

We manually selected 20 videos in YouTube to obtain a broad coverage of topics including video conferences, documentaries, academic talks, tutorials, news, comedy, and more. For each video, we selected a 30-second clip with at least 90% speech.

We prepared four types of renderings of live captions to compare:

  1. Raw ASR: raw speech-to-text results from a speech-to-text API.
  2. Raw ASR + thresholding: only display interim speech-to-text result if its confidence score is higher than 0.85.
  3. Stabilized captions: captions using our algorithm described above with alignment and merging.
  4. Stabilized and smooth captions: stabilized captions with smooth animation (scrolling + fading) to assess whether softened display experience helps improve the user experience.

We collected user ratings by asking the participants to watch the recorded live captions and rate their assessments of comfort, distraction, ease of reading, ease of following the video, fatigue, and whether the captions impaired their experience.

Correlation between flicker metric and user experience

We calculated Spearman’s coefficient between the flicker metric and each of the behavioral measurements (values range from -1 to 1, where negative values indicate a negative relationship between the two variables, positive values indicate a positive relationship, and zero indicates no relationship). Shown below, our study demonstrates statistically significant (𝑝 < 0.001) correlations between our flicker metric and users’ ratings. The absolute values of the coefficient are around 0.3, indicating a moderate relationship.

Behavioral Measurement         Correlation to Flickering Metric*
Comfort -0.29

Distraction 0.33

Easy to read -0.31

Easy to follow videos -0.29

Fatigue 0.36

Impaired Experience 0.31

Spearman correlation tests of our proposed flickering metric. *p < 0.001.

Stabilization of live captions

Our proposed technique (stabilized smooth captions) received consistently better ratings, significant as measured by the Mann-Whitney U test (p < 0.01 in the figure below), in five out of six aforementioned survey statements. That is, users considered the stabilized captions with smoothing to be more comfortable and easier to read, while feeling less distraction, fatigue, and impairment to their experience than other types of rendering.

User ratings from 1 (Strongly Disagree) – 7 (Strongly Agree) on survey statements. (**: p<0.01, ***: p<0.001; ****: p<0.0001; ns: non-significant)

Conclusion and future direction

Text instability in live captioning significantly impairs users’ reading experience. This work proposes a vision-based metric to model caption stability that statistically significantly correlates with users’ experience, and an algorithm to stabilize the rendering of live captions. Our proposed solution can be potentially integrated into existing ASR systems to enhance the usability of live captions for a variety of users, including those with translation needs or those with hearing accessibility needs.

Our work represents a substantial step towards measuring and improving text stability. This can be evolved to include language-based metrics that focus on the consistency of the words and phrases used in live captions over time. These metrics may provide a reflection of user discomfort as it relates to language comprehension and understanding in real-world scenarios. We are also interested in conducting eye-tracking studies (e.g., videos shown below) to track viewers’ gaze patterns, such as eye fixation and saccades, allowing us to better understand the types of errors that are most distracting and how to improve text stability for those.

Illustration of tracking a viewer’s gaze when reading raw ASR captions.

Illustration of tracking a viewer’s gaze when reading stabilized and smoothed captions.

By improving text stability in live captions, we can create more effective communication tools and improve how people connect in everyday conversations in familiar or, through translation, unfamiliar languages.

Acknowledgements

This work is a collaboration across multiple teams at Google. Key contributors include Xingyu “Bruce” Liu, Jun Zhang, Leonardo Ferrer, Susan Xu, Vikas Bahirwani, Boris Smus, Alex Olwal, and Ruofei Du. We wish to extend our thanks to our colleagues who provided assistance, including Nishtha Bhatia, Max Spear, and Darcy Philippon. We would also like to thank Lin Li, Evan Parker, and CHI 2023 reviewers.

Read More

SayTap: Language to quadrupedal locomotion

SayTap: Language to quadrupedal locomotion

Simple and effective interaction between human and quadrupedal robots paves the way towards creating intelligent and capable helper robots, forging a future where technology enhances our lives in ways beyond our imagination. Key to such human-robot interaction systems is enabling quadrupedal robots to respond to natural language instructions. Recent developments in large language models (LLMs) have demonstrated the potential to perform high-level planning. Yet, it remains a challenge for LLMs to comprehend low-level commands, such as joint angle targets or motor torques, especially for inherently unstable legged robots, necessitating high-frequency control signals. Consequently, most existing work presumes the provision of high-level APIs for LLMs to dictate robot behavior, inherently limiting the system’s expressive capabilities.

In “SayTap: Language to Quadrupedal Locomotion”, we propose an approach that uses foot contact patterns (which refer to the sequence and manner in which a four-legged agent places its feet on the ground while moving) as an interface to bridge human commands in natural language and a locomotion controller that outputs low-level commands. This results in an interactive quadrupedal robot system that allows users to flexibly craft diverse locomotion behaviors (e.g., a user can ask the robot to walk, run, jump or make other movements using simple language). We contribute an LLM prompt design, a reward function, and a method to expose the SayTap controller to the feasible distribution of contact patterns. We demonstrate that SayTap is a controller capable of achieving diverse locomotion patterns that can be transferred to real robot hardware.

SayTap method

The SayTap approach uses a contact pattern template, which is a 4 X T matrix of 0s and 1s, with 0s representing an agent’s feet in the air and 1s for feet on the ground. From top to bottom, each row in the matrix gives the foot contact patterns of the front left (FL), front right (FR), rear left (RL) and rear right (RR) feet. SayTap’s control frequency is 50 Hz, so each 0 or 1 lasts 0.02 seconds. In this work, a desired foot contact pattern is defined by a cyclic sliding window of size Lw and of shape 4 X Lw. The sliding window extracts from the contact pattern template four foot ground contact flags, which indicate if a foot is on the ground or in the air between t + 1 and t + Lw. The figure below provides an overview of the SayTap method.

SayTap introduces these desired foot contact patterns as a new interface between natural language user commands and the locomotion controller. The locomotion controller is used to complete the main task (e.g., following specified velocities) and to place the robot’s feet on the ground at the specified time, such that the realized foot contact patterns are as close to the desired contact patterns as possible. To achieve this, the locomotion controller takes the desired foot contact pattern at each time step as its input in addition to the robot’s proprioceptive sensory data (e.g., joint positions and velocities) and task-related inputs (e.g., user-specified velocity commands). We use deep reinforcement learning to train the locomotion controller and represent it as a deep neural network. During controller training, a random generator samples the desired foot contact patterns, the policy is then optimized to output low-level robot actions to achieve the desired foot contact pattern. Then at test time a LLM translates user commands into foot contact patterns.

SayTap approach overview.

SayTap uses foot contact patterns (e.g., 0 and 1 sequences for each foot in the inset, where 0s are foot in the air and 1s are foot on the ground) as an interface that bridges natural language user commands and low-level control commands. With a reinforcement learning-based locomotion controller that is trained to realize the desired contact patterns, SayTap allows a quadrupedal robot to take both simple and direct instructions (e.g., “Trot forward slowly.”) as well as vague user commands (e.g., “Good news, we are going to a picnic this weekend!”) and react accordingly.

We demonstrate that the LLM is capable of accurately mapping user commands into foot contact pattern templates in specified formats when given properly designed prompts, even in cases when the commands are unstructured or vague. In training, we use a random pattern generator to produce contact pattern templates that are of various pattern lengths T, foot-ground contact ratios within a cycle based on a given gait type G, so that the locomotion controller gets to learn on a wide distribution of movements leading to better generalization. See the paper for more details.

Results

With a simple prompt that contains only three in-context examples of commonly seen foot contact patterns, an LLM can translate various human commands accurately into contact patterns and even generalize to those that do not explicitly specify how the robot should react.

SayTap prompts are concise and consist of four components: (1) general instruction that describes the tasks the LLM should accomplish; (2) gait definition that reminds the LLM of basic knowledge about quadrupedal gaits and how they can be related to emotions; (3) output format definition; and (4) examples that give the LLM chances to learn in-context. We also specify five velocities that allow a robot to move forward or backward, fast or slow, or remain still.

General instruction block
You are a dog foot contact pattern expert.
Your job is to give a velocity and a foot contact pattern based on the input.
You will always give the output in the correct format no matter what the input is.

Gait definition block
The following are description about gaits:
1. Trotting is a gait where two diagonally opposite legs strike the ground at the same time.
2. Pacing is a gait where the two legs on the left/right side of the body strike the ground at the same time.
3. Bounding is a gait where the two front/rear legs strike the ground at the same time. It has a longer suspension phase where all feet are off the ground, for example, for at least 25% of the cycle length. This gait also gives a happy feeling.

Output format definition block
The following are rules for describing the velocity and foot contact patterns:
1. You should first output the velocity, then the foot contact pattern.
2. There are five velocities to choose from: [-1.0, -0.5, 0.0, 0.5, 1.0].
3. A pattern has 4 lines, each of which represents the foot contact pattern of a leg.
4. Each line has a label. "FL" is front left leg, "FR" is front right leg, "RL" is rear left leg, and "RR" is rear right leg.
5. In each line, "0" represents foot in the air, "1" represents foot on the ground.

Example block
Input: Trot slowly
Output: 0.5
FL: 11111111111111111000000000
FR: 00000000011111111111111111
RL: 00000000011111111111111111
RR: 11111111111111111000000000

Input: Bound in place
Output: 0.0
FL: 11111111111100000000000000
FR: 11111111111100000000000000
RL: 00000011111111111100000000
RR: 00000011111111111100000000

Input: Pace backward fast
Output: -1.0
FL: 11111111100001111111110000
FR: 00001111111110000111111111
RL: 11111111100001111111110000
RR: 00001111111110000111111111

Input:


SayTap prompt to the LLM. Texts in blue are used for illustration and are not input to LLM.

Following simple and direct commands

We demonstrate in the videos below that the SayTap system can successfully perform tasks where the commands are direct and clear. Although some commands are not covered by the three in-context examples, we are able to guide the LLM to express its internal knowledge from the pre-training phase via the “Gait definition block” (see the second block in our prompt above) in the prompt.


Following unstructured or vague commands

But what is more interesting is SayTap’s ability to process unstructured and vague instructions. With only a little hint in the prompt to connect certain gaits with general impressions of emotions, the robot bounds up and down when hearing exciting messages, like “We are going to a picnic!” Furthermore, it also presents the scenes accurately (e.g., moving quickly with its feet barely touching the ground when told the ground is very hot).




Conclusion and future work

We present SayTap, an interactive system for quadrupedal robots that allows users to flexibly craft diverse locomotion behaviors. SayTap introduces desired foot contact patterns as a new interface between natural language and the low-level controller. This new interface is straightforward and flexible, moreover, it allows a robot to follow both direct instructions and commands that do not explicitly state how the robot should react.

One interesting direction for future work is to test if commands that imply a specific feeling will allow the LLM to output a desired gait. In the gait definition block shown in the results section above, we provide a sentence that connects a happy mood with bounding gaits. We believe that providing more information can augment the LLM’s interpretations (e.g., implied feelings). In our evaluation, the connection between a happy feeling and a bounding gait led the robot to act vividly when following vague human commands. Another interesting direction for future work is to introduce multi-modal inputs, such as videos and audio. Foot contact patterns translated from those signals will, in theory, still work with our pipeline and will unlock many more interesting use cases.

Acknowledgements

Yujin Tang, Wenhao Yu, Jie Tan, Heiga Zen, Aleksandra Faust and Tatsuya Harada conducted this research. This work was conceived and performed while the team was in Google Research and will be continued at Google DeepMind. The authors would like to thank Tingnan Zhang, Linda Luu, Kuang-Huei Lee, Vincent Vanhoucke and Douglas Eck for their valuable discussions and technical support in the experiments.

Read More

RO-ViT: Region-aware pre-training for open-vocabulary object detection with vision transformers

RO-ViT: Region-aware pre-training for open-vocabulary object detection with vision transformers


The ability to detect objects in the visual world is crucial for computer vision and machine intelligence, enabling applications like adaptive autonomous agents and versatile shopping systems. However, modern object detectors are limited by the manual annotations of their training data, resulting in a vocabulary size significantly smaller than the vast array of objects encountered in reality. To overcome this, the open-vocabulary detection task (OVD) has emerged, utilizing image-text pairs for training and incorporating new category names at test time by associating them with the image content. By treating categories as text embeddings, open-vocabulary detectors can predict a wide range of unseen objects. Various techniques such as image-text pre-training, knowledge distillation, pseudo labeling, and frozen models, often employing convolutional neural network (CNN) backbones, have been proposed. With the growing popularity of vision transformers (ViTs), it is important to explore their potential for building proficient open-vocabulary detectors.

The existing approaches assume the availability of pre-trained vision-language models (VLMs) and focus on fine-tuning or distillation from these models to address the disparity between image-level pre-training and object-level fine-tuning. However, as VLMs are primarily designed for image-level tasks like classification and retrieval, they do not fully leverage the concept of objects or regions during the pre-training phase. Thus, it could be beneficial for open-vocabulary detection if we build locality information into the image-text pre-training.

In “RO-ViT: Region-Aware Pretraining for Open-Vocabulary Object Detection with Vision Transformers”, presented at CVPR 2023, we introduce a simple method to pre-train vision transformers in a region-aware manner to improve open-vocabulary detection. In vision transformers, positional embeddings are added to image patches to encode information about the spatial position of each patch within the image. Standard pre-training typically uses full-image positional embeddings, which does not generalize well to detection tasks. Thus, we propose a new positional embedding scheme, called “cropped positional embedding”, that better aligns with the use of region crops in detection fine-tuning. In addition, we replace the softmax cross entropy loss with focal loss in contrastive image-text learning, allowing us to learn from more challenging and informative examples. Finally, we leverage recent advances in novel object proposals to enhance open-vocabulary detection fine-tuning, which is motivated by the observation that existing methods often miss novel objects during the proposal stage due to overfitting to foreground categories. We are also releasing the code here.

Region-aware image-text pre-training

Existing VLMs are trained to match an image as a whole to a text description. However, we observe there is a mismatch between the way the positional embeddings are used in the existing contrastive pre-training approaches and open-vocabulary detection. The positional embeddings are important to transformers as they provide the information of where each element in the set comes from. This information is often useful for downstream recognition and localization tasks. Pre-training approaches typically apply full-image positional embeddings during training, and use the same positional embeddings for downstream tasks, e.g., zero-shot recognition. However, the recognition occurs at region-level for open-vocabulary detection fine-tuning, which requires the full-image positional embeddings to generalize to regions that they never see during the pre-training.

To address this, we propose cropped positional embeddings (CPE). With CPE, we upsample positional embeddings from the image size typical for pre-training, e.g., 224×224 pixels, to that typical for detection tasks, e.g., 1024×1024 pixels. Then we randomly crop and resize a region, and use it as the image-level positional embeddings during pre-training. The position, scale, and aspect ratio of the crop is randomly sampled. Intuitively, this causes the model to view an image not as a full image in itself, but as a region crop from some larger unknown image. This better matches the downstream use case of detection where recognition occurs at region- rather than image-level.

For the pre-training, we propose cropped positional embedding (CPE) which randomly crops and resizes a region of positional embeddings instead of using the whole-image positional embedding (PE). In addition, we use focal loss instead of the common softmax cross entropy loss for contrastive learning.

We also find it beneficial to learn from hard examples with a focal loss. Focal loss enables finer control over how hard examples are weighted than what the softmax cross entropy loss can provide. We adopt the focal loss and replace it with the softmax cross entropy loss in both image-to-text and text-to-image losses. Both CPE and focal loss introduce no extra parameters and minimal computation costs.

Open-vocabulary detector fine-tuning

An open-vocabulary detector is trained with the detection labels of ‘base’ categories, but needs to detect the union of ‘base’ and ‘novel’ (unlabeled) categories at test time. Despite the backbone features pre-trained from the vast open-vocabulary data, the added detector layers (neck and heads) are newly trained with the downstream detection dataset. Existing approaches often miss novel/unlabeled objects in the object proposal stage because the proposals tend to classify them as background. To remedy this, we leverage recent advances in a novel object proposal method and adopt the localization quality-based objectness (i.e., centerness score) instead of object-or-not binary classification score, which is combined with the detection score. During training, we compute the detection scores for each detected region as the cosine similarity between the region’s embedding (computed via RoI-Align operation) and the text embeddings of the base categories. At test time, we append the text embeddings of novel categories, and the detection score is now computed with the union of the base and novel categories.

The pre-trained ViT backbone is transferred to the downstream open-vocabulary detection by replacing the global average pooling with detector heads. The RoI-Align embeddings are matched with the cached category embeddings to obtain the VLM score, which is combined with the detection score into the open-vocabulary detection score.

Results

We evaluate RO-ViT on the LVIS open-vocabulary detection benchmark. At the system-level, our best model achieves 33.6 box average precision on rare categories (APr) and 32.1 mask APr, which outperforms the best existing ViT-based approach OWL-ViT by 8.0 APr and the best CNN-based approach ViLD-Ens by 5.8 mask APr. It also exceeds the performance of many other approaches based on knowledge distillation, pre-training, or joint training with weak supervision.

RO-ViT outperforms both the state-of-the-art (SOTA) ViT-based and CNN-based methods on LVIS open-vocabulary detection benchmark. We show mask AP on rare categories (APr) , except for SOTA ViT-based (OwL-ViT) where we show box AP.

Apart from evaluating region-level representation through open-vocabulary detection, we evaluate the image-level representation of RO-ViT in image-text retrieval through the MS-COCO and Flickr30K benchmarks. Our model with 303M ViT outperforms the state-of-the-art CoCa model with 1B ViT on MS COCO, and is on par on Flickr30K. This shows that our pre-training method not only improves the region-level representation but also the global image-level representation for retrieval.

We show zero-shot image-text retrieval on MS COCO and Flickr30K benchmarks, and compare with dual-encoder methods. We report recall@1 (top-1 recall) on image-to-text (I2T) and text-to-image (T2I) retrieval tasks. RO-ViT outperforms the state-of-the-art CoCa with the same backbone.
RO-ViT open-vocabulary detection on LVIS. We only show the novel categories for clarity. RO-ViT detects many novel categories that it has never seen during detection training: “fishbowl”, “sombrero”, “persimmon”, “gargoyle”.

Visualization of positional embeddings

We visualize and compare the learned positional embeddings of RO-ViT with the baseline. Each tile is the cosine similarity between positional embeddings of one patch and all other patches. For example, the tile in the top-left corner (marked in red) visualizes the similarity between the positional embedding of the location (row=1, column=1) and those positional embeddings of all other locations in 2D. The brightness of the patch indicates how close the learned positional embeddings of different locations are. RO-ViT forms more distinct clusters at different patch locations showing symmetrical global patterns around the center patch.

Each tile shows the cosine similarity between the positional embedding of the patch (at the indicated row-column position) and the positional embeddings of all other patches. ViT-B/16 backbone is used.

Conclusion

We present RO-ViT, a contrastive image-text pre-training framework to bridge the gap between image-level pre-training and open-vocabulary detection fine-tuning. Our methods are simple, scalable, and easy to apply to any contrastive backbones with minimal computation overhead and no increase in parameters. RO-ViT achieves the state-of-the-art on LVIS open-vocabulary detection benchmark and on the image-text retrieval benchmarks, showing the learned representation is not only beneficial at region-level but also highly effective at the image-level. We hope this study can help the research on open-vocabulary detection from the perspective of image-text pre-training which can benefit both region-level and image-level tasks.

Acknowledgements

Dahun Kim, Anelia Angelova, and Weicheng Kuo conducted this work and are now at Google DeepMind. We would like to thank our colleagues at Google Research for their advice and helpful discussions.

Read More

Responsible AI at Google Research: Perception Fairness

Responsible AI at Google Research: Perception Fairness

Google’s Responsible AI research is built on a foundation of collaboration — between teams with diverse backgrounds and expertise, between researchers and product developers, and ultimately with the community at large. The Perception Fairness team drives progress by combining deep subject-matter expertise in both computer vision and machine learning (ML) fairness with direct connections to the researchers building the perception systems that power products across Google and beyond. Together, we are working to intentionally design our systems to be inclusive from the ground up, guided by Google’s AI Principles.

Perception Fairness research spans the design, development, and deployment of advanced multimodal models including the latest foundation and generative models powering Google’s products.

Our team’s mission is to advance the frontiers of fairness and inclusion in multimodal ML systems, especially related to foundation models and generative AI. This encompasses core technology components including classification, localization, captioning, retrieval, visual question answering, text-to-image or text-to-video generation, and generative image and video editing. We believe that fairness and inclusion can and should be top-line performance goals for these applications. Our research is focused on unlocking novel analyses and mitigations that enable us to proactively design for these objectives throughout the development cycle. We answer core questions, such as: How can we use ML to responsibly and faithfully model human perception of demographic, cultural, and social identities in order to promote fairness and inclusion? What kinds of system biases (e.g., underperforming on images of people with certain skin tones) can we measure and how can we use these metrics to design better algorithms? How can we build more inclusive algorithms and systems and react quickly when failures occur?

Measuring representation of people in media

ML systems that can edit, curate or create images or videos can affect anyone exposed to their outputs, shaping or reinforcing the beliefs of viewers around the world. Research to reduce representational harms, such as reinforcing stereotypes or denigrating or erasing groups of people, requires a deep understanding of both the content and the societal context. It hinges on how different observers perceive themselves, their communities, or how others are represented. There’s considerable debate in the field regarding which social categories should be studied with computational tools and how to do so responsibly. Our research focuses on working toward scalable solutions that are informed by sociology and social psychology, are aligned with human perception, embrace the subjective nature of the problem, and enable nuanced measurement and mitigation. One example is our research on differences in human perception and annotation of skin tone in images using the Monk Skin Tone scale.

Our tools are also used to study representation in large-scale content collections. Through our Media Understanding for Social Exploration (MUSE) project, we’ve partnered with academic researchers, nonprofit organizations, and major consumer brands to understand patterns in mainstream media and advertising content. We first published this work in 2017, with a co-authored study analyzing gender equity in Hollywood movies. Since then, we’ve increased the scale and depth of our analyses. In 2019, we released findings based on over 2.7 million YouTube advertisements. In the latest study, we examine representation across intersections of perceived gender presentation, perceived age, and skin tone in over twelve years of popular U.S. television shows. These studies provide insights for content creators and advertisers and further inform our own research.

An illustration (not actual data) of computational signals that can be analyzed at scale to reveal representational patterns in media collections. [Video Collection / Getty Images]

Moving forward, we’re expanding the ML fairness concepts on which we focus and the domains in which they are responsibly applied. Looking beyond photorealistic images of people, we are working to develop tools that model the representation of communities and cultures in illustrations, abstract depictions of humanoid characters, and even images with no people in them at all. Finally, we need to reason about not just who is depicted, but how they are portrayed — what narrative is communicated through the surrounding image content, the accompanying text, and the broader cultural context.

Analyzing bias properties of perceptual systems

Building advanced ML systems is complex, with multiple stakeholders informing various criteria that decide product behavior. Overall quality has historically been defined and measured using summary statistics (like overall accuracy) over a test dataset as a proxy for user experience. But not all users experience products in the same way.

Perception Fairness enables practical measurement of nuanced system behavior beyond summary statistics, and makes these metrics core to the system quality that directly informs product behaviors and launch decisions. This is often much harder than it seems. Distilling complex bias issues (e.g., disparities in performance across intersectional subgroups or instances of stereotype reinforcement) to a small number of metrics without losing important nuance is extremely challenging. Another challenge is balancing the interplay between fairness metrics and other product metrics (e.g., user satisfaction, accuracy, latency), which are often phrased as conflicting despite being compatible. It is common for researchers to describe their work as optimizing an “accuracy-fairness” tradeoff when in reality widespread user satisfaction is aligned with meeting fairness and inclusion objectives.

We built and released the MIAP dataset as part of Open Images, leveraging our research on perception of socially relevant concepts and detection of biased behavior in complex systems to create a resource that furthers ML fairness research in computer vision. Original photo credits — left: Boston Public Library; middle: jen robinson; right: Garin Fons; all used with permission under the CC- BY 2.0 license.

To these ends, our team focuses on two broad research directions. First, democratizing access to well-understood and widely-applicable fairness analysis tooling, engaging partner organizations in adopting them into product workflows, and informing leadership across the company in interpreting results. This work includes developing broad benchmarks, curating widely-useful high-quality test datasets and tooling centered around techniques such as sliced analysis and counterfactual testing — often building on the core representation signals work described earlier. Second, advancing novel approaches towards fairness analytics — including partnering with product efforts that may result in breakthrough findings or inform launch strategy.

Advancing AI responsibly

Our work does not stop with analyzing model behavior. Rather, we use this as a jumping-off point for identifying algorithmic improvements in collaboration with other researchers and engineers on product teams. Over the past year we’ve launched upgraded components that power Search and Memories features in Google Photos, leading to more consistent performance and drastically improving robustness through added layers that keep mistakes from cascading through the system. We are working on improving ranking algorithms in Google Images to diversify representation. We updated algorithms that may reinforce historical stereotypes, using additional signals responsibly, such that it’s more likely for everyone to see themselves reflected in Search results and find what they’re looking for.

This work naturally carries over to the world of generative AI, where models can create collections of images or videos seeded from image and text prompts and can answer questions about images and videos. We’re excited about the potential of these technologies to deliver new experiences to users and as tools to further our own research. To enable this, we’re collaborating across the research and responsible AI communities to develop guardrails that mitigate failure modes. We’re leveraging our tools for understanding representation to power scalable benchmarks that can be combined with human feedback, and investing in research from pre-training through deployment to steer the models to generate higher quality, more inclusive, and more controllable output. We want these models to inspire people, producing diverse outputs, translating concepts without relying on tropes or stereotypes, and providing consistent behaviors and responses across counterfactual variations of prompts.

Opportunities and ongoing work

Despite over a decade of focused work, the field of perception fairness technologies still seems like a nascent and fast-growing space, rife with opportunities for breakthrough techniques. We continue to see opportunities to contribute technical advances backed by interdisciplinary scholarship. The gap between what we can measure in images versus the underlying aspects of human identity and expression is large — closing this gap will require increasingly complex media analytics solutions. Data metrics that indicate true representation, situated in the appropriate context and heeding a diversity of viewpoints, remains an open challenge for us. Can we reach a point where we can reliably identify depictions of nuanced stereotypes, continually update them to reflect an ever-changing society, and discern situations in which they could be offensive? Algorithmic advances driven by human feedback point a promising path forward.

Recent focus on AI safety and ethics in the context of modern large model development has spurred new ways of thinking about measuring systemic biases. We are exploring multiple avenues to use these models — along with recent developments in concept-based explainability methods, causal inference methods, and cutting-edge UX research — to quantify and minimize undesired biased behaviors. We look forward to tackling the challenges ahead and developing technology that is built for everybody.

Acknowledgements

We would like to thank every member of the Perception Fairness team, and all of our collaborators.

Read More

How to compare a noisy quantum processor to a classical computer

How to compare a noisy quantum processor to a classical computer

A full-scale error-corrected quantum computer will be able to solve some problems that are impossible for classical computers, but building such a device is a huge endeavor. We are proud of the milestones that we have achieved toward a fully error-corrected quantum computer, but that large-scale computer is still some number of years away. Meanwhile, we are using our current noisy quantum processors as flexible platforms for quantum experiments.

In contrast to an error-corrected quantum computer, experiments in noisy quantum processors are currently limited to a few thousand quantum operations or gates, before noise degrades the quantum state. In 2019 we implemented a specific computational task called random circuit sampling on our quantum processor and showed for the first time that it outperformed state-of-the-art classical supercomputing.

Although they have not yet reached beyond-classical capabilities, we have also used our processors to observe novel physical phenomena, such as time crystals and Majorana edge modes, and have made new experimental discoveries, such as robust bound states of interacting photons and the noise-resilience of Majorana edge modes of Floquet evolutions.

We expect that even in this intermediate, noisy regime, we will find applications for the quantum processors in which useful quantum experiments can be performed much faster than can be calculated on classical supercomputers — we call these “computational applications” of the quantum processors. No one has yet demonstrated such a beyond-classical computational application. So as we aim to achieve this milestone, the question is: What is the best way to compare a quantum experiment run on such a quantum processor to the computational cost of a classical application?

We already know how to compare an error-corrected quantum algorithm to a classical algorithm. In that case, the field of computational complexity tells us that we can compare their respective computational costs — that is, the number of operations required to accomplish the task. But with our current experimental quantum processors, the situation is not so well defined.

In “Effective quantum volume, fidelity and computational cost of noisy quantum processing experiments”, we provide a framework for measuring the computational cost of a quantum experiment, introducing the experiment’s “effective quantum volume”, which is the number of quantum operations or gates that contribute to a measurement outcome. We apply this framework to evaluate the computational cost of three recent experiments: our random circuit sampling experiment, our experiment measuring quantities known as “out of time order correlators” (OTOCs), and a recent experiment on a Floquet evolution related to the Ising model. We are particularly excited about OTOCs because they provide a direct way to experimentally measure the effective quantum volume of a circuit (a sequence of quantum gates or operations), which is itself a computationally difficult task for a classical computer to estimate precisely. OTOCs are also important in nuclear magnetic resonance and electron spin resonance spectroscopy. Therefore, we believe that OTOC experiments are a promising candidate for a first-ever computational application of quantum processors.

Plot of computational cost and impact of some recent quantum experiments. While some (e.g., QC-QMC 2022) have had high impact and others (e.g., RCS 2023) have had high computational cost, none have yet been both useful and hard enough to be considered a “computational application.” We hypothesize that our future OTOC experiment could be the first to pass this threshold. Other experiments plotted are referenced in the text.

Random circuit sampling: Evaluating the computational cost of a noisy circuit

When it comes to running a quantum circuit on a noisy quantum processor, there are two competing considerations. On one hand, we aim to do something that is difficult to achieve classically. The computational cost — the number of operations required to accomplish the task on a classical computer — depends on the quantum circuit’s effective quantum volume: the larger the volume, the higher the computational cost, and the more a quantum processor can outperform a classical one.

But on the other hand, on a noisy processor, each quantum gate can introduce an error to the calculation. The more operations, the higher the error, and the lower the fidelity of the quantum circuit in measuring a quantity of interest. Under this consideration, we might prefer simpler circuits with a smaller effective volume, but these are easily simulated by classical computers. The balance of these competing considerations, which we want to maximize, is called the “computational resource”, shown below.

Graph of the tradeoff between quantum volume and noise in a quantum circuit, captured in a quantity called the “computational resource.” For a noisy quantum circuit, this will initially increase with the computational cost, but eventually, noise will overrun the circuit and cause it to decrease.

We can see how these competing considerations play out in a simple “hello world” program for quantum processors, known as random circuit sampling (RCS), which was the first demonstration of a quantum processor outperforming a classical computer. Any error in any gate is likely to make this experiment fail. Inevitably, this is a hard experiment to achieve with significant fidelity, and thus it also serves as a benchmark of system fidelity. But it also corresponds to the highest known computational cost achievable by a quantum processor. We recently reported the most powerful RCS experiment performed to date, with a low measured experimental fidelity of 1.7×10-3, and a high theoretical computational cost of ~1023. These quantum circuits had 700 two-qubit gates. We estimate that this experiment would take ~47 years to simulate in the world’s largest supercomputer. While this checks one of the two boxes needed for a computational application — it outperforms a classical supercomputer — it is not a particularly useful application per se.

OTOCs and Floquet evolution: The effective quantum volume of a local observable

There are many open questions in quantum many-body physics that are classically intractable, so running some of these experiments on our quantum processor has great potential. We typically think of these experiments a bit differently than we do the RCS experiment. Rather than measuring the quantum state of all qubits at the end of the experiment, we are usually concerned with more specific, local physical observables. Because not every operation in the circuit necessarily impacts the observable, a local observable’s effective quantum volume might be smaller than that of the full circuit needed to run the experiment.

We can understand this by applying the concept of a light cone from relativity, which determines which events in space-time can be causally connected: some events cannot possibly influence one another because information takes time to propagate between them. We say that two such events are outside their respective light cones. In a quantum experiment, we replace the light cone with something called a “butterfly cone,” where the growth of the cone is determined by the butterfly speed — the speed with which information spreads throughout the system. (This speed is characterized by measuring OTOCs, discussed later.) The effective quantum volume of a local observable is essentially the volume of the butterfly cone, including only the quantum operations that are causally connected to the observable. So, the faster information spreads in a system, the larger the effective volume and therefore the harder it is to simulate classically.

A depiction of the effective volume Veff of the gates contributing to the local observable B. A related quantity called the effective area Aeff is represented by the cross-section of the plane and the cone. The perimeter of the base corresponds to the front of information travel that moves with the butterfly velocity vB.

We apply this framework to a recent experiment implementing a so-called Floquet Ising model, a physical model related to the time crystal and Majorana experiments. From the data of this experiment, one can directly estimate an effective fidelity of 0.37 for the largest circuits. With the measured gate error rate of ~1%, this gives an estimated effective volume of ~100. This is much smaller than the light cone, which included two thousand gates on 127 qubits. So, the butterfly velocity of this experiment is quite small. Indeed, we argue that the effective volume covers only ~28 qubits, not 127, using numerical simulations that obtain a larger precision than the experiment. This small effective volume has also been corroborated with the OTOC technique. Although this was a deep circuit, the estimated computational cost is 5×1011, almost one trillion times less than the recent RCS experiment. Correspondingly, this experiment can be simulated in less than a second per data point on a single A100 GPU. So, while this is certainly a useful application, it does not fulfill the second requirement of a computational application: substantially outperforming a classical simulation.

Information scrambling experiments with OTOCs are a promising avenue for a computational application. OTOCs can tell us important physical information about a system, such as the butterfly velocity, which is critical for precisely measuring the effective quantum volume of a circuit. OTOC experiments with fast entangling gates offer a potential path for a first beyond-classical demonstration of a computational application with a quantum processor. Indeed, in our experiment from 2021 we achieved an effective fidelity of Feff ~ 0.06 with an experimental signal-to-noise ratio of ~1, corresponding to an effective volume of ~250 gates and a computational cost of 2×1012.

While these early OTOC experiments are not sufficiently complex to outperform classical simulations, there is a deep physical reason why OTOC experiments are good candidates for the first demonstration of a computational application. Most of the interesting quantum phenomena accessible to near-term quantum processors that are hard to simulate classically correspond to a quantum circuit exploring many, many quantum energy levels. Such evolutions are typically chaotic and standard time-order correlators (TOC) decay very quickly to a purely random average in this regime. There is no experimental signal left. This does not happen for OTOC measurements, which allows us to grow complexity at will, only limited by the error per gate. We anticipate that a reduction of the error rate by half would double the computational cost, pushing this experiment to the beyond-classical regime.

Conclusion

Using the effective quantum volume framework we have developed, we have determined the computational cost of our RCS and OTOC experiments, as well as a recent Floquet evolution experiment. While none of these meet the requirements yet for a computational application, we expect that with improved error rates, an OTOC experiment will be the first beyond-classical, useful application of a quantum processor.

Read More

Teaching language models to reason algorithmically

Teaching language models to reason algorithmically

Large language models (LLMs), such as GPT-3 and PaLM, have shown impressive progress in recent years, which have been driven by scaling up models and training data sizes. Nonetheless, a long standing debate has been whether LLMs can reason symbolically (i.e., manipulating symbols based on logical rules). For example, LLMs are able to perform simple arithmetic operations when numbers are small, but struggle to perform with large numbers. This suggests that LLMs have not learned the underlying rules needed to perform these arithmetic operations.

While neural networks have powerful pattern matching capabilities, they are prone to overfitting to spurious statistical patterns in the data. This does not hinder good performance when the training data is large and diverse and the evaluation is in-distribution. However, for tasks that require rule-based reasoning (such as addition), LLMs struggle with out-of-distribution generalization as spurious correlations in the training data are often much easier to exploit than the true rule-based solution. As a result, despite significant progress in a variety of natural language processing tasks, performance on simple arithmetic tasks like addition has remained a challenge. Even with modest improvement of GPT-4 on the MATH dataset, errors are still largely due to arithmetic and calculation mistakes. Thus, an important question is whether LLMs are capable of algorithmic reasoning, which involves solving a task by applying a set of abstract rules that define the algorithm.

In “Teaching Algorithmic Reasoning via In-Context Learning”, we describe an approach that leverages in-context learning to enable algorithmic reasoning capabilities in LLMs. In-context learning refers to a model’s ability to perform a task after seeing a few examples of it within the context of the model. The task is specified to the model using a prompt, without the need for weight updates. We also present a novel algorithmic prompting technique that enables general purpose language models to achieve strong generalization on arithmetic problems that are more difficult than those seen in the prompt. Finally, we demonstrate that a model can reliably execute algorithms on out-of-distribution examples with an appropriate choice of prompting strategy.

By providing algorithmic prompts, we can teach a model the rules of arithmetic via in-context learning. In this example, the LLM (word predictor) outputs the correct answer when prompted with an easy addition question (e.g., 267+197), but fails when asked a similar addition question with longer digits. However, when the more difficult question is appended with an algorithmic prompt for addition (blue box with white + shown below the word predictor), the model is able to answer correctly. Moreover, the model is capable of simulating the multiplication algorithm (X) by composing a series of addition calculations.

Teaching an algorithm as a skill

In order to teach a model an algorithm as a skill, we develop algorithmic prompting, which builds upon other rationale-augmented approaches (e.g., scratchpad and chain-of-thought). Algorithmic prompting extracts algorithmic reasoning abilities from LLMs, and has two notable distinctions compared to other prompting approaches: (1) it solves tasks by outputting the steps needed for an algorithmic solution, and (2) it explains each algorithmic step with sufficient detail so there is no room for misinterpretation by the LLM.

To gain intuition for algorithmic prompting, let’s consider the task of two-number addition. In a scratchpad-style prompt, we process each digit from right to left and keep track of the carry value (i.e., we add a 1 to the next digit if the current digit is greater than 9) at each step. However, the rule of carry is ambiguous after seeing only a few examples of carry values. We find that including explicit equations to describe the rule of carry helps the model focus on the relevant details and interpret the prompt more accurately. We use this insight to develop an algorithmic prompt for two-number addition, where we provide explicit equations for each step of computation and describe various indexing operations in non-ambiguous formats.

Illustration of various prompt strategies for addition.

Using only three prompt examples of addition with answer length up to five digits, we evaluate performance on additions of up to 19 digits. Accuracy is measured over 2,000 total examples sampled uniformly over the length of the answer. As shown below, the use of algorithmic prompts maintains high accuracy for questions significantly longer than what’s seen in the prompt, which demonstrates that the model is indeed solving the task by executing an input-agnostic algorithm.

Test accuracy on addition questions of increasing length for different prompting methods.

Leveraging algorithmic skills as tool use

To evaluate if the model can leverage algorithmic reasoning in a broader reasoning process, we evaluate performance using grade school math word problems (GSM8k). We specifically attempt to replace addition calculations from GSM8k with an algorithmic solution.

Motivated by context length limitations and possible interference between different algorithms, we explore a strategy where differently-prompted models interact with one another to solve complex tasks. In the context of GSM8k, we have one model that specializes in informal mathematical reasoning using chain-of-thought prompting, and a second model that specializes in addition using algorithmic prompting. The informal mathematical reasoning model is prompted to output specialized tokens in order to call on the addition-prompted model to perform the arithmetic steps. We extract the queries between tokens, send them to the addition-model and return the answer to the first model, after which the first model continues its output. We evaluate our approach using a difficult problem from the GSM8k (GSM8k-Hard), where we randomly select 50 addition-only questions and increase the numerical values in the questions.

An example from the GSM8k-Hard dataset. The chain-of-thought prompt is augmented with brackets to indicate when an algorithmic call should be performed.

We find that using separate contexts and models with specialized prompts is an effective way to tackle GSM8k-Hard. Below, we observe that the performance of the model with algorithmic call for addition is 2.3x the chain-of-thought baseline. Finally, this strategy presents an example of solving complex tasks by facilitating interactions between LLMs specialized to different skills via in-context learning.

Chain-of-thought (CoT) performance on GSM8k-Hard with or without algorithmic call.

Conclusion

We present an approach that leverages in-context learning and a novel algorithmic prompting technique to unlock algorithmic reasoning abilities in LLMs. Our results suggest that it may be possible to transform longer context into better reasoning performance by providing more detailed explanations. Thus, these findings point to the ability of using or otherwise simulating long contexts and generating more informative rationales as promising research directions.

Acknowledgements

We thank our co-authors Behnam Neyshabur, Azade Nova, Hugo Larochelle and Aaron Courville for their valuable contributions to the paper and great feedback on the blog. We thank Tom Small for creating the animations in this post. This work was done during Hattie Zhou’s internship at Google Research.

Read More