An International Scientific Challenge for the Diagnosis and Gleason Grading of Prostate Cancer

In recent years, machine learning (ML) competitions in health have attracted ML scientists to work together to solve challenging clinical problems. These competitions provide access to relevant data and well-defined problems where experienced data scientists come to compete for solutions and learn new methods. However, a fundamental difficulty in organizing such challenges is obtaining and curating high quality datasets for model development and independent datasets for model evaluation. Importantly, to reduce the risk of bias and to ensure broad applicability of the algorithm, evaluation of the generalisability of resulting algorithms should ideally be performed on multiple independent evaluation datasets by an independent group of scientists.

One clinical problem that has attracted substantial ML research is prostate cancer, a condition that 1 in 9 men develop in their lifetime. A prostate cancer diagnosis requires pathologists to examine biological tissue samples under a microscope to identify cancer and grade the cancer for signs of aggressive growth patterns in the cells. However, this cancer grading task (called Gleason grading) is difficult and subjective due to the need for visual assessment of cell differentiation and Gleason pattern predominance. Building a large dataset of samples with expert annotations can help with the development of ML systems to aid in prostate cancer grading.

To help accelerate and enable more research in this area, Google Health, Radboud University Medical Center and Karolinska Institutet joined forces to organize a global competition, the Prostate cANcer graDe Assessment (PANDA) Challenge, on the open Kaggle platform. In “Artificial Intelligence for Diagnosis and Gleason Grading of Prostate Cancer: the PANDA challenge”, published in Nature Medicine, we present the results of the challenge. The study design of the PANDA challenge provided the largest public whole-slide image dataset available and was open to participants from April 21st until July 23rd, 2020. The development datasets remain available for further research. In this effort, we compiled and publicly released a European cohort of prostate cancer cases for algorithm development and pioneered a standardized evaluation setup for digital pathology that enabled independent, blinded external validation of the algorithms on data from both the United States and EU.

The global competition attracted participants from 65 countries (the size of the circle for each country illustrates the number of participants).

Design of the Panda Challenge
The challenge had two phases: a development phase (i.e., the Kaggle competition) and a validation phase. During the competition, 1,290 developers from 65 countries competed in building the best performing Gleason grading algorithm, having full access to a development set for algorithm training. Throughout the competition teams submitted algorithms that were evaluated on a hidden tuning set.

In the validation phase, a selection of top performing algorithms were independently evaluated on internal and external validation datasets with high quality reference grades from panels of expert prostate pathologists. In addition, a group of general pathologists graded a subset of the same cases to put the difficulty of the task and dataset in context. The algorithms submitted by the teams were then compared to grades done by groups of international and US general pathologists on these subsets.

Overview of the PANDA challenge’s phases for development and validation.

Research Velocity During the Challenge
We found that a group of Gleason grading ML algorithms developed during a global competition could achieve pathologist-level performance and generalize well to intercontinental and multinational cohorts. On all external validation sets, these algorithms achieved high agreement with urologic pathologists (prostate specialists) and high sensitivity for detecting tumor in biopsies. The Kaggle platform enabled the tracking of teams’ performance throughout the competition. Impressively, the first team achieving high agreement with the prostate pathologists at above 0.90 (quadratically weighted Cohen’s kappa) on the internal validation set occurred within the first 10 days of the competition. By the 33rd day, the median performance of all teams exceeded a score of 0.85.

Progression of algorithms’ performances throughout the competition, as shown by the highest score on the tuning and internal validation sets among all participating teams. During the competition teams could submit their algorithm for evaluation on the tuning set, after which they received their score. At the same time, algorithms were evaluated on the internal validation set, without disclosing these results to the participating teams. The development of the top score obtained by any team shows the rapid improvement of the algorithms.

Learning from the Challenge
By moderating the discussion forum on the Kaggle platform, we learned that the teams’ openness in sharing code via colab notebooks led to rapid improvement across the board, a promising sign for future public challenges, and a clear indication of the power of sharing knowledge on a common platform.

Organizing a public challenge that evaluates algorithm generalization across independent cohorts using high quality reference standard panels presents substantial logistical difficulties. Assembling this size of a dataset across countries and organizations was a massive undertaking. This work benefited from an amazing collaboration between the three organizing institutions which have all contributed respective publications in this space, two in Lancet Oncology and one in JAMA Oncology. Combining these efforts provided a high quality foundation on which this competition could be based. With the publication, Radboud and Karolinska research groups are also open sourcing the PANDA challenge development datasets to facilitate the further improvement of prostate Gleason grading algorithms. We look forward to seeing many more advancements in this field, and more challenges that can catalyze extensive international knowledge sharing and collaborative research.

Acknowledgements
Key contributors to this project at Google include Po-Hsuan Cameron Chen, Kunal Nagpal, Yuannan Cai, David F. Steiner, Maggie Demkin, Sohier Dane, Fraser Tan, Greg S. Corrado, Lily Peng, Craig H. Mermel. Collaborators on this project include Wouter Bulten, Kimmo Kartasalo, Peter Ström, Hans Pinckaers, Hester van Boven, Robert Vink, Christina Hulsbergen-van de Kaa, Jeroen van der Laak, Mahul B. Amin, Andrew J. Evans, Theodorus van der Kwast, Robert Allan, Peter A. Humphrey, Henrik Grönberg, Hemamali Samaratunga, Brett Delahunt, Toyonori Tsuzuki, Tomi Häkkinen, Lars Egevad, Masi Valkonen, Pekka Ruusuvuori, Geert Litjens, Martin Eklund and the PANDA Challenge consortium. We thank Ellery Wulczyn, Annisah Um’rani, Yun Liu, and Dale Webster for their feedback on the manuscript and guidance on the project. We thank our collaborators at NMCSD, particularly Niels Olson, for internal re-use of de-identified data which contributed to the US external validation set. Sincere appreciation also goes to Sami Lachgar, Ashley Zlatinov, and Lauren Winer for their feedback on the blogpost.

Read More

Guiding Frozen Language Models with Learned Soft Prompts

Large pre-trained language models, which are continuing to grow in size, achieve state-of-art results on many natural language processing (NLP) benchmarks. Since the development of GPT and BERT, standard practice has been to fine-tune models on downstream tasks, which involves adjusting every weight in the network (i.e., model tuning). However, as models become larger, storing and serving a tuned copy of the model for each downstream task becomes impractical.

An appealing alternative is to share across all downstream tasks a single frozen pre-trained language model, in which all weights are fixed. In an exciting development, GPT-3 showed convincingly that a frozen model can be conditioned to perform different tasks through “in-context” learning. With this approach, a user primes the model for a given task through prompt design, i.e., hand-crafting a text prompt with a description or examples of the task at hand. For instance, to condition a model for sentiment analysis, one could attach the prompt, “Is the following movie review positive or negative?” before the input sequence, “This movie was amazing!

Sharing the same frozen model across tasks greatly simplifies serving and allows for efficient mixed-task inference, but unfortunately, this is at the expense of task performance. Text prompts require manual effort to design, and even well-designed prompts still far underperform compared to model tuning. For instance, the performance of a frozen GPT-3 175B parameter model on the SuperGLUE benchmark is 5 points below a fine-tuned T5 model that uses 800 times fewer parameters.

In “The Power of Scale for Parameter-Efficient Prompt Tuning”, presented at EMNLP 2021, we explore prompt tuning, a more efficient and effective method for conditioning frozen models using tunable soft prompts. Just like engineered text prompts, soft prompts are concatenated to the input text. But rather than selecting from existing vocabulary items, the “tokens” of the soft prompt are learnable vectors. This means a soft prompt can be optimized end-to-end over a training dataset. In addition to removing the need for manual design, this allows the prompt to condense information from datasets containing thousands or millions of examples. By comparison, discrete text prompts are typically limited to under 50 examples due to constraints on model input length. We are also excited to release the code and checkpoints to fully reproduce our experiments.

Prompt tuning retains the strong task performance of model tuning, while keeping the pre-trained model frozen, enabling efficient multitask serving.

Prompt Tuning
To create a soft prompt for a given task, we first initialize the prompt as a fixed-length sequence of vectors (e.g., 20 tokens long). We attach these vectors to the beginning of each embedded input and feed the combined sequence into the model. The model’s prediction is compared to the target to calculate a loss, and the error is back-propagated to calculate gradients, however we only apply these gradient updates to our new learnable vectors — keeping the core model frozen. While soft prompts learned in this way are not immediately interpretable, at an intuitive level, the soft prompt is extracting evidence about how to perform a task from the labeled dataset, performing the same role as a manually written text prompt, but without the need to be constrained to discrete language.

Our codebase, implemented in the new JAX-based T5X framework, makes it easy for anyone to replicate this procedure, and provides practical hyperparameter settings, including a large learning rate (0.3), which we found was important for achieving good results.

Since soft prompts have a small parameter footprint (we train prompts with as few as 512 parameters), one can easily pass the model a different prompt along with each input example. This enables mixed-task inference batches, which can streamline serving by sharing one core model across many tasks.

Left: With model tuning, incoming data are routed to task-specific models. Right: With prompt tuning, examples and prompts from different tasks can flow through a single frozen model in large batches, better utilizing serving resources.

Improvement with Scale
When evaluated on SuperGLUE and using a frozen T5 model, prompt tuning significantly outperforms prompt design using either GPT-3 or T5. Furthermore, as model size increases, prompt tuning catches up to the performance level of model tuning. Intuitively, the larger the pre-trained model, the less of a “push” it needs to perform a specific task, and the more capable it is of being adapted in a parameter-efficient way.

As scale increases, prompt tuning matches model tuning, despite tuning 25,000 times fewer parameters.

The effectiveness of prompt tuning at large model scales is especially important, since serving separate copies of a large model can incur significant computational overhead. In our paper, we demonstrate that larger models can be conditioned successfully even with soft prompts as short as 5 tokens. For T5 XXL, this means tuning just 20 thousand parameters to guide the behavior of an 11 billion parameter model.

Resilience to Domain Shift
Another advantage of prompt tuning is its resilience to domain shift. Since model tuning touches every weight in the network, it has the capacity to easily overfit on the provided fine-tuning data and may not generalize well to variations in the task at inference time. By comparison, our learned soft prompts have a small number of parameters, so the solutions they represent may be more generalizable.

To test generalizability, we train prompt tuning and model tuning solutions on one task, and evaluate zero-shot on a closely related task. For example, when we train on the Quora Question Pairs task (i.e., detecting if two questions are duplicates) and evaluate on MRPC (i.e., detecting if two sentences from news articles are paraphrases), prompt tuning achieves +3.2 points higher accuracy than model tuning.

Train    Eval    Tuning    Accuracy    F1
                          
QQP    MRPC    Model    73.1 ±0.9    81.2 ±2.1
Prompt    76.3 ±0.1    84.3 ±0.3
                          
MRPC    QQP    Model    74.9 ±1.3    70.9 ±1.2
Prompt    75.4 ±0.8    69.7 ±0.3   
On zero-shot domain transfer between two paraphrase detection tasks, prompt tuning matches or outperforms model tuning, depending on the direction of transfer.

Looking Forward
Prompt-based learning is an exciting new area that is quickly evolving. While several similar methods have been proposed — such as Prefix Tuning, WARP, and P-Tuningwe discuss their pros and cons and demonstrate that prompt tuning is the simplest and the most parameter efficient method.

In addition to the Prompt Tuning codebase, we’ve also released our LM-adapted T5 checkpoints, which we found to be better-suited for prompt tuning compared to the original T5. This codebase was used for the prompt tuning experiments in FLAN, and the checkpoints were used as a starting point for training the BigScience T0 model. We hope that the research community continues to leverage and extend prompt tuning in future research.

Acknowledgements
This project was a collaboration between Brian Lester, Rami Al-Rfou and Noah Constant. We are grateful to the following people for feedback, discussion and assistance: Waleed Ammar, Lucas Dixon, Slav Petrov, Colin Raffel, Adam Roberts, Sebastian Ruder, Noam Shazeer, Tu Vu and Linting Xue.

Read More

Nested Hierarchical Transformer: Towards Accurate, Data-Efficient, and Interpretable Visual Understanding

In visual understanding, the Visual Transformer (ViT) and its variants have received significant attention recently due to their superior performance on many core visual applications, such as image classification, object detection, and video understanding. The core idea of ViT is to utilize the power of self-attention layers to learn global relationships between small patches of images. However, the number of connections between patches increases quadratically with image size. Such a design has been observed to be data inefficient — although the original ViT can perform better than convolutional networks with hundreds of millions of images for pre-training, such a data requirement is not always practical, and it still underperforms compared to convolutional networks when given less data. Many are exploring to find more suitable architectural re-designs that can learn visual representations effectively, such as by adding convolutional layers and building hierarchical structures with local self-attention.

The principle of hierarchical structure is one of the core ideas in vision models, where bottom layers learn more local object structures on the high-dimensional pixel space and top layers learn more abstracted and high-level knowledge at low-dimensional feature space. Existing ViT-based methods focus on designing a variety of modifications inside self-attention layers to achieve such a hierarchy, but while these offer promising performance improvements, they often require substantial architectural re-designs. Moreover, these approaches lack an interpretable design, so it is difficult to explain the inner-workings of trained models.

To address these challenges, in “Nested Hierarchical Transformer: Towards Accurate, Data-Efficient and Interpretable Visual Understanding”, we present a rethinking of existing hierarchical structure–driven designs, and provide a novel and orthogonal approach to significantly simplify them. The central idea of this work is to decouple feature learning and feature abstraction (pooling) components: nested transformer layers encode visual knowledge of image patches separately, and then the processed information is aggregated. This process is repeated in a hierarchical manner, resulting in a pyramid network structure. The resulting architecture achieves competitive results on ImageNet and outperforms results on data-efficient benchmarks. We have shown such a design can meaningfully improve data efficiency with faster convergence and provide valuable interpretability benefits. Moreover, we introduce GradCAT, a new technique for interpreting the decision process of a trained model at inference time.

Architecture Design
The overall architecture is simple to implement by adding just a few lines of Python code to the source code of the original ViT. The original ViT architecture divides an input image into small patches, projects pixels of each patch to a vector with predefined dimension, and then feeds the sequences of all vectors to the overall ViT architecture containing multiple stacked identical transformer layers. While every layer in ViT processes information of the whole image, with this new method, stacked transformer layers are used to process only a region (i.e., block) of the image containing a few spatially adjacent image patches. This step is independent for each block and is also where feature learning occurs. Finally, a new computational layer called block aggregation then combines all of the spatially adjacent blocks. After each block aggregation, the features corresponding to four spatially adjacent blocks are fed to another module with a stack of transformer layers, which then process those four blocks jointly. This design naturally builds a pyramid hierarchical structure of the network, where bottom layers can focus on local features (such as textures) and upper layers focus on global features (such as object shape) at reduced dimensionality thanks to the block aggregation.

A visualization of the network processing an image: Given an input image, the network first partitions images into blocks, where each block contains 4 image patches. Image patches in every block are linearly projected as vectors and processed by a stack of identical transformer layers. Then the proposed block aggregation layer aggregates information from each block and reduces its spatial size by 4 times. The number of blocks is reduced to 1 at the top hierarchy and classification is conducted after the output of it.

Interpretability
This architecture has a non-overlapping information processing mechanism, independent at every node. This design resembles a decision tree-like structure, which manifests unique interpretability capabilities because every tree node contains independent information of an image block that is being received by its parent nodes. We can trace the information flow through the nodes to understand the importance of each feature. In addition, our hierarchical structure retains the spatial structure of images throughout the network, leading to learned spatial feature maps that are effective for interpretation. Below we showcase two kinds of visual interpretability.

First, we present a method to interpret the trained model on test images, called gradient-based class-aware tree-traversal (GradCAT). GradCAT traces the feature importance of each block (a tree node) from top to bottom of the hierarchy structure. The main idea is to find the most valuable traversal from the root node at the top layer to a child node at the bottom layer that contributes the most to the classification outcomes. Since each node processes information from a certain region of the image, such traversal can be easily mapped to the image space for interpretation (as shown by the overlaid dots and lines in the image below).

The following is an example of the model’s top-4 predictions and corresponding interpretability results on the left input image (containing 4 animals). As shown below, GradCAT highlights the decision path along the hierarchical structure as well as the corresponding visual cues in local image regions on the images.

Given the left input image (containing four objects), the figure visualizes the interpretability results of the top-4 prediction classes. The traversal locates the model decision path along the tree and simultaneously locates the corresponding image patch (shown by the dotted line on images) that has the highest impact to the predicted target class.

Moreover, the following figures visualize results on the ImageNet validation set and show how this approach enables some intuitive observations. For instance, the example of the lighter below (upper left panel) is particularly interesting because the ground truth class — lighter/matchstick — actually defines the bottom-right matchstick object, while the most salient visual features (with the highest node values) are actually from the upper-left red light, which conceptually shares visual cues with a lighter. This can also be seen from the overlaid red lines, which indicate the image patches with the highest impact on the prediction. Thus, although the visual cue is a mistake, the output prediction is correct. In addition, the four child nodes of the wooden spoon below have similar feature importance values (see numbers visualized in the nodes; higher indicates more importance), which is because the wooden texture of the table is similar to that of the spoon.

Visualization of the results obtained by the proposed GradCAT. Images are from the ImageNet validation dataset.

Second, different from the original ViT, our hierarchical architecture retains spatial relationships in learned representations. The top layers output low-resolution features maps of input images, enabling the model to easily perform attention-based interpretation by applying Class Attention Map (CAM) on the learned representations at the top hierarchical level. This enables high-quality weakly-supervised object localization with just image-level labels. See the following figure for examples.

Visualization of CAM-based attention results. Warmer colors indicate higher attention. Images are from the ImageNet validation dataset.

Convergence Advantages
With this design, feature learning only happens at local regions independently, and feature abstraction happens inside the aggregation function. This design and simple implementation is general enough for other types of visual understanding tasks beyond classification. It also improves the model convergence speed greatly, significantly reducing the training time to reach the desired maximum accuracy.

We validate this advantage in two ways. First, we compare the ViT structure on the ImageNet accuracy with a different number of total training epochs. The results are shown on the left side of the figure below, demonstrating much faster convergence than the original ViT, e.g., around 20% improvement in accuracy over ViT with 30 total training epochs.

Second, we modify the architecture to conduct unconditional image generation tasks, since training ViT-based models for image generation tasks is challenging due to convergence and speed issues. Creating such a generator is straightforward by transposing the proposed architecture: the input is an embedding vector, the output is a full image in RGB channels, and the block aggregation is replaced by a block de-aggregation component supported by Pixel Shuffling. Surprisingly, we find our generator is easy to train and demonstrates faster convergence speed, as well as better FID score (which measures how similar generated images are to real ones), than the capacity-comparable SAGAN.

Left: ImageNet accuracy given different number of total training epochs compared with standard ViT architecture. Right: ImageNet 64×64 image generation FID scores (lower is better) with single 1000-epoch training. On both tasks, our method shows better convergence speed.

Conclusion
In this work we demonstrate the simple idea that decoupled feature learning and feature information extraction in this nested hierarchy design leads to better feature interpretability through a new gradient-based class-aware tree traversal method. Moreover, the architecture improves convergence on not only classification tasks but also image generation tasks. The proposed idea is focusing on aggregation function and thereby is orthogonal to advanced architecture design for self-attention. We hope this new research encourages future architecture designers to explore more interpretable and data-efficient ViT-based models for visual understanding, like the adoption of this work for high-resolution image generation. We have also released the source code for the image classification portion of this work.

Acknowledgements
We gratefully acknowledge the contributions of other co-authors, including Han Zhang, Long Zhao, Ting Chen, Sercan Arik, Tomas Pfister. We also thank Xiaohua Zhai, Jeremy Kubica, Kihyuk Sohn, and Madeleine Udell for the valuable feedback of the work.

Read More

Robot See, Robot Do

People learn to do things by watching others — from mimicking new dance moves, to watching YouTube cooking videos. We’d like robots to do the same, i.e., to learn new skills by watching people do things during training. Today, however, the predominant paradigm for teaching robots is to remote control them using specialized hardware for teleoperation and then train them to imitate pre-recorded demonstrations. This limits both who can provide the demonstrations (programmers & roboticists) and where they can be provided (lab settings). If robots could instead self-learn new tasks by watching humans, this capability could allow them to be deployed in more unstructured settings like the home, and make it dramatically easier for anyone to teach or communicate with them, expert or otherwise. Perhaps one day, they might even be able to use Youtube videos to grow their collection of skills over time.

Our motivation is to have robots watch people do tasks, naturally with their hands, and then use that data as demonstrations for learning. Video by Teh Aik Hui and Nathaniel Lim. License: CC-BY

However, an obvious but often overlooked problem is that a robot is physically different from a human, which means it often completes tasks differently than we do. For example, in the pen manipulation task below, the hand can grab all the pens together and quickly transfer them between containers, whereas the two-fingered gripper must transport one at a time. Prior research assumes that humans and robots can do the same task similarly, which makes manually specifying one-to-one correspondences between human and robot actions easy. But with stark differences in physique, defining such correspondences for seemingly easy tasks can be surprisingly difficult and sometimes impossible.

Physically different end-effectors (i.e., “grippers”) (i.e., the part that interacts with the environment) induce different control strategies when solving the same task. Left: The hand grabs all pens and quickly transfers them between containers. Right: The two-fingered gripper transports one pen at a time.

In “XIRL: Cross-Embodiment Inverse RL”, presented as an oral paper at CoRL 2021, we explore these challenges further and introduce a self-supervised method for Cross-embodiment Inverse Reinforcement Learning (XIRL). Rather than focusing on how individual human actions should correspond to robot actions, XIRL learns the high-level task objective from videos, and summarizes that knowledge in the form of a reward function that is invariant to embodiment differences, such as shape, actions and end-effector dynamics. The learned rewards can then be used together with reinforcement learning to teach the task to agents with new physical embodiments through trial and error. Our approach is general and scales autonomously with data — the more embodiment diversity presented in the videos, the more invariant and robust the reward functions become. Experiments show that our learned reward functions lead to significantly more sample efficient (roughly 2 to 4 times) reinforcement learning on new embodiments compared to alternative methods. To extend and build on our work, we are releasing an accompanying open-source implementation of our method along with X-MAGICAL, our new simulated benchmark for cross-embodiment imitation.

Cross-Embodiment Inverse Reinforcement Learning (XIRL)
The underlying observation in this work is that in spite of the many differences induced by different embodiments, there still exist visual cues that reflect progression towards a common task objective. For example, in the pen manipulation task above, the presence of pens in the cup but not the mug, or the absence of pens on the table, are key frames that are common to different embodiments and indirectly provide cues for how close to being complete a task is. The key idea behind XIRL is to automatically discover these key moments in videos of different length and cluster them meaningfully to encode task progression. This motivation shares many similarities with unsupervised video alignment research, from which we can leverage a method called Temporal Cycle Consistency (TCC), which aligns videos accurately while learning useful visual representations for fine-grained video understanding without requiring any ground-truth correspondences.

We leverage TCC to train an encoder to temporally align video demonstrations of different experts performing the same task. The TCC loss tries to maximize the number of cycle-consistent frames (or mutual nearest-neighbors) between pairs of sequences using a differentiable formulation of soft nearest-neighbors. Once the encoder is trained, we define our reward function as simply the negative Euclidean distance between the current observation and the goal observation in the learned embedding space. We can subsequently insert the reward into a standard MDP and use an RL algorithm to learn the demonstrated behavior. Surprisingly, we find that this simple reward formulation is effective for cross-embodiment imitation.

XIRL self-supervises reward functions from expert demonstrations using temporal cycle consistency (TCC), then uses them for downstream reinforcement learning to learn new skills from third-person demonstrations.

X-MAGICAL Benchmark
To evaluate the performance of XIRL and baseline alternatives (e.g., TCN, LIFS, Goal Classifier) in a consistent environment, we created X-MAGICAL, which is a simulated benchmark for cross-embodiment imitation. X-MAGICAL features a diverse set of agent embodiments, with differences in their shapes and end-effectors, designed to solve tasks in different ways. This leads to differences in execution speeds and state-action trajectories, which poses challenges for current imitation learning techniques, e.g., ones that use time as a heuristic for weak correspondences between two trajectories. The ability to generalize across embodiments is precisely what X-MAGICAL evaluates.

The SweepToTop task we considered for our experiments is a simplified 2D equivalent of a common household robotic sweeping task, where an agent has to push three objects into a goal zone in the environment. We chose this task specifically because its long-horizon nature highlights how different agent embodiments can generate entirely different trajectories (shown below). X-MAGICAL features a Gym API and is designed to be easily extendable to new tasks and embodiments. You can try it out today with pip install x-magical.

Different agent shapes in the SweepToTop task in the X-MAGICAL benchmark need to use different strategies to reposition objects into the target area (pink), i.e., to “clear the debris”. For example, the long-stick can clear them all in one fell swoop, whereas the short-stick needs to do multiple consecutive back-and-forths.
Left: Heatmap of state visitation for each embodiment across all expert demonstrations. Right: Examples of expert trajectories for each embodiment.

Highlights
In our first set of experiments, we checked whether our learned embodiment-invariant reward function can enable successful reinforcement learning, when the expert demonstrations are provided through the agent itself. We find that XIRL significantly outperforms alternative methods especially on the tougher agents (e.g., short-stick and gripper).

Same-embodiment setting: Comparison of XIRL with baseline reward functions, using SAC for RL policy learning. XIRL is roughly 2 to 4 times more sample efficient than some of the baselines on the harder agents (short-stick and gripper).

We also find that our approach shows great potential for learning reward functions that generalize to novel embodiments. For instance, when reward learning is performed on embodiments that are different from the ones on which the policy is trained, we find that it results in significantly more sample efficient agents compared to the same baselines. Below, in the gripper subplot (bottom right) for example, the reward is first learned on demonstration videos from long-stick, medium-stick and short-stick, after which the reward function is used to train the gripper agent.

Cross-embodiment setting: XIRL performs favorably when compared with other baseline reward functions, trained on observation-only demonstrations from different embodiments. Each agent (long-stick, medium-stick, short-stick, gripper) had its reward trained using demonstrations from the other three embodiments.

We also find that we can train on real-world human demonstrations, and use the learned reward to train a Sawyer arm in simulation to push a puck to a designated target zone. In these experiments as well, our method outperforms baseline alternatives. For example, our XIRL variant trained only on the real-world demonstrations (purple in the plots below) reaches 80% of the total performance roughly 85% faster than the RLV baseline (orange).

What Do The Learned Reward Functions Look Like?
To further explore the qualitative nature of our learned rewards in more challenging real-world scenarios, we collect a dataset of the pen transfer task using various household tools.

Below, we show rewards extracted from a successful (top) and unsuccessful (bottom) demonstration. Both demonstrations follow a similar trajectory at the start of the task execution. The successful one nets a high reward for placing the pens consecutively into the mug then into the glass cup, while the unsuccessful one obtains a low reward because it drops the pens outside the glass cup towards the end of the execution (orange circle). These results are promising because they show that our learned encoder can represent fine-grained visual differences relevant to a task.

Conclusion
We highlighted XIRL, our approach to tackling the cross-embodiment imitation problem. XIRL learns an embodiment-invariant reward function that encodes task progress using a temporal cycle-consistency objective. Policies learned using our reward functions are significantly more sample-efficient than baseline alternatives. Furthermore, the reward functions do not require manually paired video frames between the demonstrator and the learner, giving them the ability to scale to an arbitrary number of embodiments or experts with varying skill levels. Overall, we are excited about this direction of work, and hope that our benchmark promotes further research in this area. For more details, please check out our paper and download the code from our GitHub repository.

Acknowledgments
Kevin and Andy summarized research performed together with Pete Florence, Jonathan Tompson, Jeannette Bohg (faculty at Stanford University) and Debidatta Dwibedi. All authors would additionally like to thank Alex Nichol, Nick Hynes, Sean Kirmani, Brent Yi, Jimmy Wu, Karl Schmeckpeper and Minttu Alakuijala for fruitful technical discussions, and Sam Toyer for invaluable help with setting up the simulated benchmark.

Read More

Unlocking the Full Potential of Datacenter ML Accelerators with Platform-Aware Neural Architecture Search

Continuing advances in the design and implementation of datacenter (DC) accelerators for machine learning (ML), such as TPUs and GPUs, have been critical for powering modern ML models and applications at scale. These improved accelerators exhibit peak performance (e.g., FLOPs) that is orders of magnitude better than traditional computing systems. However, there is a fast-widening gap between the available peak performance offered by state-of-the-art hardware and the actual achieved performance when ML models run on that hardware.

One approach to address this gap is to design hardware-specific ML models that optimize both performance (e.g., throughput and latency) and model quality. Recent applications of neural architecture search (NAS), an emerging paradigm to automate the design of ML model architectures, have employed a platform-aware multi-objective approach that includes a hardware performance objective. While this approach has yielded improved model performance in practice, the details of the underlying hardware architecture are opaque to the model. As a result, there is untapped potential to build full capability hardware-friendly ML model architectures, with hardware-specific optimizations, for powerful DC ML accelerators.

In “Searching for Fast Model Families on Datacenter Accelerators”, published at CVPR 2021, we advanced the state of the art of hardware-aware NAS by automatically adapting model architectures to the hardware on which they will be executed. The approach we propose finds optimized families of models for which additional hardware performance gains cannot be achieved without loss in model quality (called Pareto optimization). To accomplish this, we infuse a deep understanding of hardware architecture into the design of the NAS search space for discovery of both single models and model families. We provide quantitative analysis of the performance gap between hardware and traditional model architectures and demonstrate the advantages of using true hardware performance (i.e., throughput and latency), instead of the performance proxy (FLOPs), as the performance optimization objective. Leveraging this advanced hardware-aware NAS and building upon the EfficientNet architecture, we developed a family of models, called EfficientNetX, that demonstrate the effectiveness of this approach for Pareto-optimized ML models on TPUs and GPUs.

Platform-Aware NAS for DC ML Accelerators
To achieve high performance, ML models need to adapt to modern ML accelerators. Platform-aware NAS integrates knowledge of the hardware accelerator properties into all three pillars of NAS: (i) the search objectives; (ii) the search space; and (iii) the search algorithm (shown below). We focus on the new search space because it contains the building blocks needed to compose the models and is the key link between the ML model architectures and accelerator hardware architectures.

We construct TPU/GPU specialized search spaces with TPU/GPU-friendly operations to infuse hardware awareness into NAS. For example, a key adaptation is maximizing parallelism to ensure different hardware components inside the accelerators work together as efficiently as possible. This includes the matrix multiplication units (MXUs) in TPUs and the TensorCore in GPUs for matrix/tensor computation, as well as the vector processing units (VPUs) in TPUs and CUDA cores in GPUs for vector processing. Maximizing model arithmetic intensity (i.e., optimizing the parallelism between computation and operations on the high bandwidth memory) is also critical to achieve top performance. To tap into the full potential of the hardware, it is crucial for ML models to achieve high parallelism inside and across these hardware components.

Overview of platform-aware NAS on TPUs/GPUs, highlighting the search space and search objectives.

Advanced platform-aware NAS has an optimized search space containing a set of complementary techniques to holistically improve parallelism for ML model execution on TPUs and GPUs:

  1. It uses specialized tensor reshaping techniques to maximize the parallelism in the MXUs / TensorCores.
  2. It dynamically selects different activation functions depending on matrix operation types to ensure overlapping of vector and matrix/tensor processing.
  3. It employs hybrid convolutions and a novel fusion strategy to strike a balance between total compute and arithmetic intensity to ensure that computation and memory access happens in parallel and to reduce the contention on VPUs / CUDA cores.
  4. With latency-aware compound scaling (LACS), which uses hardware performance instead of FLOPs as the performance objective to search for model depth, width and resolutions, we ensure parallelism at all levels for the entire model family on the Pareto-front.

EfficientNet-X: Platform-Aware NAS-Optimized Computer Vision Models for TPUs and GPUs
Using this approach to platform-aware NAS, we have designed EfficientNet-X, an optimized computer vision model family for TPUs and GPUs. This family builds upon the EfficientNet architecture, which itself was originally designed by traditional multi-objective NAS without true hardware-awareness as the baseline. The resulting EfficientNet-X model family achieves an average speedup of ~1.5x–2x over EfficientNet on TPUv3 and GPUv100, respectively, with comparable accuracy.

In addition to the improved speeds, EfficientNet-X has shed light on the non-proportionality between FLOPs and true performance. Many think FLOPs are a good ML performance proxy (i.e., FLOPs and performance are proportional), but they are not. While FLOPs are a good performance proxy for simple hardware such as scalar machines, they can exhibit a margin of error of up to 400% on advanced matrix/tensor machines. For example, because of its hardware-friendly model architecture, EfficientNet-X requires ~2x more FLOPs than EfficientNet, but is ~2x faster on TPUs and GPUs.

EfficientNet-X family achieves 1.5x–2x speedup on average over the state-of-the-art EfficientNet family, with comparable accuracy on TPUv3 and GPUv100.

Self-Driving ML Model Performance on New Accelerator Hardware Platforms
Platform-aware NAS exposes the inner workings of the hardware and leverages these properties when designing hardware-optimized ML models. In a sense, the “platform-awareness” of the model is a “gene” that preserves knowledge of how to optimize performance for a hardware family, even on new generations, without the need to redesign the models. For example, TPUv4i delivers up to 3x higher peak performance (FLOPS) than its predecessor TPUv2, but EfficientNet performance only improves by 30% when migrating from TPUv2 to TPUv4i. In comparison, EfficientNet-X retains its platform-aware properties even on new hardware and achieves a 2.6x speedup when migrating from TPUv2 to TPUv4i, utilizing almost all of the 3x peak performance gain expected when upgrading between the two generations.

Hardware peak performance ratio of TPUv2 to TPUv4i and the geometric mean speedup of EfficientNet-X and EfficientNet families, respectively, when migrating from TPUv2 to TPUv4i.

Conclusion and Future Work
We demonstrate how to improve the capabilities of platform-aware NAS for datacenter ML accelerators, especially TPUs and GPUs. Both platform-aware NAS and the EfficientNet-X model family have been deployed in production and materialize up to ~40% efficiency gains and significant quality improvements for various internal computer vision projects across Google. Additionally, because of its deep understanding of accelerator hardware architecture, platform-aware NAS was able to identify critical performance bottlenecks on TPUv2-v4i architectures and has enabled design enhancements to future TPUs with significant potential performance uplift. As next steps, we are working on expanding platform-aware NAS’s capabilities to the ML hardware and model design beyond computer vision.

Acknowledgements
Special thanks to our co-authors: Mingxing Tan, Ruoming Pang, Andrew Li, Liqun Cheng, Quoc Le. We also thank many collaborators including Jeff Dean, David Patterson, Shengqi Zhu, Yun Ni, Gang Wu, Tao Chen, Xin Li, Yuan Qi, Amit Sabne, Shahab Kamali, and many others from the broad Google research and engineering teams who helped on the research and the subsequent broad production deployment of platform-aware NAS.

Read More

Ask a Techspert: What does AI do when it doesn’t know?

As humans, we constantly learn from the world around us. We experience inputs that shape our knowledge — including the boundaries of both what we know and what we don’t know.

Many of today’s machines also learn by example. However, these machines are typically trained on datasets and information that doesn’t always include rare or out-of-the-ordinary examples that inevitably come up in real-life scenarios. What is an algorithm to do when faced with the unknown?

I recently spoke with Abhijit Guha Roy, an engineer on the Health AI team, and Ian Kivlichan, an engineer on the Jigsaw team, to hear more about using AI in real-world scenarios and better understand the importance of training it to know when it doesn’t know.

Abhijit, tell me about your recent research in the dermatology space.

We’re applying deep learning to a number of areas in health, including in medical imaging where it can be used to aid in the identification of health conditions and diseases that might require treatment. In the dermatological field, we have shown that AI can be used to help identify possible skin issues and are in the process of advancing research and products, including DermAssist, that can support both clinicians and people like you and me.

In these real-world settings, the algorithm might come up against something it’s never seen before. Rare conditions, while individually infrequent, might not be so rare in aggregate. These so-called “out-of-distribution” examples are a common problem for AI systems which can perform less well when it’s exposed to things they haven’t seen before in its training.

Can you explain what “out-distribution” means for AI?

Most traditional machine learning examples that are used to train AI deal with fairly unsubtle — or obvious — changes. For example, if an algorithm that is trained to identify cats and dogs comes across a car, then it can typically detect that the car — which is an “out-of-distribution” example — is an outlier. Building an AI system that can recognize the presence of something it hasn’t seen before in training is called “out-of-distribution detection,” and is an active and promising field of AI research.

Okay, let’s go back to how this applies to AI in medical settings.

Going back to our research in the dermatology space, the differences between skin conditions can be much more subtle than recognizing a car from a dog or a cat, even more subtle than recognizing a previously unseen “pick-up truck” from a “truck”. As such, the out-of-distribution detection task in medical AI demands even more of our focused attention.

This is where our latest research comes in. We trained our algorithm to recognize even the most subtle of outliers (a so-called “near-out of distribution” detection task). Then, instead of the model inaccurately guessing, it can take a safer course of action — like deferring to human experts.

Ian, you’re working on another area where AI needs to know when it doesn’t know something. What’s that?

The field of content moderation. Our team at Jigsaw used AI to build a free tool called Perspective that scores comments according to how likely they are to be considered toxic by readers. Our AI algorithms help identify toxic language and online harassment at scale so that human content moderators can make better decisions for their online communities. A range of online platforms use Perspective more than 600 million times a day to reduce toxicity and the human time required to moderate content.

In the real world, online conversations — both the things people say and even the ways they say them — are continually changing. For example, two years ago, nobody would have understood the phrase “non-fungible token (NFT).” Our language is always evolving, which means a tool like Perspective doesn’t just need to identify potentially toxic or harassing comments, it also needs to “know when it doesn’t know,” and then defer to human moderators when it encounters comments very different from anything it has encountered before.

In our recent research, we trained Perspective to identify comments it was uncertain about and flag them for separate human review. By prioritizing these comments, human moderators can correct more than 80% of the mistakes the AI might otherwise have made.

What connects these two examples?

We have more in common with the dermatology problem than you’d expect at first glance — even though the problems we try to solve are so different.

Building AI that knows when it doesn’t know something means you can prevent certain errors that might have unintended consequences. In both cases, the safest course of action for the algorithm entails deferring to human experts rather than trying to make a decision that could lead to potentially negative effects downstream.

There are some fields where this isn’t as important and others where it’s critical. You might not care if an automated vegetable sorter incorrectly sorts a purple carrot after being trained on orange carrots, but you would definitely care if an algorithm didn’t know what to do about an abnormal shadow on an X-ray that a doctor might recognize as an unexpected cancer.

How is AI uncertainty related to AI safety?

Most of us are familiar with safety protocols in the workplace. In safety-critical industries like aviation or medicine, protocols like “safety checklists” are routine and very important in order to prevent harm to both the workers and the people they serve.

It’s important that we also think about safety protocols when it comes to machines and algorithms, especially when they are integrated into our daily workflow and aid in decision-making or triaging that can have a downstream impact.

Teaching algorithms to refrain from guessing in unfamiliar scenarios and to ask for help from human experts falls within these protocols, and is one of the ways we can reduce harm and build trust in our systems. This is something Google is committed to, as outlined in its AI Principles.

Read More

Can Robots Follow Instructions for New Tasks?

People can flexibly maneuver objects in their physical surroundings to accomplish various goals. One of the grand challenges in robotics is to successfully train robots to do the same, i.e., to develop a general-purpose robot capable of performing a multitude of tasks based on arbitrary user commands. Robots that are faced with the real world will also inevitably encounter new user instructions and situations that were not seen during training. Therefore, it is imperative for robots to be trained to perform multiple tasks in a variety of situations and, more importantly, to be capable of solving new tasks as requested by human users, even if the robot was not explicitly trained on those tasks.

Existing robotics research has made strides towards allowing robots to generalize to new objects, task descriptions, and goals. However, enabling robots to complete instructions that describe entirely new tasks has largely remained out-of-reach. This problem is remarkably difficult since it requires robots to both decipher the novel instructions and identify how to complete the task without any training data for that task. This goal becomes even more difficult when a robot needs to simultaneously handle other axes of generalization, such as variability in the scene and positions of objects. So, we ask the question: How can we confer noteworthy generalization capabilities onto real robots capable of performing complex manipulation tasks from raw pixels? Furthermore, can the generalization capabilities of language models help support better generalization in other domains, such as visuomotor control of a real robot?

In “BC-Z: Zero-Shot Task Generalization with Robotic Imitation Learning”, published at CoRL 2021, we present new research that studies how robots can generalize to new tasks that they were not trained to do. The system, called BC-Z, comprises two key components: (i) the collection of a large-scale demonstration dataset covering 100 different tasks and (ii) a neural network policy conditioned on a language or video instruction of the task. The resulting system can perform at least 24 novel tasks, including ones that require interaction with pairs of objects that were not previously seen together. We are also excited to release the robot demonstration dataset used to train our policies, along with pre-computed task embeddings.

The BC-Z system allows a robot to complete instructions for new tasks that the robot was not explicitly trained to do. It does so by training the policy to take as input a description of the task along with the robot’s camera image and to predict the correct action.

Collecting Data for 100 Tasks
Generalizing to a new task altogether is substantially harder than generalizing to held-out variations in training tasks. Simply put, we want robots to have more generalization all around, which requires that we train them on large amounts of diverse data.

We collect data by teleoperating the robot with a virtual reality headset. This data collection follows a scheme similar to how one might teach an autonomous car to drive. First, the human operator records complete demonstrations of each task. Then, once the robot has learned an initial policy, this policy is deployed under close supervision where, if the robot starts to make a mistake or gets stuck, the operator intervenes and demonstrates a correction before allowing the robot to resume.

This mixture of demonstrations and interventions has been shown to significantly improve performance by mitigating compounding errors. In our experiments, we see a 2x improvement in performance when using this data collection strategy compared to only using human demonstrations.

Example demonstrations collected for 12 out of the 100 training tasks, visualized from the perspective of the robot and shown at 2x speed.

Training a General-Purpose Policy
For all 100 tasks, we use this data to train a neural network policy to map from camera images to the position and orientation of the robot’s gripper and arm. Crucially, to allow this policy the potential to solve new tasks beyond the 100 training tasks, we also input a description of the task, either in the form of a language command (e.g., “place grapes in red bowl”) or a video of a person doing the task.

To accomplish a variety of tasks, the BC-Z system takes as input either a language command describing the task or a video of a person doing the task, as shown here.

By training the policy on 100 tasks and conditioning the policy on such a description, we unlock the possibility that the neural network will be able to interpret and complete instructions for new tasks. This is a challenge, however, because the neural network needs to correctly interpret the instruction, visually identify relevant objects for that instruction while ignoring other clutter in the scene, and translate the interpreted instruction and perception into the robot’s action space.

Experimental Results
In language models, it is well known that sentence embeddings generalize on compositions of concepts encountered in training data. For instance, if you train a translation model on sentences like “pick up a cup” and “push a bowl”, the model should also translate “push a cup” correctly.

We study the question of whether the compositional generalization capabilities found in language encoders can be transferred to real robots, i.e., being able to compose unseen object-object and task-object pairs.

We test this method by pre-selecting a set of 28 tasks, none of which were among the 100 training tasks. For example, one of these new test tasks is to pick up the grapes and place them into a ceramic bowl, but the training tasks involve doing other things with the grapes and placing other items into the ceramic bowl. The grapes and the ceramic bowl never appeared in the same scene during training.

In our experiments, we see that the robot can complete many tasks that were not included in the training set. Below are a few examples of the robot’s learned policy.

The robot completes three instructions of tasks that were not in its training data, shown at 2x speed.

Quantitatively, we see that the robot can succeed to some degree on a total of 24 out of the 28 held-out tasks, indicating a promising capacity for generalization. Further, we see a notably small gap between the performance on the training tasks and performance on the test tasks. These results indicate that simply improving multi-task visuomotor control could considerably improve performance.

The BC-Z performance on held-out tasks, i.e., tasks that the robot was not trained to perform. The system correctly interprets the language command and translates that into action to complete many of the tasks in our evaluation.

Takeaways
The results of this research show that simple imitation learning approaches can be scaled in a way that enables zero-shot generalization to new tasks. That is, it shows one of the first indications of robots being able to successfully carry out behaviors that were not in the training data. Interestingly, language embeddings pre-trained on ungrounded language corpora make for excellent task conditioners. We demonstrated that natural language models can not only provide a flexible input interface to robots, but that pretrained language representations actually confer new generalization capabilities to the downstream policy, such as composing unseen object pairs together.

In the course of building this system, we confirmed that periodic human interventions are a simple but important technique for achieving good performance. While there is a substantial amount of work to be done in the future, we believe that the zero-shot generalization capabilities of BC-Z are an important advancement towards increasing the generality of robotic learning systems and allowing people to command robots. We have released the teleoperated demonstrations used to train the policy in this paper, which we hope will provide researchers with a valuable resource for future multi-task robotic learning research.

Acknowledgements
We would like to thank the co-authors of this research: Alex Irpan, Mohi Khansari, Daniel Kappler, Frederik Ebert, Corey Lynch, and Sergey Levine. This project was a collaboration between Google Research and the Everyday Robot Project. We would like to give special thanks to Noah Brown, Omar Cortes, Armando Fuentes, Kyle Jeffrey, Linda Luu, Sphurti Kirit More, Jornell Quiambao, Jarek Rettinghouse, Diego Reyes, Rosario Jau-regui Ruano, and Clayton Tan for overseeing robot operations and collecting human videos of the tasks, as well as Jeffrey Bingham, Jonathan Weisz, and Kanishka Rao for valuable discussions. We would also like to thank Tom Small for creating animations in this post and Paul Mooney for helping with dataset open-sourcing.

Read More

Applying Differential Privacy to Large Scale Image Classification

Machine learning (ML) models are becoming increasingly valuable for improved performance across a variety of consumer products, from recommendations to automatic image classification. However, despite aggregating large amounts of data, in theory it is possible for models to encode characteristics of individual entries from the training set. For example, experiments in controlled settings have shown that language models trained using email datasets may sometimes encode sensitive information included in the training data and may have the potential to reveal the presence of a particular user’s data in the training set. As such, it is important to prevent the encoding of such characteristics from individual training entries. To these ends, researchers are increasingly employing federated learning approaches.

Differential privacy (DP) provides a rigorous mathematical framework that allows researchers to quantify and understand the privacy guarantees of a system or an algorithm. Within the DP framework, privacy guarantees of a system are usually characterized by a positive parameter ε, called the privacy loss bound, with smaller ε corresponding to better privacy. One usually trains a model with DP guarantees using DP-SGD, a specialized training algorithm that provides DP guarantees for the trained model.

However training with DP-SGD typically has two major drawbacks. First, most existing implementations of DP-SGD are inefficient and slow, which makes it hard to use on large datasets. Second, DP-SGD training often significantly impacts utility (such as model accuracy) to the point that models trained with DP-SGD may become unusable in practice. As a result most DP research papers evaluate DP algorithms on very small datasets (MNIST, CIFAR-10, or UCI) and don’t even try to perform evaluation of larger datasets, such as ImageNet.

In “Toward Training at ImageNet Scale with Differential Privacy”, we share initial results from our ongoing effort to train a large image classification model on ImageNet using DP while maintaining high accuracy and minimizing computational cost. We show that the combination of various training techniques, such as careful choice of the model and hyperparameters, large batch training, and transfer learning from other datasets, can significantly boost accuracy of an ImageNet model trained with DP. To substantiate these discoveries and encourage follow-up research, we are also releasing the associated source code.

Testing Differential Privacy on ImageNet
We choose ImageNet classification as a demonstration of the practicality and efficacy of DP because: (1) it is an ambitious task for DP, for which no prior work shows sufficient progress; and (2) it is a public dataset on which other researchers can operate, so it represents an opportunity to collectively improve the utility of real-life DP training. Classification on ImageNet is challenging for DP because it requires large networks with many parameters. This translates into a significant amount of noise added into the computation, because the noise added scales with the size of the model.

Scaling Differential Privacy with JAX
Exploring multiple architectures and training configurations to research what works for DP can be debilitatingly slow. To streamline our efforts, we used JAX, a high-performance computational library based on XLA that can do efficient auto-vectorization and just-in-time compilation of the mathematical computations. Using these JAX features was previously recommended as a good way to speed up DP-SGD in the context of smaller datasets such as CIFAR-10.

We created our own implementation of DP-SGD on JAX and benchmarked it against the large ImageNet dataset (the code is included in our release). The implementation in JAX was relatively simple and resulted in noticeable performance gains simply because of using the XLA compiler. Compared to other implementations of DP-SGD, such as that in Tensorflow Privacy, the JAX implementation is consistently several times faster. It is typically even faster compared to the custom-built and optimized PyTorch Opacus.

Each step of our DP-SGD implementation takes approximately two forward-backward passes through the network. While this is slower than non-private training, which requires only a single forward-backward pass, it is still the most efficient known approach to train with the per-example gradients necessary for DP-SGD. The graph below shows training runtimes for two models on ImageNet with DP-SGD vs. non-private SGD, each on JAX. Overall, we find DP-SGD on JAX sufficiently fast to run large experiments just by slightly reducing the number of training runs used to find optimal hyperparameters compared to non-private training. This is significantly better than alternatives, such as Tensorflow Privacy, which we found to be ~5x–10x slower on our CIFAR10 and MNIST benchmarks.

Time in seconds per training epoch on ImageNet using a Resnet18 or Resnet50 architecture with 8 V100 GPUs.

Combining Techniques for Improved Accuracy
It is possible that future training algorithms may improve DP’s privacy-utility tradeoff. However, with current algorithms, such as DP-SGD, our experience points to an engineering “bag-of-tricks” approach to make DP more practical on challenging tasks like ImageNet.

Because we can train models faster with JAX, we can iterate quickly and explore multiple configurations to find what works well for DP. We report the following combination of techniques as useful to achieve non-trivial accuracy and privacy on ImageNet:

  • Full-batch training

    Theoretically, it is known that larger minibatch sizes improve the utility of DP-SGD, with full-batch training (i.e., where a full dataset is one batch) giving the best utility [1, 2], and empirical results are emerging to support this theory. Indeed, our experiments demonstrate that increasing the batch size along with the number of training epochs leads to a decrease in ε while still maintaining accuracy. However, training with extremely large batches is non-trivial as the batch cannot fit into GPU/TPU memory. So, we employed virtual large-batch training by accumulating gradients for multiple steps before updating the weights instead of applying gradient updates on each training step.

    Batch size 1024 4 × 1024 16 × 1024 64 × 1024
    Number of epochs 10 40 160 640
    Accuracy 56% 57.5% 57.9% 57.2%
    Privacy loss bound ε 9.8 × 108 6.1 × 107 3.5 × 106 6.7 × 104

  • Transfer learning from public data

    Pre-training on public data followed by DP fine-tuning on private data has previously been shown to improve accuracy on other benchmarks [3, 4]. A question that remains is what public data to use for a given task to optimize transfer learning. In this work we simulate a private/public data split by using ImageNet as “private” data and using Places365, another image classification dataset, as a proxy for “public” data. We pre-trained our models on Places365 before fine-tuning them with DP-SGD on ImageNet. Places365 only has images of landscapes and buildings, not of animals as ImageNet, so it is quite different, making it a good candidate to demonstrate the ability of the model to transfer to a different but related domain.

    We found that transfer learning from Places365 gave us 47.5% accuracy on ImageNet with a reasonable level of privacy (ε = 10). This is low compared to the 70% accuracy of a similar non-private model, but compared to naïve DP training on ImageNet, which yields either very low accuracy (2 – 5%) or no privacy (ε=109), this is quite good.

Privacy-accuracy tradeoff for Resnet-18 on ImageNet using large-batch training with transfer learning from Places365.

Next Steps
We hope these early results and source code provide an impetus for other researchers to work on improving DP for ambitious tasks such as ImageNet as a proxy for challenging production-scale tasks. With the much faster DP-SGD on JAX, we urge DP and ML researchers to explore diverse training regimes, model architectures, and algorithms to make DP more practical. To continue advancing the state of the field, we recommend researchers start with a baseline that incorporates full-batch training plus transfer learning.

Acknowledgments
This work was carried out with the support of the Google Visiting Researcher Program while Prof. Geambasu, an Associate Professor with Columbia University, was on sabbatical with Google Research. This work received substantial contributions from Steve Chien, Shuang Song, Andreas Terzis and Abhradeep Guha Thakurta.

Read More

Controlling Neural Networks with Rule Representations

Deep neural networks (DNNs) provide more accurate results as the size and coverage of their training data increases. While investing in high-quality and large-scale labeled datasets is one path to model improvement, another is leveraging prior knowledge, concisely referred to as “rules” — reasoning heuristics, equations, associative logic, or constraints. Consider a common example from physics where a model is given the task of predicting the next state in a double pendulum system. While the model may learn to estimate the total energy of the system at a given point in time only from empirical data, it will frequently overestimate the energy unless also provided an equation that reflects the known physical constraints, e.g., energy conservation. The model fails to capture such well-established physical rules on its own. How could one effectively teach such rules so that DNNs absorb the relevant knowledge beyond simply learning from the data?

In “Controlling Neural Networks with Rule Representations”, published at NeurIPS 2021, we present Deep Neural Networks with Controllable Rule Representations (DeepCTRL), an approach used to provide rules for a model agnostic to data type and model architecture that can be applied to any kind of rule defined for inputs and outputs. The key advantage of DeepCTRL is that it does not require retraining to adapt the rule strength. At inference, the user can adjust rule strength based on the desired operation point of accuracy. We also propose a novel input perturbation method, which helps generalize DeepCTRL to non-differentiable constraints. In real-world domains where incorporating rules is critical — such as physics and healthcare — we demonstrate the effectiveness of DeepCTRL in teaching rules for deep learning. DeepCTRL ensures that models follow rules more closely while also providing accuracy gains at downstream tasks, thus improving reliability and user trust in the trained models. Additionally, DeepCTRL enables novel use cases, such as hypothesis testing of the rules on data samples and unsupervised adaptation based on shared rules between datasets.

The benefits of learning from rules are multifaceted:

  • Rules can provide extra information for cases with minimal data, improving the test accuracy.
  • A major bottleneck for widespread use of DNNs is the lack of understanding the rationale behind their reasoning and inconsistencies. By minimizing inconsistencies, rules can improve the reliability of and user trust in DNNs.
  • DNNs are sensitive to slight input changes that are human-imperceptible. With rules, the impact of these changes can be minimized as the model search space is further constrained to reduce underspecification.

Learning Jointly from Rules and Tasks
The conventional approach to implementing rules incorporates them by including them in the calculation of the loss. There are three limitations of this approach that we aim to address: (i) rule strength needs to be defined before learning (thus the trained model cannot operate flexibly based on how much the data satisfies the rule); (ii) rule strength is not adaptable to target data at inference if there is any mismatch with the training setup; and (iii) the rule-based objective needs to be differentiable with respect to learnable parameters (to enable learning from labeled data).

DeepCTRL modifies canonical training by creating rule representations, coupled with data representations, which is the key to enable the rule strength to be controlled at inference time. During training, these representations are stochastically concatenated with a control parameter, indicated by α, into a single representation. The strength of the rule on the output decision can be improved by increasing the value of α. By modifying α at inference, users can control the behavior of the model to adapt to unseen data.

DeepCTRL pairs a data encoder and rule encoder, which produce two latent representations, which are coupled with corresponding objectives. The control parameter α is adjustable at inference to control the relative weight of each encoder.

Integrating Rules via Input Perturbations
Training with rule-based objectives requires the objectives to be differentiable with respect to the learnable parameters of the model. There are many valuable rules that are non-differentiable with respect to input. For example, “higher blood pressure than 140 is likely to lead to cardiovascular disease” is a rule that is hard to be combined with conventional DNNs. We also introduce a novel input perturbation method to generalize DeepCTRL to non-differentiable constraints by introducing small perturbations (random noise) to input features and constructing a rule-based constraint based on whether the outcome is in the desired direction.

Use Cases
We evaluate DeepCTRL on machine learning use cases from physics and healthcare, where utilization of rules is particularly important.

  • Improved Reliability Given Known Principles in Physics
  • We quantify reliability of a model with the verification ratio, which is the fraction of output samples that satisfy the rules. Operating at a better verification ratio could be beneficial, especially if the rules are known to be always valid, as in natural sciences. By adjusting the control parameter α, a higher rule verification ratio, and thus more reliable predictions, can be achieved.

    To demonstrate this, we consider the time-series data generated from double pendulum dynamics with friction from a given initial state. We define the task as predicting the next state of the double pendulum from the current state while imposing the rule of energy conservation. To quantify how much the rule is learned, we evaluate the verification ratio.

    DeepCTRL enables controlling a model’s behavior after learning, but without retraining. For the example of a double pendulum, conventional learning imposes no constraints to ensure the model follows physical laws, e.g., conservation of energy. The situation is similar for the case of DeepCTRL where the rule strength is low. So, the total energy of the system predicted at time t+1 ( blue) can sometimes be greater than that measured at time t (red), which is physically disallowed (bottom left). If rule strength in DeepCTRL is high, the model may follow the given rule but lose accuracy (discrepancy between red and blue is larger; bottom right). If rule strength is between the two extremes, the model may achieve higher accuracy (blue curve is close to red) and follow the rule properly (blue curve is lower than red one).

    We compare the performance of DeepCTRL on this task to conventional baselines of training with a fixed rule-based constraint as a regularization term added to the objective, λ. The highest of these regularization coefficients provides the highest verification ratio (shown by the green line in the second graph below), however, the prediction error is slightly worse than that of λ = 0.1 (orange line). We find that the lowest prediction error of the fixed baseline is comparable to that of DeepCTRL, but the highest verification ratio of the fixed baseline is still lower, which implies that DeepCTRL could provide accurate predictions while following the law of energy conservation. In addition, we consider the benchmark of imposing the rule-constraint with Lagrangian Dual Framework (LDF) and demonstrate two results where its hyperparameters are chosen by the lowest mean absolute error (LDF-MAE) and the highest rule verification ratio (LDF-Ratio) on the validation set. The performance of the LDF method is highly sensitive to what the main constraint is and its output is not reliable (black and pink dashed lines).

    Experimental results for the double pendulum task, showing the task-based mean absolute error (MAE), which measures the discrepancy between the ground truth and the model prediction, versus DeepCTRL as a function of the control parameter α. TaskOnly doesn’t have a rule constraint and Task & Rule has different rule strength (λ). LDF enforces rules by solving a constraint optimization problem.
    As above, but showing the verification ratio from different models.
    Experimental results for the double pendulum task showing the current and predicted energy at time t and t + 1, respectively.

    Additionally, the figures above illustrate the advantage DeepCTRL has over conventional approaches. For example, increasing the rule strength λ from 0.1 to 1.0 improves the verification ratio (from 0.7 to 0.9), but does not improve the mean absolute error. Arbitrarily increasing λ will continue to drive the verification ratio closer to 1, but will result in worse accuracy. Thus, finding the optimal value of λ will require many training runs through the baseline model, whereas DeepCTRL can find the optimal value for the control parameter α much more quickly.

  • Adapting to Distribution Shifts in Healthcare
  • The strengths of some rules may differ between subsets of the data. For example, in disease prediction, the correlation between cardiovascular disease and higher blood pressure is stronger for older patients than younger patients. In such situations, when the task is shared but data distribution and the validity of the rule differ between datasets, DeepCTRL can adapt to the distribution shifts by controlling α.

    Exploring this example, we focus on the task of predicting whether cardiovascular disease is present or not using a cardiovascular disease dataset. Given that higher systolic blood pressure is known to be strongly associated with cardiovascular disease, we consider the rule: “higher risk if the systolic blood pressure is higher”. Based on this, we split the patients into two groups: (1) unusual, where a patient has high blood pressure, but no disease or lower blood pressure, but has disease; and (2) usual, where a patient has high blood pressure and disease or low blood pressure, but no disease.

    We demonstrate below that the source data do not always follow the rule, and thus the effect of incorporating the rule can depend on the source data. The test cross entropy, which indicates classification accuracy (lower cross entropy is better), vs. rule strength for source or target datasets with varying usual / unusual ratio are visualized below. The error monotonically increases as α → 1 because the enforcement of the imposed rule, which doesn’t accurately reflect the source data, becomes more strict.

    Test cross entropy vs. rule strength for a source dataset with usual / unusual ratio of 0.30.

    When a trained model is transferred to the target domain, the error can be reduced by controlling α. To demonstrate this, we show three domain-specific datasets, which we call Target 1, 2, and 3. In Target 1, where the majority of patients are from the usual group, as α is increased, the rule-based representation has more weight and the resultant error decreases monotonically.

    As above, but for a Target dataset (1) with a usual / unusual ratio of 0.77.

    When the ratio of usual patients is decreased in Target 2 and 3, the optimal α is an intermediate value between 0 and 1. These demonstrate the capability to adapt the trained model via α.

    As above, but for Target 2 with a usual / unusual ratio of 0.50.
    As above, but for Target 3 with a usual / unusual ratio of 0.40.

Conclusions
Learning from rules can be crucial for constructing interpretable, robust, and reliable DNNs. We propose DeepCTRL, a new methodology used to incorporate rules into data-learned DNNs. DeepCTRL enables controllability of rule strength at inference without retraining. We propose a novel perturbation-based rule encoding method to integrate arbitrary rules into meaningful representations. We demonstrate three use cases of DeepCTRL: improving reliability given known principles, examining candidate rules, and domain adaptation using the rule strength.

Acknowledgements
We greatly appreciate the contributions of Jinsung Yoon, Xiang Zhang, Kihyuk Sohn and Tomas Pfister.

Read More

Does Your Medical Image Classifier Know What It Doesn’t Know?

Deep machine learning (ML) systems have achieved considerable success in medical image analysis in recent years. One major contributing factor is access to abundant labeled datasets, which are used to train highly effective supervised deep learning models. However, in the real-world, these models may encounter samples exhibiting rare conditions that are individually too infrequent for per-condition classification. Nevertheless, such conditions can be collectively common because they follow a long-tail distribution and when taken together can represent a significant portion of cases — e.g., in a recent deep learning dermatological study, hundreds of rare conditions composed around 20% of cases encountered by the model at test time.

To prevent models from generating erroneous outputs on rare samples at test time, there remains a considerable need for deep learning systems with the ability to recognize when a sample is not a condition it can identify. Detecting previously unseen conditions can be thought of as an out-of-distribution (OOD) detection task. By successfully identifying OOD samples, preventive measures can be taken, like abstaining from prediction or deferring to a human expert.

Traditional computer vision OOD detection benchmarks work to detect dataset distribution shifts. For example, a model may be trained on CIFAR images but be presented with street view house numbers (SVHN) as OOD samples, two datasets with very different semantic meanings. Other benchmarks seek to detect slight differences in semantic information, e.g., between images of a truck and a pickup truck, or two different skin conditions. The semantic distribution shifts in such near-OOD detection problems are more subtle in comparison to dataset distribution shifts, and thus, are harder to detect.

In “Does Your Dermatology Classifier Know What it Doesn’t Know? Detecting the Long-Tail of Unseen Conditions”, published in Medical Image Analysis, we tackle this near-OOD detection task in the application of dermatology image classification. We propose a novel hierarchical outlier detection (HOD) loss, which leverages existing fine-grained labels of rare conditions from the long tail and modifies the loss function to group unseen conditions and improve identification of these near OOD categories. Coupled with various representation learning methods and the diverse ensemble strategy, this approach enables us to achieve better performance for detecting OOD inputs.

The Near-OOD Dermatology Dataset
We curated a near-OOD dermatology dataset that includes 26 inlier conditions, each of which are represented by at least 100 samples, and 199 rare conditions considered to be outliers. Outlier conditions can have as low as one sample per condition. The separation criteria between inlier and outlier conditions can be specified by the user. Here the cutoff sample size between inlier and outlier was 100, consistent with our previous study. The outliers are further split into training, validation, and test sets that are intentionally mutually exclusive to mimic real-world scenarios, where rare conditions shown during test time may have not been seen in training.

Long tail distribution of different dermatological conditions in our dataset. The 26 inlier conditions, with at least 100 samples, (blue) and the remaining 199 rare outlier conditions (orange). Outlier conditions can have as low as one sample per condition.
    Train set  Validation set      Test set
Inlier Outlier Inlier Outlier Inlier Outlier
Number of classes 26 68 26 66 26 65
Number of samples 8854 1111 1251 1082 1192 937
Inlier and outlier conditions in our benchmark dataset and detailed dataset split statistics. The outliers are further split into mutually exclusive train, validation, and test sets.

Hierarchical Outlier Detection Loss
We propose to use “known outlier” samples during training that are leveraged to aid detection of “unknown outlier” samples during test time. Our novel hierarchical outlier detection (HOD) loss performs a fine-grained classification of individual classes for all inlier or outlier classes and, in parallel, a coarse-grained binary classification of inliers vs. outliers in a hierarchical setup (see the figure below). Our experiments confirmed that HOD is more effective than performing a coarse-grained classification followed by a fine-grained classification, as this could result in a bottleneck that impacted the performance of the fine-grained classifier.

We use the sum of the predictive probabilities of the outlier classes as the OOD score. As a primary OOD detection metric we use the area under receiver operating characteristics (AUROC) curve, which ranges between 0 and 1 and gives us a measure of separability between inliers and outliers. A perfect OOD detector, which separates all inliers from outliers, is assigned an AUROC score of 1. A popular baseline method, called reject bucket, separates each inlier individually from the outliers, which are grouped into a dedicated single abstention class. In addition to a fine-grained classification for each individual inlier and outlier classes, the HOD loss–based approach separates the inliers collectively from the outliers with a coarse-grained prediction loss, resulting in better generalization. While similar, we demonstrate that our HOD loss–based approach outperforms other baseline methods that leverage outlier data during training, achieving an AUROC score of 79.4% on the benchmark, a significant improvement over that of reject bucket, which achieves 75.6%.

Our model architecture and the HOD loss. The encoder (green) represents the wide ResNet 101×3 model pre-trained with different representation learning models (ImageNet, BiT, SimCLR, and MICLe; see below). The output of the encoder is sent to the HOD loss where fine-grained and coarse-grained predictions for inliers (blue) and outliers (orange) are obtained. The coarse predictions are obtained by summing over the fine-grained probabilities as indicated in the figure. The OOD score is defined as the sum of the probabilities of outlier classes.

Representation Learning and the Diverse Ensemble Strategy
We also investigate how different types of representation learning help in OOD detection in conjunction with HOD by pretraining on ImageNet, BiT-L, SimCLR and MICLe models. We observe that including HOD loss improves OOD performance compared to the reject bucket baseline method for all four representation learning methods.

Representation Learning
Methods
OOD detection metric (AUROC %)
With reject bucket With HOD loss
ImageNet 74.7% 77%
BiT-L 75.6% 79.4%
SimCLR 75.2% 77.2%
MICLe 76.7% 78.8%
OOD detection performance for different representation learning models with reject bucket and with HOD loss.

Another orthogonal approach for improving OOD detection performance and accuracy is deep ensemble, which aggregates outputs from multiple independently trained models to provide a final prediction. We build upon deep ensemble, but instead of using a fixed architecture with a fixed pre-training, we combine different representation learning architectures (ImageNet, BiT-L, SimCLR and MICLe) and introduce objective loss functions (HOD and reject bucket). We call this a diverse ensemble strategy, which we demonstrate outperforms the deep ensemble for OOD performance and inlier accuracy.

Downstream Clinical Trust Analysis
While we mainly focus on improving the performance for OOD detection, the ultimate goal for our dermatology model is to have high accuracy in predicting inlier and outlier conditions. We go beyond traditional performance metrics and introduce a “penalty” matrix that jointly evaluates inlier and outlier predictions for model trust analysis to approximate downstream impact. For a fixed confidence threshold, we count the following types of mistakes: (i) incorrect inlier predictions (i.e., mistaking inlier condition A as inlier condition B); (ii) incorrect abstention of inliers (i.e., abstaining from making a prediction for an inlier); and (iii) incorrect prediction for outliers as one of the inlier classes.

To account for the asymmetrical consequences of the different types of mistakes, penalties can be 0, 0.5, or 1. Both incorrect inlier and outlier-as-inlier predictions can potentially erode user trust in the model and were penalized with a score of 1. Incorrect abstention of an inlier as an outlier was penalized with a score of 0.5, indicating that potential model users should seek additional guidance given the model-expressed uncertainty or abstention. For correct decisions no cost is incurred, indicated by a score of 0.

                  Action of the Model
Prediction as Inlier Abstain
Inlier 0 (Correct)

1 (Incorrect, mistakes
that may erode trust)

0.5 (Incorrect,
abstains inliers)
Outlier     1 (Incorrect, mistakes
that may erode trust)
0 (Correct)
The penalty matrix is designed to capture the potential impact of different types of model errors.

Because real-world scenarios are more complex and contain a variety of unknown variables, the numbers used here represent simplifications to enable qualitative approximations for the downstream impact on user trust of outlier detection models, which we refer to as “cost”. We use the penalty matrix to estimate a downstream cost on the test set and compare our method against the baseline, thereby making a stronger case for its effectiveness in real-world scenarios. As shown in the plot below, our proposed solution incurs a much lower estimated cost in comparison to baseline over all possible operating points.

Trust analysis comparing our proposed method to the baseline (reject bucket) for a range of outlier recall rates, indicated by 𝛕. We show that our method reduces downstream estimated cost, potentially reflecting improved downstream impact.

Conclusion
In real-world deployment, medical ML models may encounter conditions that were not seen in training, and it’s important that they accurately identify when they do not know a specific condition. Detecting those OOD inputs is an important step to improving safety. We develop an HOD loss that leverages outlier data during training, and combine it with pre-trained representation learning models and a diverse ensemble to further boost performance, significantly outperforming the baseline approach on our new dermatology benchmark dataset. We believe that our approach, aligned with our AI Principles, can aid successful translation of ML algorithms into real-world scenarios. Although we have primarily focused on OOD detection for dermatology, most of our contributions are fairly generic and can be easily incorporated into OOD detection for other applications.

Acknowledgements
We would like to thank Shekoofeh Azizi, Aaron Loh, Vivek Natarajan, Basil Mustafa, Nick Pawlowski, Jan Freyberg, Yuan Liu, Zach Beaver, Nam Vo, Peggy Bui, Samantha Winter, Patricia MacWilliams, Greg S. Corrado, Umesh Telang, Yun Liu, Taylan Cemgil, Alan Karthikesalingam, Balaji Lakshminarayanan, and Jim Winkens for their contributions. We would also like to thank Tom Small for creating the post animation.

Read More