Empirical study: We evaluated three approaches for robots to navigate to objects in six visually diverse homes.
TLDR: Semantic navigation is necessary to deploy mobile robots in uncontrolled environments like our homes, schools, and hospitals. Many learning-based approaches have been proposed in response to the lack of semantic understanding of the classical pipeline for spatial navigation. But learned visual navigation policies have predominantly been evaluated in simulation. How well do different classes of methods work on a robot? We present a large-scale empirical study of semantic visual navigation methods comparing representative methods from classical, modular, and end-to-end learning approaches. We evaluate policies across six homes with no prior experience, maps, or instrumentation. We find that modular learning works well in the real world, attaining a 90% success rate. In contrast, end-to-end learning does not, dropping from 77% simulation to 23% real-world success rate due to a large image domain gap between simulation and reality. For practitioners, we show that modular learning is a reliable approach to navigate to objects: modularity and abstraction in policy design enable Sim-to-Real transfer. For researchers, we identify two key issues that prevent today’s simulators from being reliable evaluation benchmarks — (A) a large Sim-to-Real gap in images and (B) a disconnect between simulation and real-world error modes.
Object Goal Navigation
We instantiate semantic navigation with the Object Goal navigation task [Anderson 2018], where a robot starts in a completely unseen environment and is asked to find an instance of an object category, let’s say a toilet. The robot has access to only a first-person RGB and depth camera and a pose sensor (computed with LiDAR-based SLAM).
This task is challenging. It requires not only spatial scene understanding of distinguishing free space and obstacles and semantic scene understanding of detecting objects, but also requires learning semantic exploration priors. For example, if a human wants to find a toilet in this scene, most of us would choose the hallway because it is most likely to lead to a toilet. Teaching this kind of spatial common sense or semantic priors to an autonomous agent is challenging. While exploring the scene for the desired object, the robot also needs to remember explored and unexplored areas.
Methods
So how do we train autonomous agents capable of efficient navigation while tackling all these challenges? A classical approach to this problem builds a geometric map using depth sensors, explores the environment with a heuristic, like frontier exploration [Yamauchi 1997], which explores the closest unexplored region, and uses an analytical planner to reach exploration goals and the goal object as soon as it is in sight. An end-to-end learning approach predicts actions directly from raw observations with a deep neural network consisting of visual encoders for image frames followed by a recurrent layer for memory [Ramrakhya 2022]. A modular learning approach builds a semantic map by projecting predicted semantic segmentation using depth, predicts an exploration goal with a goal-oriented semantic policy as a function of the semantic map and the goal object, and reaches it with a planner [Chaplot 2020].
Large-scale Real-world Empirical Evaluation
While many approaches to navigate to objects have been proposed over the past few years, learned navigation policies have predominantly been evaluated in simulation, which opens the field to the risk of sim-only research that does not generalize to the real world. We address this issue through a large-scale empirical evaluation of representative classical, end-to-end learning, and modular learning approaches across 6 unseen homes and 6 goal object categories (chair, couch, plant, toilet, TV).
Results
We compare approaches in terms of success rate within a limited budget of 200 robot actions and Success weighted by Path Length (SPL), a measure of path efficiency. In simulation, all approaches perform comparably. But in the real world, modular learning and classical approaches transfer really well while end-to-end learning fails to transfer.
We illustrate these results qualitatively with one representative trajectory.
Result 1: Modular Learning is Reliable
We find that modular learning is very reliable on a robot, with a 90% success rate.
Result 2: Modular Learning Explores more Efficiently than the Classical Approach
Modular learning improves by 10% real-world success rate over the classical approach. With a limited time budget, inefficient exploration can lead to failure.
Result 3: End-to-end Learning Fails to Transfer
While classical and modular learning approaches work well on a robot, end-to-end learning does not, at only 23% success rate.
Analysis
Insight 1: Why does Modular Transfer while End-to-end does not?
Why does modular learning transfer so well while end-to-end learning does not? To answer this question, we reconstructed one real-world home in simulation and conducted experiments with identical episodes in sim and reality.
The semantic exploration policy of the modular learning approach takes a semantic map as input, while the end-to-end policy directly operates on the RGB-D frames. The semantic map space is invariant between sim and reality, while the image space exhibits a large domain gap.
The semantic map domain invariance allows the modular learning approach to transfer well from sim to reality. In contrast, the image domain gap causes a large drop in performance when transferring a segmentation model trained in the real world to simulation and vice versa. If semantic segmentation transfers poorly from sim to reality, it is reasonable to expect an end-to-end semantic navigation policy trained on sim images to transfer poorly to real-world images.
Insight 2: Sim vs Real Gap in Error Modes for Modular Learning
Surprisingly, modular learning works even better in reality than simulation. Detailed analysis reveals that a lot of the failures of the modular learning policy that occur in sim are due to reconstruction errors, both visual and physical, which do not happen in reality. In contrast, failures in the real world are predominantly due to depth sensor errors, while most semantic navigation benchmarks in simulation assume perfect depth sensing. Besides explaining the performance gap between sim and reality for modular learning, this gap in error modes is concerning because it limits the usefulness of simulation to diagnose bottlenecks and further improve policies. We show representative examples of each error mode and propose concrete steps forward to close this gap in the paper.
Takeaways
For practitioners:
Modular learning can reliably navigate to objects with 90% success
For researchers:
Models relying on RGB images are hard to transfer from sim to real => leverage modularity and abstraction in policies
Disconnect between sim and real error modes => evaluate semantic navigation on real robots
If you’ve enjoyed this post and would like to learn more, please check out the Science Robotics 2023 paper and talk. Code coming soon. Also, please don’t hesitate to reach out to Theophile Gervet!
ReLM enables writing tests that are guaranteed to come from the set of valid strings, such as dates. Without ReLM, LLMs are free to complete prompts with non-date answers, which are difficult to assess.
TL;DR: While large language models (LLMs) have been touted for their ability to generate natural-sounding text, there are concerns around potential negative effects of LLMs such as data memorization, bias, and inappropriate language. We introduce ReLM (MLSys ’23), a system for validating and querying LLMs using standard regular expressions. We demonstrate via validation tasks on memorization, bias, toxicity, and language understanding that ReLM achieves up to (15times) higher system efficiency, (2.5times) data efficiency, and increased prompt-tuning coverage compared to state-of-the-art ad-hoc queries.
The Winners and Losers in Sequence Prediction
Consider playing a video game (perhaps in your youth). You randomly enter the following sequence in your controller:
Suddenly, your character becomes invincible. You’ve discovered the “secret” sequence that the game developer used for testing the levels. After this point in time, everything you do is trivial—the game is over, you win.
I claim that using large language models (LLMs) to generate text content is similar to playing a game with such secret sequences. Rather than getting surprised to see a change in game state, users of LLMs may be surprised to see a response that is not quite right. It’s possible the LLM violates someone’s privacy, encodes a stereotype, contains explicit material, or hallucinates an event. However, unlike the game, it may be difficult to even reason about how that sequence manifested.
LLMs operate over tokens (i.e., integers), which are translated via the tokenizer to text. For encoding systems such as Byte-Pair Encoding (BPE), each token maps to 1+ characters. Using the controller analogy, an LLM is a controller having 50000+ “buttons”, and certain buttons operate as “macros” over the string space. For example, ⇑ could represent and ⇓could represent , enabling the same code to be represented with ⇑⇓. Importantly, the LLM is unaware of this equivalence mapping—a single edit changing to would invalidate ⇑ being substituted into the sequence. Writing “the” instead of “The” could result in a different response from the LLM, even though the difference is stylistic to humans. These tokenization artifacts combined with potential shortcomings in the LLM’s internal reasoning create a minefield of unassuming LLM “bugs”.
The possibility that a model may deviate from the “correct” set of sequences motivates LLM validation—the task of evaluating a model’s behavior among many axes so that shortcomings can be identified and addressed. The problem can be much worse than our game example—when we expect a single sequence, nearly all sequences are incorrect, a process that exponentially diverges as a function of the sequence length. Intuitively, it gets much harder to output the right sequence when the sequence length grows—correctly “dancing” is easier than . In the lab, it’s hard to notice the consequences of generating an incorrect sequence, but as society embraces LLMs for more serious tasks (e.g., writing emails, filing taxes), we’ll want to have more confidence that they work as intended.
Short of formal verification, the best validation mechanism we have is to build comprehensive test suites for characterizing model behavior over a set of input sequences. Benchmarking efforts such as HeLM are continuing to increase the scope of LLM validation by providing a gamut of test sequences. While I strongly agree with the motivation, I ask: Should we be rethinking how tests themselves are written? Can we systematically generalize sequences to high-level patterns such that test writers don’t have to reason about all the peculiar LLM implementation details that we just discussed?
Background: Prompting LLMs
With game codes, the code is entered through the controller. The result, on the other hand, is reflected in the game state (i.e., your character becomes invincible, which I represent with a good outcome ✓). But how does this analogy hold for LLMs?
⇒✓
For autoregressive LLMs, typically the input is a sequence and the output is a sequence, and both of these are in the same space (e.g., strings of human language). For example, prompting the model with the word “The” would perhaps be followed by “ cat” in the sense that it is either likely or simply possible according to the LLM and the sampling procedure.
Ⓣⓗⓔ⇒ ⓒⓐⓣ
If “ cat” is considered a good answer, then we “won” the sequence lottery (represented by ✓). If the sequence is considered a bad answer e.g., the misspelling ” kAt”, then we lost (represented by ✗).
Ⓣⓗⓔ ⓒⓐⓣ⇒✓
Ⓣⓗⓔ ⓚⒶⓣ⇒✗
Keep in mind that the token-level encoding is not unique for a given string sequence, so the above LLM examples will have many representations. The number of representations compounds with the size of the reference strings e.g., all the possible misspellings of ” cat”. Furthermore, the LLM will output a distribution over good and bad sequences, so we’d like to summarize them e.g., by measuring what percentage of sequences are good.
Problem: Testing LLMs
As test designers, our goal is to quantitatively measure some aspect of the LLM’s behavior. As we are studying a general notion of tests, we’ll introduce a small amount of formalism to argue our points. Let us call a test, (T), which takes a model, (M), and returns a boolean represented with 0 (bad answer) or 1 (good answer).
$$T: M → {0, 1}$$
For classification tasks, (T) represents whether the model, (M), classified a particular example correctly; the average of these tests is reported with test accuracy. Since correct classification boils down to the predicted class ((y_text{pred}:=M(x))) matching the ground-truth class ((y)), this test can be implemented in one line of code.
y_pred == y
What does (T) look like for LLMs? Let’s say we want to test if “The” is followed by “ cat”. Constructing such a test is straightforward, because we can just check if the statement is true. We can imagine (x) representing “The” and (y) representing “ cat”. If (y) is sampled from some distribution (i.e., it’s a random variable), we can get many samples to compute the mean score. Depending on the application, we may or may not be interested in including all the encodings discussed previously as well as possible variations of the base pattern e.g., misspellings.
Because of the potentially massive number of sequences involved in a test, LLM tests are both more difficult to express and evaluate, leading to tests with insufficient coverage. For example, if we happened to miss some prompt that does lead to “ cat”, our test had a false negative—it concluded it was not possible when it actually was. If we were to check if “ cat” is the most likely string following “The”, we may get false positives in the omitted cases where “ kAt” was more likely. The test designer must carefully consider trading off such sources of error with the implementation and execution complexity of the test.
With traditional string-level APIs, it’s difficult to make testing trade-offs without rewriting the testing logic altogether—one has to write testing code that explicitly samples from the distribution of interest (e.g., the choice of encodings and misspellings). For example, a privacy-oriented user would want you to be reasonably sure that the LLM couldn’t emit their private information, even with the presence of encoding or misspelling artifacts. Such a minor change in the test’s scope would result in dramatic changes to the underlying test implementation. To make matters worse, testing becomes even more difficult when the base pattern of interest is a combinatorial object, such as integers, dates, URL strings, and phone numbers—sets too large to enumerate.
Example: Does GPT-2XL know George Washington’s birth date?
To give a concrete example of false positives and false negatives, let’s consider a simple test of knowledge: Does the LLM know George Washington’s birth date? As shown in the figure below, we formulate this ‘test’ by asking the model to rank 4 choices. Such multiple-choice questions are common in today’s benchmark suites because they are simple to implement. However, 4 choices do not cover all birth dates; what if the model was lucky enough to eliminate the other 3 answers and just guess? That would be a false positive. As shown below, the correct date of February 22, 1732, is chosen by the model because it is the most likely; thus this test concludes the model does know the birth date.
We can also try free response, as shown in in the following figure. However, the most likely reply is not a date and thus penalizes the model for being more general than the test task—a possible false negative. “this day in 1732” and “a farm” are reasonable completions for the fill-in-the-blank, yet an automated test system would mark them as not matching the solution set.
A more natural alternative, and one that we explore via our work in ReLM (MLSys ’23), would be to only consider answers that follow a specific date-related format. The way we evaluate this query is by constraining generation to be of the form <Month> <Day>, <Year>, as if we had a “complete” multiple choice solution set, which is too large to enumerate. Because this pattern contains exactly all the solutions of interest, the test minimizes spurious conclusions due to false positives and false negatives. In doing so, we confirm a true negative—GPT-2XL believes George Washington was born on July 4, 1732. That’s of course factually incorrect, but we didn’t trick ourselves into thinking the LLM knew the answer when it didn’t.
While we don’t have the space to exactly write out how to run these queries in ReLM, you can rest assured that you’ll find the above example in our code.
The Case for ReLM
Regular expressions describe the regular languages and are a way of specifying text patterns. Many text-processing tools, such as grep, use regular expressions to locate patterns in text. At a high level, regular languages can describe patterns using the primitives of string literals, disjunction (“OR”), and repetitions. For the purpose of this blog, you can think of regular languages as allowing you to interpolate between a 4-way multiple choice (e.g., A OR B OR C OR D) and one with a combinatorial explosion of choices in a free-response (e.g., all strings of length (N)). At the implementation level, regular expressions can be expressed with an equivalent directed graph, called an automaton, that represents all sequences via the edge transitions in the graph.
ReLM is a Regular Expression engine for Language Models. As shown below, ReLM is an automaton-based constrained decoding system on top of the LLM. Users of ReLM construct queries that encompass the test pattern and how to execute it. Because the user explicitly describes the pattern of interest, ReLM can avoid doing extra work that results in false negatives. Additionally, since the user describes variations of the pattern (e.g., encodings and misspellings), ReLM can cover often-ignored elements in the test set, avoiding false positives. We can essentially describe any pattern or mutation of the pattern as long as the effects can be correctly propagated to the final automaton. Thankfully, there is a rich theory on ways to perform operations on automata (e.g., including misspellings and rewrites), which we utilize when compiling the final automaton. Thus, the user can 1) exactly specify large sets of interest and 2) cover the tokenization artifacts mentioned in the introduction.
Since the same query pattern can be used for many execution parameters, a single test encoded as a regular expression can lead to a variety of analyses. For example, the query in the above figure could be modified to include all misspellings of the base pattern as well as all the encodings. Additionally, the user can choose between sampling from the test set or finding the most likely sequence in it. Our paper’s results exploring queries surrounding memorization (extracting URLs), gender bias (measuring distributional bias in professions), toxicity (extracting offensive words), and language understanding (completing the correct answer) show that ReLM achieves up to (15times) higher system efficiency in extracting memorized URLs, (2.5times) data efficiency in extracting offensive content, and increased statistical and prompt-tuning coverage compared to state-of-the-art ad-hoc queries.
Our results indicate that subtle differences in query specification can yield dramatically different results. For example, we find that randomly sampling from a URL prefix “https://www.” tends to generate invalid or duplicated URLs. ReLM avoids such inefficiency by returning strings matching the valid URL pattern sorted by likelihood. Likewise, searching over the space of all encodings as well as misspellings enables the (2.5times) data efficiency in extracting toxic content from the LLM and results in different results on the gender bias task. Finally, we can recover prompt tuning behavior on the LAMBADA dataset by modifying the regular expression pattern, demonstrating that even language understanding tasks can benefit from such pattern specification.
Conclusion
In this blog, we outlined why it’s important to think of LLM tests in terms of patterns rather than individual sequences. Our work introduces ReLM, a Regular Expression engine for Language Models, to enable test writers to easily write LLM tests that can be described via pattern matching. If you’re interested in learning more about ReLM and how it can reduce the burden of LLM validation, please check out our paper (MLSys ’23) as well as our open-source code.
DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.
How can we be better prepared for the next pandemic?
Patient data collected by groups such as hospitals and health agencies is a critical tool for monitoring and preventing the spread of disease. Unfortunately, while this data contains a wealth of useful information for disease forecasting, the data itself may be highly sensitive and stored in disparate locations (e.g., across multiple hospitals, health agencies, and districts).
In this post we discuss our research on federated learning, which aims to tackle this challenge by performing decentralized learning across private data silos. We then explore an application of our research to the problem of privacy-preserving pandemic forecasting—a scenario where we recently won a 1st place, $100k prize in a competition hosted by the US & UK governments—and end by discussing several directions of future work based on our experiences.
Part 1: Privacy, Personalization, and Cross-Silo Federated Learning
Federated learning (FL) is a technique to train models using decentralized data without directly communicating such data. Typically:
a central server sends a model to participating clients;
the clients train that model using their own local data and send back updated models; and
the server aggregates the updates (e.g., via averaging, as in FedAvg)
However, while significant attention has been given to cross-device FL (e.g., learning across large networks of devices such as mobile phones), the area of cross-silo FL (e.g., learning across a handful of data silos such as hospitals or financial institutions) is relatively under-explored, and it presents interesting challenges in terms of how to best model federated data and mitigate privacy risks. In Part 1.1, we’ll examine a suitable privacy granularity for such settings, and in Part 1.2, we’ll see how this interfaces with model personalization, an important technique in handling data heterogeneity across clients.
1.1. How should we protect privacy in cross-silo federated learning?
Although the high-level federated learning workflow described above can help to mitigate systemic privacy risks, pastwork suggests that FL’s data minimization principle alone isn’t sufficient for data privacy, as the client models and updates can still reveal sensitive information.
This is where differential privacy (DP) can come in handy. DP provides both a formal guarantee and an effective empirical mitigation to attacks like membership inference and data poisoning. In a nutshell, DP is a statistical notion of privacy where we add randomness to a query on a “dataset” to create quantifiable uncertainty about whether any one “data point” has contributed to the query output. DP is typically measured by two scalars ((varepsilon, delta))—the smaller, the more private.
In the above, “dataset” and “data point” are in quotes because privacy granularity matters. In cross-device FL, it is common to apply “client-level DP” when training a model, where the federated clients (e.g., mobile phones) are thought of as “data points”. This effectively ensures that each participating client/mobile phone user remains private.
However, while client-level DP makes sense for cross-device FL as each client naturally corresponds to a person, this privacy granularity may not be suitable for cross-silo FL, where there are fewer (2-100) ‘clients’ but each holds many data subjects that require protection, e.g., each ‘client’ may be a hospital, bank, or school with many patient, customer, or student records.
In our recent work (NeurIPS’22), we instead consider the notion of “silo-specific example-level DP” in cross-silo FL (see figure above). In short, this says that the (k)-th data silo may set its own ((varepsilon_k, delta_k)) example-level DP target for any learning algorithm with respect to its local dataset.
This notion is better aligned with real-world use cases of cross-silo FL, where each data subject contributes a single “example”, e.g., each patient in a hospital contributes their individual medical record. It is also very easy to implement: each silo can just run DP-SGD for local gradient steps with calibrated per-step noise. As we discuss below, this alternate privacy granularity affects how we consider modeling federated data to improve privacy/utility trade-offs.
1.2. The interplay of privacy, heterogeneity, and model personalization
Let’s now look at how this privacy granularity may interface with model personalization in federated learning.
Model personalization is a common technique used to improve model performance in FL when data heterogeneity (i.e. non-identically distributed data) exists between data silos.1 Indeed, existingbenchmarks suggest that realistic federated datasets may be highly heterogeneous and that fitting separate local models on the federated data are already competitive baselines.
When considering model personalization techniques under silo-specific example-level privacy, we find that a unique trade-off may emerge between the utility costs from privacy and data heterogeneity (see figure below):
As DP noises are added independently by each silo for its own privacy targets, these noises are reflected in the silos’ model updates and can thus be smoothed out when these updates are averaged (e.g. via FedAvg), leading to a smaller utility drop from DP for the federated model.
On the other hand, federation also means that the shared, federated model may suffer from data heterogeneity (“one size does not fit all”).
This “privacy-heterogeneity cost tradeoff” is interesting because it suggests that model personalization can play a key and distinct role in cross-silo FL. Intuitively, local training (no FL participation) and FedAvg (full FL participation) can be viewed as two ends of a personalization spectrum with identical privacy costs—silos’ participation in FL itself does not incur privacy costs due to DP’s robustness to post-processing—and various personalization algorithms (finetuning, clustering, …) are effectively navigating this spectrum in different ways.
If local training minimizes the effect of data heterogeneity but enjoys no DP noise reduction, and contrarily for FedAvg, it is natural to wonder whether there are personalization methods that lie in between and achieve better utility. If so, what methods would work best?
Our analysis points to mean-regularized multi-task learning (MR-MTL) as a simple yet particularly suitable form of personalization. MR-MTL simply asks each client (k) to train its own local model (w_k), regularize it towards the mean of others’ models (bar w) via a penalty (fraclambda 2 | w_k – bar w |_2^2 ), and keep (w_k) across rounds (i.e. client is stateful). The mean model (bar w) is maintained by the FL server (as in FedAvg) and may be updated in every round. More concretely, each local update step takes the following form:
The hyperparameter (lambda) serves as a smooth knob between local training and FedAvg: (lambda = 0) recovers local training, and a larger (lambda) forces the personalized models to be closer to each other (intuitively, “federate more”).
MR-MTL has some nice properties in the context of private cross-silo FL:
Noise reduction is attained throughout training via the soft proximity constraint towards an averaged model;
The mean-regularization itself has no privacy overhead;2 and
(lambda) provides a smooth interpolation along the personalization spectrum.
Why is the above interesting? Consider the following experiment where we try a range of (lambda) values roughly interpolating local training and FedAvg. Observe that we could find a “sweet spot” (lambda^ast) that outperforms both of the endpoints under the same privacy cost. Moreover, both the utility advantage of MR-MTL((lambda^ast)) over the endpoints, and (lambda^ast) itself, are larger under privacy; intuitively, this says that silos are encouraged to “federate more” for noise reduction.
The above provides rough intuition on why MR-MTL may be a strong baseline for private cross-silo FL and motivates this approach for a practical pandemic forecasting problem, which we discuss in Part 2. Our full paper delves deeper into the analyses and provides additional results and discussions!
Part 2: Federated Pandemic Forecasting at the US/UK PETs Challenge
The pandemic forecasting problem asks the following: Given a person’s demographic attributes (e.g. age, household size), locations, activities, infection history, and the contact network, what is the likelihood of infection in the next (t_text{pred}=7) days? Can we make predictions while protecting the privacy of individuals? Moreover, what if the data are siloed across administrative regions?
There’s a lot to unpack in the above. First, the pandemic outbreak problem follows a discrete-time SIR model (Susceptible → Infectious → Recovered) and we begin with a subset of the population infected. Subsequently,
Each person goes about their usual daily activities and gets into contact with others (e.g. at a shopping mall)—this forms a contact graph where individuals are nodes and direct contacts are edges;
Each person may get infected with different risk levels depending on a myriad of factors—their age, the nature and duration of their contact(s), their node centrality, etc.; and
Such infection can also be asymptomatic—the individual can appear in the S state while being secretly infectious.
The challenge dataset models a pandemic outbreak in Virginia and contains roughly 7.7 million nodes (persons) and 186 million edges (contacts) with health states over 63 days; so the actual contact graph is fairly large but also quite sparse.
There are a few extra factors that make this problem challenging:
Data imbalance: less than 5% of people are ever in the I or R state and roughly 0.3% of people became infected in the final week.
Data silos: the true contact graph is cut along administrative boundaries, e.g., by grouped FIPS codes/counties. Each silo only sees a local subgraph, but people may still travel and make contacts across multiple regions! in In the official evaluation, the population sizes can also vary by more than 10(times) across silos.
Temporal modeling: we are given the first (t_text{train} = 56) days of each person’s health states (S/I/R) and asked to predict individual infections any time in the subsequent ( t_text{pred} = 7 ) days. What is a training example in this case?How should we perform temporal partitioning? How does this relate to privacy accounting?
Graphs generally complicate DP: we are often used to ML settings where we can clearly define the privacy granularity and how it relates to an actual individual (e.g. medical images of patients). This is tricky with graphs: people can make different numbers of contacts each of different natures, and their influence can propagate throughout the graph. At a high level (and as specified by the scope of sensitive data of the competition), what we care about is known as node-level DP—the model output is “roughly the same” if we add/remove/replace a node, along with its edges.
2.2. Applying MR-MTL with silo-specific example-level privacy
One clean approach to the pandemic forecasting problem is to just operate on the individual level and view it as (federated) binary classification: if we could build a feature vector to summarize an individual, then risk scores are simply the sigmoid probabilities of near-term infection.
Of course, the problem lies in what that feature vector (and the corresponding label) is—we’ll get to this in the following section. But already, we can see that MR-MTL with silo-specific example-level privacy (from Part 1) is a nice framework for a number of reasons:
Model personalization is likely needed as the silos are large and heterogeneous by construction (geographic regions are unlike to all be similar).
Privacy definition: There are a small number of clients, but each holds many data subjects, and client-level DP isn’t suitable.
Usability, efficiency, and scalability: MR-MTL is remarkably easy to implement with minimal resource overhead (over FedAvg and local training). This is crucial for real-world applications.
Adaptability and explainability: The framework is highly adaptable to any learning algorithm that can take DP-SGD-style updates. It also preserves the explainability of the underlying ML algorithm as we don’t obfuscate the model weights, updates, or predictions.
It is also helpful to look at the threat model we might be dealing with and how our framework behaves under it; the interested reader may find more details in the extended post!
2.3. Building training examples
We now describe how to convert individual information and the contact network into a tabular dataset for every silo ( k ) with ( n_k ) nodes.
Recall that our task is to predict the risk of infection of a person within ( t_text{pred} = 7) days, and that each silo only sees its local subgraph. We formulate this via a silo-specific set of examples ( ( X_k in mathbb R^{n_k times d}, Y_k in mathbb {0, 1}^{n_k} ) ), where the features ( {X_k^{(i)} in mathbb R^d} ) describe the neighborhood around a person ( i ) (see figure) and binary label ( {Y_k^{(i)}} ) denotes if the person become infected in the next ( t_text{pred} ) days.
Each example’s features ( X_k^{(i)} ) consist of the following:
(1) Individual features: Basic (normalized) demographic features like age, gender, and household size; activity features like working, school, going to church, or shopping; and the individual’s infection history as concatenated one-hot vectors (which depends on how we create labels; see below).
(2) Contact features: One of our key simplifying heuristics is that each node’s (ell)-hop neighborhood should contain most of the information we need to predict infection. We build the contact features as follows:
Every sampled neighbor (v) of a node (u) is encoded using its individual features (as above) along with the edge features describing the contact—e.g. the location, the duration, and the activity type.
We use iterative neighborhood sampling (figure above), meaning that we first select a set of ( S_1 ) 1-hop neighbors, and then sample (S_2) 2-hop neighbors adjacent to those 1-hop neighbors, and so on. This allows reusing 1-hop edge features and keeps the feature dimension (d) low.
We also used deterministic neighborhood sampling—the same person always takes the same subset of neighbors. This drastically reduces computation as the graph/neighborhoods can now be cached. For the interested reader, this also has implications on privacy accounting.
The figure above illustrates the neighborhood feature vector that describes a person and their contacts for the binary classifier! Intriguingly, this makes the per-silo models a simplified variant of a graph neural network (GNN) with a single-step, non-parameterized neighborhood aggregation and prediction (cf. SGC models).
For the labels ( Y_k^{(i)} ), we deployed a random infection windowstrategy:
Pick a window size ( t_text{window} ) (say 21 days);
Select a random day (t’) within the valid range ((t_text{window} le t’ le t_text{train} – t_text{pred}));
Encode the S/I/R states in the past window from (t’) for every node in the neighborhood as individual features;
The label is then whether person (i) is infected in any of the next (t_text{pred}) days from (t’).
Our strategy implicitly assumes that a person’s infection risk is individual: whether Bob gets infected depends only on his own activities and contacts in the past window. This is certainly not perfect as it ignores population-level modeling (e.g. denser areas have higher risks of infection), but it makes the ML problem very simple: just plug-in existing tabular data modeling approaches!
2.4. Putting it all together
We can now see our solution coming together: each silo builds a tabular dataset using neighborhood vectors for features and infection windows for labels, and each silo trains a personalized binary classifier under MR-MTL with silo-specific example-level privacy. We complete our method with a few additional ingredients:
Privacy accounting. We’ve so far glossed over what silo-specific “example-level” DP actually means for an individual. We’ve put more details in the extended blog post, and the main idea is that local DP-SGD can give “neighborhood-level” DP since each node’s enclosing neighborhood is fixed and unique, and we can then convert it to node-level DP (our privacy goal from Part 2.1) by carefully accounting for how a certain node may appear in other nodes’ neighborhoods.
Noisy SGD as an empirical defense. While we have a complete framework for providing silo-specific node-level DP guarantees, for the PETs challenge specifically we decided to opt for weak DP ((varepsilon > 500)) as an empirical protection, rather than a rigorous theoretical guarantee. While some readers may find this mildly disturbing at first glance, we note that the strength of protection depends on the data, the models, the actual threats, the desired privacy-utility trade-off, and several crucial factors linking theory and practice which we outline in the extended blog. Our solution was in turn attacked by several red teams to test for vulnerabilities.
Model architecture: simple is good. While the model design space is large, we are interested in methods amenable to gradient-based private optimization (e.g. DP-SGD) and weight-space averaging for federated learning. We compared simple logistic regression and a 3-layer MLP and found that the variance in data strongly favors linear models, which also have benefits in privacy (in terms of limited capacity for memorization) as well as explainability, efficiency, and robustness.
Computation-utility tradeoff for neighborhood sampling. While larger neighborhood sizes (S) and more hops (ell) better capture the original contact graph, they also blow up the computation and our experiments found that larger (S) and (ell) tend to have diminishing returns.
Data imbalance and weighted loss. Because the data are highly imbalanced, training naively will suffer from low recall and AUPRC. While there are established over-/under-sampling methods to deal with such imbalance, they, unfortunately, make privacy accounting a lot trickier in terms of the subsampling assumption or the increased data queries. We leveraged the focal loss from the computer vision literature designed to emphasize hard examples (infected cases) and found that it did improve both the AUPRC and the recall considerably.
The above captures the essence of our entry to the challenge. Despite the many subtleties in fully building out a working system, the main ideas were quite simple: train personalized models with DP and add some proximity constraints!
Takeaways and Open Challenges
In Part 1, we reviewed our NeurIPS’22 paper that studied the application of differential privacy in cross-silo federated learning scenarios, and in Part 2, we saw how the core ideas and methods from the paper helped us develop our submission to the PETs prize challenge and win a 1st place in the pandemic forecasting track. For readers interested in more details—such as theoretical analyses, hyperparameter tuning, further experiments, and failure modes—please check out our full paper. Our work also identified several important future directions in this context:
DP under data imbalance. DP is inherently a uniform guarantee, but data imbalance implies that examples are not created equal—minority examples (e.g., disease infection, credit card fraud) are more informative, and they tend to give off (much) larger gradients during model training. Should we instead do class-specific (group-wise) DP or refine “heterogeneousDP” or “outlier DP” notions to better cater to the discrepancy between data points?
Graphs and privacy. Another fundamental basis of DP is that we could delineate what is and isn’t an individual. But as we’ve seen, the information boundaries are often nebulous when an individual is a node in a graph (think social networks and gossip propagation), particularly when the node is arbitrarily well connected. Instead of having rigid constraints (e.g., imposing a max node degree and accounting for it), are there alternative privacy definitions that offer varying degrees of protection for varying node connectedness?
Scalable, private, and federated trees for tabular data. Decision trees/forests tend to work extremely well for tabular data such as ours, even with data imbalance, but despite recent progress, we argue that they are not yet mature under private and federated settings due to some underlying assumptions.
Novel training frameworks. While MR-MTL is a simple and strong baseline under our privacy granularity, it has clear limitations in terms of modeling capacity. Are there other methods that can also provide similar properties to balance the emerging privacy-heterogeneity cost tradeoff?
Honest privacy cost of hyperparameter search. When searching for better frameworks, the dependence on hyperparameters is particularly interesting: our full paper (section 7) made a surprising but somewhat depressing observation that the honest privacy cost of just tuning (on average) 10 configurations (values of (lambda) in this case) may already outweigh the utility advantage of the best tune MR-MTL((lambda^ast)). What does this mean if MR-MTL is already a strong baseline with just a single hyperparameter?
Example of embodied commonsense reasoning. A robot proactively identifies a remote on the floor and knows it is out of place without instruction. Then, the robot figures out where to place it in the scene and manipulates it there.
For robots to operate effectively in the world, they should be more than explicit step-by-step instruction followers. Robots should take actions in situations when there is a clear violation of the normal circumstances and be able to infer relevant context from partial instruction. Consider a situation where a home robot identifies a remote control which has fallen to the kitchen floor. The robot should not need to wait until a human instructs the robot to “pick the remote control off the floor and place it on the coffee table”. Instead, the robot should understand that the remote on the floor is clearly out of place, and act to pick it up and place it in a reasonable location. Even if a human were to spot the remote control first and instruct the agent to “put away the remote that is on the living room floor”, the robot should not require a second instruction for where to put the remote, but instead infer from experience that a reasonable location would be, for example, on the coffee table. After all, it would become tiring for a home robot user to have to specify every desire in excruciating detail (think about for each item you want the robot to move, specifying an instruction such as “pick up the shoes beneath the coffee table and place them next to the door, aligned with the wall”).
The type of reasoning that would permit such partial or self-generated instruction following involves a deep sense of how things in the world (objects, physics, other agents, etc.) ought to behave. Reasoning and acting of this kind are all aspects of embodied commonsense reasoning and are vastly important for robots to act and interact seamlessly in the physical world.
There has been much work on embodied agents that follow detailed step-by-step instructions, but less on embodied commonsense reasoning, where the task involves learning how to perceive and act without explicit instruction. One task in which to study embodied commonsense reasoning is that of tidying up, where the agent must identify objects which are out of their natural locations and act in order bring the identified objects to plausible locations. This task combines many desirable capabilities of intelligent agents with commonsense reasoning of object placements. The agent must search in likely locations for objects to be displaced, identify when objects are out of their natural locations in the context of the current scene, and figure out where to reposition the objects so that they are in proper locations – all while intelligently navigating and manipulating.
In our recent work, we proposeTIDEE, an embodied agent that can tidy up never-before-seen rooms without any explicit instruction. TIDEE is the first of its kind for its ability to search a scene for out of place objects, identify where in the scene to reposition the out of place objects, and effectively manipulate the objects to the identified locations. We’ll walk through how TIDEE is able to do this in a later section, but first let’s describe how we create a dataset to train and test our agent for the task of tidying up.
Creating messy homes
To create clean and messy scenes for our agent to learn from for what constitutes a tidy scene and what constitute a messy scene, we use a simulation environment called ai2thor. Ai2thor is an interactive 3D environment of indoor scenes that allows objects to be picked up and moved around. The simulator comes ready with 120 scenes of kitchens, bathrooms, living rooms, and bedrooms with over 116 object categories (and significantly more object instances) scattered throughout. Each of the scenes comes with a default initialization of object placements that are meticulously chosen by humans to be highly structured and “neat”. These default object locations make up our “tidy” scenes for providing our agent examples of objects in their natural locations. To create messy scenes, we apply forces to a subset of the objects with a random direction and magnitude (we “throw” the objects around) so they end up in uncommon locations and poses. You can see below some examples of objects which have been moved out of place.
Next, let’s see how TIDEE learns from this dataset to be able to tidy up rooms.
How does TIDEE work?
We give our agent a depth and RGB sensor to use for perceiving the scene. From this input, the agent must navigate around, detect objects, pick them up, and place them. The goal of the tidying task is to rearrange a messy room back to a tidy state.
TIDEE tidies up rooms in three phases. In the first phase, TIDEE explores around the room and runs an out of place object detector at each time step until one is identified. Then, TIDEE navigates over to the object, and picks it up. In the second phase, TIDEE uses graph inference in its joint external graph memory and scene graph to infer a plausible receptacle to place the object on within the scene. It then explores the scene guided by a visual search network that suggests where the receptacle may be found if TIDEE has not identified it in a previous time step. For navigation and keeping track objects, TIDEE maintains a obstacle map of the scene and stores in memory the estimated 3D centroids of previously detected objects.
The out of place detector uses visual and relational language features to determine if an object is in or out of place in the context of the scene. The visual features for each object are obtained from an off-the-shelf object detector, and the relational language features are obtained by giving predicted 3D relations of the objects (e.g. next to, supported by, above, etc.) to a pretrained language model. We combine the visual and language features to classify whether each detected object is in or out of place. We find that combining the visual and relational modalities performs best for out of place classification over using a single modality.
To infer where to place an object once it has picked up, TIDEE includes a neural graph module which is trained to predict plausible object placement proposals of objects. The modules works by passing information between the object to be placed, a memory graph encoding plausible contextual relations from training scenes, and a scene graph encoding the object-relation configuration in the current scene. For our memory graph, we take inspiration from “Beyond Categories: The Visual Memex Model for Reasoning About Object Relationships” by Tomasz Malisiewicz and Alexei A. Efros (2009), which models instance-level object features and their relations to provide more complete appearance-based context. Our memory graph consists of the tidy object instances in the training to provide fine-grain contextualization of tidy object placements. We show in the paper that this fine-grain visual and relational information is important for TIDEE to place objects in human-preferred locations.
To search for objects that have not been previously found, TIDEE uses a visual search network that takes as input the semantic obstacle map and a search category and predicts the likelihood of the object being present at each spatial location in the obstacle map. The agent then searches in those likely locations for the object of interest.
Combining all the above modules provides us with a method to be able to detect out of place objects, infer where they should go, search intelligently, and navigate & manipulate effectively. In the next section, we’ll show you how well our agent performs at tidying up rooms.
How good is TIDEE at tidying up?
Using a set of messy test scenes that TIDEE has never seen before, we task our agent with reconfiguring the messy room to a tidy state. Since a single object may be tidy in multiple locations within a scene, we evaluate our method by asking humans whether they prefer the placements of TIDEE compared to baseline placements that do not make use of one or more of TIDEE’s commonsense priors. Below we show that TIDEE placements are significantly preferred to the baseline placements, and even competitive with human placements (last row).
We additionally show that the placements of TIDEE can be customized based on user preferences. For example, based on user input such as “I never want my alarm on the desk”, we can use online learning techniques to change the output from the model that alarm clock being on the desk is out of place (and should be moved). Below we show some examples of locations and relations of alarm clocks that were predicted as being in the correct locations (and not out of place) within the scene after our initial training. However, after doing the user-specified finetuning, our network predicts that the alarm clock on the desk is out of place and should be repositioned.
We also show that a simplified version of TIDEE can generalize to task of rearrangement, where the agent sees the original state of the objects, then some of the objects get rearranged to new locations, and the agent must rearrange the objects back to their original state. We outperform the previous state of the art model that utilizes semantic mapping and reinforcement learning, even with noisy sensor measurements.
Summary
In this article, we discussed TIDEE, an embodied agent that uses commonsense reasoning to tidy up novel messy scenes. We introduce a new benchmark to test agents in their ability to clean up messy scenes without any human instruction. To check out our paper, code, and more, please visit our website at https://tidee-agent.github.io/.
Also, feel free to shoot me an email at gsarch@andrew.cmu.edu! I would love to chat!
Figure 1. This blog post discusses the effectiveness of black-box model explanations in aiding end users to make decisions. We observe that explanations do not in fact help with concrete applications such as fraud detection and paper matching for peer review. Our work further motivates novel directions for developing and evaluating tools to support human-ML interactions.
Model explanations have been touted as crucial information to facilitate human-ML interactions in many real-world applications where end users make decisions informed by ML predictions. For example, explanations are thought to assist model developers in identifying when models rely on spurious artifacts and to aid domain experts in determining whether to follow a model’s prediction. However, while numerous explainable AI (XAI) methods have been developed, XAI has yet to deliver on this promise. XAI methods are typically optimized for diverse but narrow technical objectives disconnected from their claimed use cases. To connect methods to concrete use cases, we argued in our Communications of ACM paper [1] for researchers to rigorously evaluate how well proposed methods can help real users in their real-world applications.
Towards bridging this gap, our group has since completed two collaborative projects where we worked with domain experts in e-commerce fraud detection and paper matching for peer review. Through these efforts, we’ve gleaned the following two insights:
Existing XAI methods are not useful for decision-making. Presenting humans with popular, general-purpose XAI methods does not improve their performance on real-world use cases that motivated the development of these methods. Our negative findings align with those of contemporaneous works.
Rigorous, real-world evaluation is important but hard. These findings were obtained through user studies that were time-consuming to conduct.
We believe that each of these insights motivates a corresponding research direction to support human-ML interactions better moving forward. First, beyond methods that attempt to explain the ML model itself, we should consider a wider range of approaches that present relevant task-specific information to human decision-makers; we refer to these approaches as human-centered ML (HCML) methods [10]. Second, we need to create new workflows to evaluate proposed HCML methods that are both low-cost and informative of real-world performance.
In this post, we first outline our workflow for evaluating XAI methods. We then describe how we instantiated this workflow in two domains: fraud detection and peer review paper matching. Finally, we describe the two aforementioned insights from these efforts; we hope these takeaways will motivate the community to rethink how HCML methods are developed and evaluated.
How do you rigorously evaluate explanation methods?
In our CACM paper [1], we introduced a use-case-grounded workflow to evaluate explanation methods in practice—this means showing that they are ‘useful,’ i.e., that they can actually improve human-ML interactions in the real-world applications that they are motivated by. This workflow contrasts with evaluation workflows of XAI methods in prior work, which relied on researcher-defined proxy metrics that may or may not be relevant to any downstream task. Our proposed three-step workflow is based on the general scientific method:
Step 1:Define a concrete use case. To do this, researchers may need to work closely with domain experts to define a task that reflects the practical use case of interest.
Step 2:Select explanation methods for evaluation. While selected methods might be comprised of popular XAI methods, the appropriate set of methods is to a large extent application-specific and should also include relevant non-explanation baselines.
Step 3:Evaluate explanation methods against baselines. While researchers should ultimately evaluate selected methods through a user study with real-world users, researchers may want to first conduct cheaper, noisier forms of evaluation to narrow down the set of methods in consideration (Figure 2).
Instantiating the workflow in practice
We collaborated with experts from two domains (fraud detection and peer review paper matching) to instantiate this use-case-grounded workflow and evaluate existing XAI methods:
Domain 1: Fraud detection [3]. We partnered with researchers at Feedzai, a financial start-up, to assess whether providing model explanations improved the ability of fraud analysts to detect fraudulent e-commerce transactions. Given that we had access to real-world data (i.e., historical e-commerce transactions for which we had ground truth answers of whether the transaction was fraudulent) and real users (i.e., fraud analysts), we directly conducted a user study in this context. An example of the interface shown to analysts is in Figure 3. We compared analysts’ average performance when shown different explanations to a baseline setting where they were only provided the model prediction. We ultimately found that none of the popular XAI methods we evaluated (LIME, SHAP, and Tree Interpreter) resulted in any improvement in the analysts’ decisions compared to the baseline setting (Figure 5, left). Evaluating these methods with real users additionally posed many logistical challenges because fraud analysts took time from their regular day-to-day work to periodically participate in our study.
Domain 2: Peer review paper matching [4]. We collaborated with Professor Nihar Shah (CMU), an expert in peer review, to investigate what information could help meta-reviewers of a conference better match submitted papers to suitable reviewers. Learning from our prior experience, we first conducted a user study using proxy tasks and users, which we worked with Professor Shah to design as shown in Figure 4. In this proxy setting, we found that providing explanations from popular XAI methods in fact led users to be more confident—-the majority of participants shown highlights from XAI methods believed the highlighted information was helpful—yet, they made statistically worse decisions(Figure 5 right)!
How can we better support human-ML interactions?
Through these collaborations, we identified two important directions for future work, which we describe in more detail along with our initial efforts in each direction.
We need to develop methods for specific use cases. Our results suggest that explanations from popular, general-purpose XAI methods can both hurt decision-making while making users overconfident. These findings have also been observed in multiple contemporaneous works (e.g., [7,8,9]). Researchers, instead, need to consider developing human-centered ML (HCML) methods [10] tailored for each downstream use case. HCML methods are any approach that provides information about the particular use case and context that can inform human decisions.
Our contributions: In the peer review matching setting, we proposed an HCML method designed in tandem with a domain expert [4]. Notably, our method is not a model explanation approach, as it highlights information in the input data, specifically sentences and phrases that are similar in the submitted paper and the reviewer profile. Figure 6 compares the text highlighted using our method to the text highlighted using existing methods. Our method outperformed both a baseline where there was no explanation and the model explanation condition (Figure 5, right). Based on these positive results, we plan to move evaluations of our proposed method to more realistic peer review settings. Further, we performed an exploratory study to better understand how people interact with information provided by HCML methods as a first step towards coming up with a more systematic approach to devise task-specific HCML methods [5].
We need more efficient evaluation pipelines. While user studies conducted in a real-world use case and with real users are the ideal way to evaluate HCML methods, it is a time- and resource-consuming process. We highlight the need for more cost-effective evaluations that can be utilized to narrow down candidate HCML methods and still implicate the downstream use case. One option is to work with domain experts to design a proxy task as we did in the peer review setting, but even these studies require careful consideration of the generalizability to the real-world use case.
Our contributions. We introduced an algorithmic-based evaluation called simulated user evaluation (SimEvals) [2]. Instead of conducting studies on proxy tasks, researchers can train SimEvals, which are ML models that serve as human proxies. SimEvals more faithfully reflects aspects of real-world evaluation because their training and evaluation data are instantiated on the same data and task considered in real-world studies. To train SimEvals, the researcher first needs to generate a dataset of observation-label pairs. The observation corresponds to the information that would be presented in a user study (and critically includes the HCML method), while the output is the ground truth label for the use case of interest. For example, in the fraud detection setting, the observation would consist of both the e-commerce transaction and ML model score shown in Figure 3(a) along with the explanation shown in Figure 3(b). The ground truth label is whether or not the transaction was fraudulent. SimEvals are trained to predict a label given an observation and their test set accuracies can be interpreted as a measure of whether the information contained in the observation is predictive for the use case.
We not only evaluated SimEvals on a variety of proxy tasks but also tested SimEvals in practice by working with Feedzai, where we found results that corroborate the negative findings from the user study [6]. Although SimEvals should not replace user studies because SimEvals are not designed to mimic human decision-making, these results suggest that SimEvals could be initially used to identify more promising explanations (Figure 6).
Conclusion
In summary, our recent efforts motivate two ways the community should rethink how to support human-ML interactions: (1) we need to replace general-purpose XAI techniques with HCML methods tailored to specific use cases, and (2) creating intermediate evaluation procedures that can help narrow down the HCML methods to evaluate in more costly settings.
For more information about the various papers mentioned in this blog post, see the links below:
[1] Chen, V., Li, J., Kim, J. S., Plumb, G., & Talwalkar, A. Interpretable Machine Learning. Communications of the ACM, 2022. (link)
[2] Chen, V., Johnson, N., Topin, N., Plumb, G., & Talwalkar, A. Use-case-grounded simulations for explanation evaluation. NeurIPS, 2022. (link)
[3] Amarasinghe, K., Rodolfa, K. T., Jesus, S., Chen, V., Balayan, V., Saleiro, P., Bizzaro, P., Talwalkar, A. & Ghani, R. (2022). On the Importance of Application-Grounded Experimental Design for Evaluating Explainable ML Methods. arXiv. (link)
[4] Kim, J. S., Chen, V., Pruthi, D., Shah, N., Talwalkar, A.Assisting Human Decisions in Document Matching. arXiv. (link)
[5] Chen, V., Liao, Q. V., Vaughan, J. W., & Bansal, G. (2023). Understanding the Role of Human Intuition on Reliance in Human-AI Decision-Making with Explanations. arXiv. (link)
[6] Martin, A., Chen, V., Jesus, S., Saleiro, P. A Case Study on Designing Evaluations of ML Explanations with Simulated User Studies. arXiv. (link)
[7] Bansal, G., Wu, T., Zhou, J., Fok, R., Nushi, B., Kamar, E., Ribeiro, M. T. & Weld, D. Does the whole exceed its parts? the effect of ai explanations on complementary team performance. CHI, 2021. (link)
[8] Adebayo, J., Muelly, M., Abelson, H., & Kim, B. Post hoc explanations may be ineffective for detecting unknown spurious correlation. ICLR, 2022. (link)
[9] Zhang, Y., Liao, Q. V., & Bellamy, R. K. Effect of confidence and explanation on accuracy and trust calibration in AI-assisted decision making. FAccT, 2020. (link)
[10] Chancellor, S. (2023). Toward Practices for Human-Centered Machine Learning. Communications of the ACM, 66(3), 78-85. (link)
Acknowledgments
We would like to thank Kasun Amarasinghe, Jeremy Cohen, Nari Johnson, Joon Sik Kim, Q. Vera Liao, and Junhong Shen for helpful feedback and suggestions on earlier versions of the blog post. Thank you also to Emma Kallina for her help with designing the main figure!
Figure 1: Behavior-driven AI development centers model iteration on evaluating and improving specific real-world use cases.
It has never been easier to prototype AI-driven systems. With a bit of programming knowledge and a couple of hours, you can spin up a chatbot for your notes, a text-based image editor, or a tool for summarizing customer feedback. But play around with your prototype for a bit, and you might find that it doesn’t work as well as you first expected. Your system might make up facts or respond with racist suggestions. How would you evaluate your model and predict its performance in deployment?
The canonical process for benchmarking AI systems revolves around model-centric metrics. Calculate a metric (F1-score, precision, etc.), and if it increases, you are going in the right direction. But these metrics are oversimplified objectives that sand away the complexity of model behavior and cannot fully represent a model’s performance. A metric may tell you how well your model can predict the next word in a sentence, but it won’t tell you how factually accurate, logical, or fair your model is across diverse, real-world use cases. Generative AI systems such as ChatGPT or Stable Diffusion make evaluation even more challenging since there are no well-defined metrics that can summarize their performance.
When creating deployed AI products, practitioners instead focus on the specific use cases their customers have and whether or not their models are fulfilling them. In interviews with 18 AI practitioners, we found that they constantly collect user feedback and develop “golden test sets” of behaviors that they expect deployed models to have. We term this behavior-drivenAI development, a development process focused on evaluating and updating models to improve performance on real-world use cases. While chatbot A might sound more human-like, a practitioner will deploy chatbot B if it produces concise and accurate answers that customers prefer.
The landscape of AI evaluation tools primarily revolves around model-centric metrics that do not capture important behaviors like these chatbot characteristics. While there are specific tools for behavior-driven development, such as fairnesstoolkits and robustnessanalysislibraries, practitioners end up cobbling together disparate tools into ad-hoc scripts or computational notebooks that are hard to maintain and reproduce.
I believe that there are a set of abstractions that can unify AI evaluation in line with model use cases in practice. This philosophy revolves around model behaviors: metrics summarizing patterns of output on subgroups of instances. This simple concept can encode any model evaluation or analysis, from fairness audits to language model hallucinations. We show what this can look like with Zeno, an interactive platform we built for behavior-driven development that supports interactive data exploration, slicing, and reporting. By investigating their own models using Zeno, practitioners have been able to pinpoint significant and actionable issues such as biases and systematic failures.
What is model behavior?
The dictionary describes behavior as anything that an organism does involving action and response to stimulation. In the case of AI systems, model behavioris a specific pattern of output for a semantically meaningful subgroup of input data (stimulus). By semantically meaningful, I mean subgroups that can be described with human-interpretable concepts, such as “audio with noise in the background” or “people who identify as women.” Similarly, a pattern of output could be “high audio transcription error” or “low loan approval rate.”
Behaviors can be quantified as metrics on subgroups of data, often using the same metrics as are used for model-centric evaluation. But unlike summary metrics across an entire dataset, metrics in behavior-centric development quantify specificpatterns of behavior, like how often an image generation model produces unintelligible text. Tests of model behaviors are like exams for specific subjects, while summary metrics resemble IQ tests.
Model behaviors are a relatively simple concept, but encoding behaviors can be challenging in practice. Practitioners may not have enough data to validate or fix important model behaviors and have to collect or generate more data. If they have extensive data, they need ways to subdivide it into meaningful groups of instances – how do I find all images that have text? Lastly, for each subgroup, practitioners have to derive the appropriate metrics to quantify the prevalence of behavior – how do I detect blurry text? Succinctly, behavior-driven development requires sufficient data that is representative of expected behaviorsand metadata for defining and quantifying the behaviors.
A platform for behavior-driven AI development
The beauty of a behavior-based framing on AI development is that it is still data and model agnostic. While the specific behaviors for each ML task will be vastly different, subgroups of data and metrics are universal concepts.
To test this theory, we built a platform for behavior-driven AI development called Zeno. Zeno is a platform that empowers users to explore data and model outputs, interactively create subgroups of data, and calculate and quantify model behaviors. Zeno consists of a Python API for scaffolding the data needed for analysis and a user interface for interactively creating subgroups and evaluating behaviors.
The Python API is a set of decorator functions (wrappers on user-defined functions) that can be used to plug in ML models and derive metadata features and metrics from input data. Since the decorators are generic wrappers, Zeno supports any Python-based model, processing function, or metric. Zeno preprocesses the input data with these functions, which it passes into the UI for analysis.
Zeno’s UIis the primary interface for behavior-driven evaluation. It allows users to interactively explore and filter their data, create slices, calculate metrics, and create exportable visualizations. On the right side of the UI is Zeno’s instance view, where users can explore the raw data on which the model is being evaluated. In addition to the standard list view, users can also see the data in a table or a 2D scatterplot representation. The left side of the interface holds the metadata panel. All the metadata columns that either came with the dataset or were generated with the Python API have their distributions displayed in the panel. Users can interactively filter the distributions to update the instance view and create named subgroups.
The UI also has a report page for creating interactive summary visualizations of behaviors. For example, a user could create a bar chart comparing the performance of three models across ten different slices. Or they could create a line chart showing how a model performs on data slices from each day of data. These visualizations can be exported or shared directly with other stakeholders.
Case Studies
We have worked with various ML practitioners to apply Zeno to the models and tasks on which they work. Using Zeno, practitioners found significant model issues and areas for improvement, including gender biases and regional model disparities.
Audio transcription.This first case study I ran myself after I heard that OpenAI released a new speech-to-text model, Whisper, with state-of-the-art performance. I was curious how the model compared to some existing off-the-shelf transcription models. Instead of looking at aggregate metrics, I ran the models on the Speech Accent Archive dataset, which has speakers worldwide saying the same phrase. By filtering the dataset’s extensive metadata, I found that the models perform worse for English speakers who learned the language later in life and speakers from countries where English is not the native language.
Cancer classification.In another case study, we worked with a researcher who wanted to improve a breast cancer classifier for mammogram images. Since the data was anonymized and lacked meaningful metadata, the practitioner wrote dozens of functions using a Python library to extract meaningful metadata features. By exploring the distributions, they found that images with higher “entropy” correlating with denser breast tissue had a significantly higher error rate than images with lower entropy, or less dense, tissue. This finding matches performance differences in human radiologists, who also perform worse for images of denser breast tissue since it makes it harder to detect lesions.
High density (656) entropy > 2.75 && gray level variance > 2.5
AUC
0.86
0.76
Figure 6. The breast cancer classification model performed significantly worse for high-density images (described by high entropy and gray level variance metadata levels) compared to the low-density images. (left, low density, right, high density).
Image generation.Models with complex outputs often do not have clearly defined metrics, including text-to-image generation models such as DALL*E and Stable Diffusion. We can instead look at metrics that measure specific behaviors. In this example, a practitioner we worked with was exploring the DiffusionDB dataset, which has over two million prompt-image pairs from the Stable Diffusion model. The dataset also has metadata for how NSFW or inappropriate the prompts and images are. This data was used to derive an “average NSFW” metric, which can show us interesting potential biases in the model. For example, the participant compared the images generated using prompts with the word “boy” versus “girl” and found that prompts with “girl” generated images with a significantly higher NSFW level than prompts with “boy”, showing potential biases in the types of images created by the model.
Discussion and Opportunities
Model iteration is still a primarily reactiveprocess of finding and defining behaviors after a model has been deployed and the customer complaints start rolling in. There remains significant room for improving this process, from making it easier to ideate model behaviors to tracking model changes over time.
Discovering behaviors. While practitioners often need a model to discover the behaviors the model should have, methods for defining expected model behaviors before deployment can prevent serious real-world model issues. For example, crowdsourcing techniques for eliciting potential edge cases could preemptively catch model errors. Algorithmic methods that find clusters of data with high error have also shown promise for surfacing problematic behaviors.
Data discovery and generation. Having high-quality, representative data remains a persistent obstacle for behavioral evaluation. In some domains with ample data, such as natural images, methods like Stable Diffusion have shown promise for generating new data for evaluation or training. In less data-rich domains, techniques for searching through large unlabeled datasets, such as text-based image search, can surface valuable data for evaluation and retraining. It is also challenging to derive metadata from instances for creating subgroups and calculating metrics. While it can be easy to generate metadata for simple concepts like “image brightness,” many behaviors are defined by complex metadata such as “images with a person wearing clear glasses” that cannot be encoded by a simple function. Foundation models have shown some promise in using text-based descriptions to generate complex metadata and metrics.
Model comparison. Models are almost never one-off jobs and can be updated daily or weekly. While it is easy to compare aggregate metrics, it can be challenging to compare model performance in behavior-driven development. To pick between models, users may have to compare dozens of behaviors and qualitative insights. Improved visual encodings or intelligent recommendations of model differences could help users make informed decisions and deploy the right models.
Fixing behaviors. Discovering and encoding behaviors is one thing, but fixing behaviors is another massive challenge. A common approach to fixing issues is to gather more data and retrain the model, but this process can lead to catastrophic forgetting and regressions. There are recent techniques that align well with behavior-driven development, such as slice-based learning, which can selectively fix model behaviors without new data.
Conclusion
There is significant excitement for this new era of AI systems. But along with their growing capability, the complexity of their behavior is also increasing. We need powerful tools to empower behavior-driven development and ensure we build intelligent systems that align with human values. Zeno provides a general-purpose platform that empowers users to do this deep evaluation across the diverse tasks of modern AI. Learn more about Zeno at zenoml.com, read the full paper, or reach out if you would like to use Zeno for your models!
Acknowledgments
I’d like to thank Will Epperson, Jason I. Hong, Yi-Cheng Huang, Misha Khodak, Adam Perer, Venkat Sivaraman, Ameet Talwalkar, and Kristen Vossler for their thoughtful feedback and advice.
Figure 1: Overview of RL Prompt for discrete prompt optimization. All language models (LMs) are frozen. We build our policy network by training a task-specific multi-layer perceptron (MLP) network inserted into a frozen pre-trained LM. The figure above illustrates generation of a prompt (left), example usages in a masked LM for classification (top right) and a left-to-right LM for generation (bottom right), and update of the MLP using RL reward signals (red arrows).
TL;DR: Prompting enables large language models (LLMs) to perform various NLP tasks without changing the model. Discrete prompts have many desirable properties, but are difficult to optimize. We propose an efficient approach using reinforcement learning, which shows superior performance and facilitates rich interpretations and analyses. You can easily adapt it for your own tasks using our code base here.
Prompting has emerged as a promising approach to solving a wide range of NLP problems using large pre-trained language models (LMs), including left-to-right models such as GPTs (Radford et al., 2019; Brown et al., 2020) and masked LMs such as BERT (Devlin et al., 2019), RoBERTa (Liu et al., 2019), etc.
Compared to conventional fine-tuning that expensively updates the massive LM parameters for each downstream task, prompting concatenates the inputs with an additional piece of text that steers the LM to produce the desired outputs. A key question with prompting is how to find the optimal prompts to improve the LM’s performance on various tasks, often with only a few training examples.
Most existing work resorts to tuning soft prompt (e.g., embeddings) which falls short of interpretability, reusability across LMs, and applicability when gradients are not accessible. Discrete prompt, on the other hand, is difficult to optimize, and is often created by “enumeration (e.g., paraphrasing)-then-selection” heuristics that do not explore the prompt space systematically.
Instead, we propose RLPrompt, an efficient discrete prompt optimization approach with reinforcement learning (RL). RLPrompt is flexibly applicable to different types of LMs (e.g., BERT and GPTs) for both classification and generation tasks. Experiments on few-shot classification and unsupervised text style transfer show superior performance over a wide range of existing finetuning or prompting methods.
Interestingly, the resulting optimized prompts are often ungrammatical gibberish text; and surprisingly, those gibberish prompts are transferable between different LMs to retain significant performance, indicating LMs may have grasped shared structures for prompting, but do not follow human language patterns.
Discrete Prompt Optimization with RL
This paper presents RLPrompt, a new discrete prompt optimization approach based on reinforcement learning (RL). This approach brings together a wide range of desirable properties for efficient use on diverse tasks and LMs (see the table below).
Crucially, rather than directly editing the discrete tokens, which has been difficult and inefficient, RLPrompt trains a policy network that generates the desired prompts. Discrete prompt optimization thus amounts to learning a small number of policy parameters which we set as an MLP layer inserted into a frozen compact model such as distilGPT-2 (HuggingFace, 2019). We describe the specific formulations in Section §2.1-2.3 of our paper.
This formulation also allows us to employ off-the-shelf RL algorithms (e.g., Guo et al., 2021) that learn the policy with arbitrary reward functions—defined either with available data (e.g., in few-shot classification) or other weak signals when no supervised data is accessible (e.g., in controllable text generation).
Reward Stabilization
On the other hand, RL for prompt optimization poses new challenges to learning efficiency: the large black-box LM presents a highly complex environment that, given the prompt (i.e., actions), goes through a long series of complex transitions (e.g., reading the input and inferring the output) before computing the rewards. This makes the reward signals extremely unstable and hard to learn from.
To overcome this difficulty, we propose two simple yet surprisingly effective ways to stabilize the rewards and improve the optimization efficiency.
Normalizing the training signal by computing the z-score of rewards for the same input.
Designing piecewise reward functions that provide a sparse, qualitative bonus to desirable behaviors (e.g., certain accuracy on certain class).
We describe more details in Section §2.4 of our paper.
Experiments
We evaluate our approach on both classification (in the few-shot setting) and generation (unsupervised text style transfer), and perform rich analyses for new insights on LM prompting. We describe implementation details such as reward function design in Section §3 our paper, and publish the code at our Github codebase.
Few-Shot Text Classification
For few-shot classification, we follow previous work and experiment on popular sentiment and topic classification tasks, using 16 examples per class for both training and validation (Perez et al., 2021). Results using RoBERTa-large (left table below) show our approach improving over a wide range of fine-tuning and prompting methods, and is as efficient to optimize as similar methods that tune soft prompts (e.g., right figure below). We report detailed dataset-level results in Section §3.1 of our paper.
Unsupervised Text Style Transfer
For text style transfer, we evaluate on the popular Yelp sentiment transfer dataset (Shen et al., 2017) using popular automatic metrics for content preservation, style accuracy, and fluency, and report their sentence-level joint product (J(cdot)) below. Our full paper also includes few-shot experiments on the Shakespeare (Xu et al., 2012) dataset and human evaluations.
Results using GPT-2 (left table below) show our method outperforms or competes with various fine-tuning and prompting baselines, including DiRR (Liu et al., 2021c) which expensively fine-tunes all parameters of a GPT-2 model. Ablation study (right figure below) shows that our proposed reward normalization technique is crucial to optimization success. We describe the full evaluation results in Section §3.2 of our paper.
Analysis
Optimal Prompts Don’t Follow Human Language
The resulting discrete prompts also facilitate rich interpretations and analyses for new insights into LM prompting. In particular, the optimized prompts, though inducing strong task performance, tend to be gibberish text without clear human-understandable meaning (e.g., table below), echoing recent research (Webson and Pavlick, 2021; Zhao et al., 2021; Prasad et al., 2022) that LMs making use of prompts do not necessarily follow human language patterns.
Learned Prompts Transfer Trivially Across LMs
Perhaps surprisingly, those gibberish prompts learned with one LM can be used in other LMs for significant performance, indicating that those different pre-trained LMs have grasped shared structures for prompting (e.g., figures below).
Conclusion
We have presented RLPrompt, an efficient and flexible approach for discrete prompt optimization using RL, which improves over a wide range of fine-tuning and prompting methods in experiments on few-shot classification and unsupervised text style transfer.
Analysis reveals that strong optimized prompts are incoherent but transferable between LMs for remarkable performance. The observation opens up many promising possibilities for prompting, such as learning prompts cheaply from smaller models and performing inference with larger models. We are excited to explore further.
We perform open vocabulary detection of the objects mentioned in the sentence using both bottom-up and top-down feedback.
Object detection is the fundamental computer vision task of finding all “objects” that are present in a visual scene. However, this raises the question, what is an object? Typically, this question is side-stepped by defining a vocabulary of categories and then training a model to detect instances of this vocabulary. This means that if “apple” is not in this vocabulary, the model does not consider it as an object. The problem gets even worse when we try to integrate these object detectors into real household agents. Imagine that we want a robot that can pick up “your favorite green mug from the table right in front of you”. We want the robot to specifically detect the “green mug” which is on the “table in front of you” and not any other mug or table. Obviously, treating descriptions such as “green mug from the table right in front of you” as separate classes in the detector’s vocabulary cannot scale; one can come up with countless variations of such descriptions.
In light of this, we introduce Bottom-up Top-Down DEtection TRansformer (BUTD-DETR pron. Beauty-DETER), a model that conditions directly on a language utterance and detects all objects that the utterance mentions. When the utterance is a list of object categories, BUTD-DETR operates as a standard object detector. It is trained from both fixed vocabulary object detection datasets and referential grounding datasets which provide image-language pairs annotated with the bounding boxes for all objects referred to in the language utterance. With minimal changes, BUTD-DETR grounds language phrases both in 3D point clouds and 2D images.
No box bottleneck: BUTD-DETR decodes object boxes directly by attending to language and visual input instead of selecting them from a pool. Language-directed attention helps us localize objects that our bottom-up, task-agnostic attention may miss. For example, in the above image, the hint of “clock on top of the shelf” suffices to guide our attention to the right place, though the clock is not a salient object in the scene. Previous approaches for language grounding are detection-bottlenecked: they select the referred object from a pool of box proposals obtained from a pre-trained object detector. This means that if the object detector fails, then the grounding model will fail as well.
How does it work?
The input to our model is a scene and a language utterance. A pre-trained object detector is used to extract box proposals. Next, the scene, boxes, and utterance are encoded using per-modality-specific encoders into visual, box, and language tokens respectively. These tokens are contextualized by attending to one another. The refined visual tokens are used to initialize object queries that attend to the different streams and decode boxes and spans.
Augmenting supervision with Detection prompts
Object detection is an instance of referential language grounding in which the utterance is simply the object category label. We cast object detection as the referential grounding of detection prompts: we randomly sample some object categories from the detector’s vocabulary and generate synthetic utterances by sequencing them, e.g., “Couch. Person. Chair.”, as shown in the figure above. We use these detection prompts as additional supervision data: the task is to localize all object instances of the category labels mentioned in the prompt if they appear in the scene. For the category labels with no instances present in the visual input (e.g. “person” in the above figure), the model is trained to not match them to any boxes. In this way, a single model can perform both language grounding and object detection simultaneously and share the supervision information.
Results
BUTD-DETR achieves a large boost in performance over state-of-the-art approaches across all 3D language grounding benchmarks (SR3D, NR3D, ScanRefer). Moreover, it was the winning entry in the ReferIt3D challenge, held at the ECCV workshop on Language for 3D Scenes. On 2D language grounding benchmarks, BUTD-DETR performs on par with state-of-the-art methods when trained on large-scale data. Importantly, our model converges twice as fast compared to state-of-the-art MDETR, mainly because of the efficient deformable attention which we used with our 2D model.
We show the qualitative results of our model in the video at the beginning of the blog. For more visualizations, please refer to our project page and paper.
What’s next?
Our method detects all objects mentioned in the sentence — however, this assumes that the user needs to mention all relevant objects in the sentence. This is not desirable in general — for example, in response to “make breakfast” we would like our model to detect all the relevant ingredients like bread, eggs etc., even if they are not mentioned in the sentence. Additionally, while our architecture works for both 2D and 3D language grounding with minimal changes, we do not share parameters between the two modalities. This prevents transferring representations across modalities, which would be particularly helpful for the low-resource 3D modality. Our ongoing work is investigating these two directions.
We have released our code and model weights on GitHub, making it easy to reproduce our results and build upon our method. If you are interested in a language-conditioned open vocabulary detector for your project, then give BUTD-DETR a run! For more details, please check out our project page and paper.
A standard assumption in sequential decision making is that we observe everything required to make good decisions. In practice however, this isn’t always the case. We discuss two specific examples (temporally correlated noise (a) and unobserved contexts (c)) that have stymied the use of IL/RL algorithms (in autonomous helicopters (b) and self-driving (d)). We derive provably correct algorithms for both of these problems that scale to continuous control problems.
Reinforcement Learning (RL) and Imitation Learning (IL) methods have achieved impressive results in recent years like beating the world champion at Go or controlling stratospheric balloons. Usually, these results are on problems where we either a) observe the full state or b) are able to faithfully execute our intended actions on the system. However, we frequently have to contend with situations where this isn’t the case: our self-driving car might miss a person’s hand gestures or persistent wind might make it difficult to fly our quadcopter perfectly straight. These sorts of situations can cause standard IL approaches to perform poorly ([1], [2]). In causal inference, we call a random variable that we don’t observe that influences a relationship we’d like to model a confounder. Using techniques from causal inference, we derive provably correct and scalable algorithms for sequential decision making in these sorts of confounded settings.
We’re going to be focused mostly on imitation learning (see our last blog post for more details). In IL, we observe trajectories (sequences of states and actions) generated by some expert policy (pi_E) and want to recover the policy that generated them. The expert could be a) an experienced driver if we’re trying to build a self-driving car, b) a user interacting with a recommender system we’re attempting to model, or c) scraped internet text used to train a large language model. We’re now going to discuss two issues of causation that are difficult for standard IL algorithms to handle. I’ll be drawing upon material from our ICML ’22 paper Causal Imitation under Temporally Correlated Noise and our NeurIPS ’22 paper Sequence Model Imitation Learning with Unobserved Contexts.
Issue 1: Temporally Correlated Noise
The first problem we’ll consider is when there’s temporally correlated noise (TCN) affecting pairs of expert actions. For example, if our expert was a human quadcopter pilot, the TCN could be persistent wind [1]. Let’s assume that the expert intended to fly completely straight but was unable to do so because of the persistent wind, producing swerving trajectories. If the learner ignores the fact that the wind was the root cause of these swerves, they might attempt to reproduce the swerving behavior of the expert, causing them to deviate even further at test time due to the continued influence of the wind.
What’s happened is that the TCN has created a spurious correlation (e.g. being a little on the left is often followed by going more to the left) which the learner has latched onto. What we’d like is to instead learn a policy that can fly as straight as the expert did in the demonstrations.
Let’s try and formalize what’s happening in the above example via graphical models. In a Markov Decision Process (MDP), the learner (pi) responds to the observed state (s_t) via taking an action (a_t) and the MDP transitions via dynamics (mathcal{T}: s times a rightarrow s ) to the next state (s_{t+1}).
Adding in the TCN corresponds to adding in a common cause of a pair of actions that’s not part of the observed state ((u_t) below).
Now, let’s dig into what goes wrong with imitation learning under TCN. Let’s say we apply a standard algorithm like behavioral cloning: direct regression from states ((X = s_t)) to actions ((Y=a_t)). The TCN travels through the dynamics to influence both the inputs and outputs of our regression procedure, creating spurious correlations that our learner might latch onto, leading to poor test-time performance. We visualize this effect in red below.
In effect, the TCN makes some elements of state (s_2) look more correlated with action (a_2) than they would otherwise. Attempting to minimize training error, the learner will use this correlation to their advantage but will end up swerving unnecessarily at test time as a result.
Filtering out TCN with Instrumental Variable Regression
To learn well under TCN, we’re going to utilize the idea of a natural experiment from causal inference, in which we’re able to use variation in the observational data to simulate the effect of an intervention / randomized control trial (RCT). Interestingly enough, this idea was one of the central contributions of the winners of the Nobel Prize in Economics in 2021. Let’s first take a look at this idea in the single shot setting. Assume that we’re trying to predict from (X) to (Y), both of which are affected by an unobserved confounder (U):
Notice the random variable (Z) both (1) affects (X) and (2) is independent from (U). (Z) is therefore an independent source of variation in (X). We call such variables instruments. Under some assumptions (see our paper for details), one can condition on an instrument to de-noise the inputs to our regression procedure and recover a causal relationship. So, instead of regressing from $$X rightarrow Y,$$ one regresses from $$X|Z rightarrow Y|Z.$$
We’ll skip the derivation here but intuitively, conditioning on an independent source of variation, (Z), “washes out” the effect of the confounder, (U). The variation in (Z) therefore acts as a sort of “natural experiment,” simulating an RCT without requiring a true intervention.
Now, you might well be wondering what a valid instrument in the sequential setting is? Well, given the past is independent of future confounding, we can use past states as an instrument! Graphically,
Many econometrics papers will spend quite a lot of text arguing something is a valid instrument (i.e. satisfies (1) and (2)) but here, the arrow of time gives us an instrument for free. Algorithmically, instrumental variable regression for imitation learning under TCN corresponds to solving $$ s_t | s_{t-1} rightarrow a_t | s_{t-1}.$$
So, one first fits models of the left and right-hand sides of the above expression and then regresses between them. We give two algorithms for doing so efficiently in our paper. On simulated control tasks, we observe that our approaches (in teal and orange) are able to recover a policy that matches expert performance, while regression-based behavioral cloning struggles to do so.
So, putting it all together, temporally correlated noise between expert actions can create spurious correlations between recorded states and actions, causing standard regression-based imitation learning approaches to produce inaccurate predictors. One can use the “natural experiment” caused by the variation in past states to de-noise the inputs to their regression procedure and recover causally correct predictors that do not suffer from spurious correlations.
Issue 2: Unobserved Contexts
We’re now going to consider a different kind of confounder. Let’s assume we’re trying to imitate an expert who observes some context, (c), which the learner does not. For example, if we were trying to imitate an expert driver but our sensors don’t pick up a stop sign on the side of the road, (c) could refer to this side information. Now, this problem is in general impossible to solve: if we don’t see the sign, how would we know to stop? However, if we pay attention to our surroundings and accumulate information over time, we might be able to eventually filter out what we missed: if we observe all the other cars around us slowing down for a few timesteps, we might rationally conclude that we should as well, even though we never saw the stop sign.
More formally, we’re in a Partially Observed Markov Decision Process (a POMDP), which we can visualize as follows:
So, the effect of the stop sign ((c)) is reflected in the states ((s_1, s_2)). There’s two key differences between this graphical model and the TCN graphical model. First, the confounder directly affects the state. Second, the confounder is constant, rather than time-varying. These two characteristics make it plausible that, given enough state observations, we can figure out what (c) is and act as the expert does.
To enable us to accumulate evidence over time, we’re going to allow our policy to condition its actions on the entire history up to this point. Denoting ( h_t = (s_1, a_1, dots, s_{t}) ), our policy now takes the form $$pi: h_t rightarrow a_t .$$ This means our policy is now a sequence model (e.g. an RNN or a transformer).
Ok, we use sequence models as our policy class and apply behavioral cloning (prediction of the next token from the prefix) and call it a day right? Unfortunately, this often produces policies that naively repeat their own past actions. We call this the latching effect. For example, in the self-driving domain, this can lead to learning a driving policy that begins to turn and then just keeps turning, driving in circles [2]. This is clearly less than ideal.
Modeling Unobserved Contexts with Sequence Models + On-Policy Training
Let’s dig into why the latching effect stymies off-policy imitation learners. The model we learn during training looks like $$pi(a_t |h_t) approx p_(a_t^E| s_1^E, a_1^E, dots, s_t^E),$$ where the superscript (E) denotes that these are expert states and actions. Now at test time, the past actions we pass into our sequence model policy are our own rather than the expert’s. This means we’re trying to sample actions from $$p_(a_t^E| s_1, a_1, dots, s_t).$$ Unless we know that the elements to the right of the conditioning bar are exactly the same (i.e. we’ve perfectly matched the expert policy), we have no reason to believe that the model we’ve learned will generalize to our own induced state distribution. Put differently, we have no reason to believe we’ve learned a policy which will perform well at test-time.
In machine learning terms, we’re facing the problem of policy-induced covariate shift (PICS). Here, our covariates are the history of states and actions the policy uses to predict the next action. PICS is a well known problem in imitation learning and sequential prediction writ large (see last blog post for more details). What’s special about the unobserved context setting is the degree to which it causes problems. Early on in an episode, no learner can do well as they haven’t accumulated enough information to perform well. This means that we have an unavoidable source of covariate shift from preliminary mistakes.
Now, let’s assume that the expert’s actions were relatively similar between adjacent timesteps (which is frequently true in practice). This means rather than learning a complex mapping from states to actions, our learner might have instead learned a policy that mostlycopies the last action. This was fine when the previous actions were the expert’s. But when they are the learner’s extremely suboptimal initial actions, naive copying is a recipe for disaster. This is what leads to us learning a driving policy that is only capable of turning in circles. A different way of phrasing this point is because the learner was trained only on trajectories from the expert, it treats its own test-time past actions as though they were produced by the expert, leading to unfortunate downstream consequences.
On-policy training (i.e. performing rollouts in the environment and comparing them to expert rollouts) doesn’t suffer from these issues. This is because we’re instead trying to satisfy $$ p_(a_t^E| s_1^E, a_1^E, dots, s_t^E) approx p_(a_t| s_1, a_1, dots, s_t).$$ In words, rather than matching expert actions on expert histories, we’re instead trying to match expert and learner trajectory distributions. We’re making sure we predict well based on our own histories and therefore can’t be surprised at test-time by states we didn’t expect to end up in. We prove in our paper that under certain identifiability conditions, on-policy training of sequence model policies is both necessary and sufficient for the learner to generalize well to their own induced state distribution and successfully imitate the expert.
We also conduct experiments on continuous control tasks with unobserved contexts and show that equipping an off-policy learner (in grey) with access to history can actually hurt its performance, while we see no such effect for an on-policy learner (in orange).
Discussion
Confounding commonly occurs in real world situations, including those in which we want to make a sequence of decisions. When we don’t filter out or model the effects of what we don’t observe, we can end up making extremely suboptimal decisions. We describe two situations in which there’s clean algorithms for handling confounding in sequential decision making (temporally correlated noise can be filtered out via IVR, unobserved contexts can be modeled via on-policy training of a sequence model). Moving forward, we’re interested in other causal structures (e.g. negative controls or proxy shifts) and applying our techniques to problems where confounders abound (e.g. recommender systems).
If you’re interested in learning more, I recommend you look at the full papers or the attached code:
Alina Beygelzimer, Yann N. Dauphin, Percy Liang, Jennifer Wortman Vaughan (NeurIPS 2021 Program Chairs)
Charvi Rastogi, Ivan Stelmakh, Zhenyu Xue, Hal Daumé III, Emma Pierson, and Nihar B. Shah
There is a considerable body of research on peer review. Within the machine learning community, there have been experiments establishing significant disagreement across reviewers and across reviewer panels—including at NeurIPS 2021—and active discussions about the state of peer review. But how do author perceptions about their submitted papers match up to the outcomes of the peer-review process and perceptions of other authors? We investigate this question by asking authors who submitted papers to NeurIPS 2021 three questions:
(Q1) [At the time of paper submission] What is your best estimate of the probability (as a percentage) that this submission will be accepted?
(Q2) [At the time of paper submission; to authors submitting two or more papers] Rank your submissions in terms of your own perception of their scientific contributions to the NeurIPS community, if published in their current form.
(Q3) [After preliminary reviews were available to authors] After you read the reviews of this paper, how did your perception of the value of its scientific contribution to the NeurIPS community change (assuming it was published in its initially submitted form)?
Here are five key findings.
1. How well do authors estimate the probability of acceptance of their papers?
Authors significantly overestimate their papers’ chances of acceptance. When answering Q1, authors were informed that the acceptance rate at NeurIPS over the last 4 years had been about 21%. The acceptance rate at NeurIPS 2021 turned out to be 25.8%. The authors’ responses had a nearly three-fold overestimate, with a median prediction of 70%.
2. Are some sub-groups better calibrated than others?
We examined calibration error across sub-groups, measuring this error in terms of the Brier score (squared loss) and controlling for other confounders. We find that the calibration error of female authors is slightly (but statistically significantly) higher than that of male authors. We also see a trend of miscalibration decreasing with seniority, with authors who were invited to serve as (meta-)reviewers better calibrated than the rest. All sub-groups we examined over-predicted their papers’ chances of acceptance.
3. Among authors with multiple papers, how much do their predictions of acceptance probabilities agree with their own perceived scientific merit?
These two sets of responses are largely in agreement: The strict ranking provided by authors about their perceived scientific merit (Q2) and the strict ranking induced by their predicted acceptance probabilities (Q1) agree for 93% of responses. However, there is a noticeable 7% of responses where the authors think that the peer review is more likely to reject the better of their two papers.
4. How much do co-authors agree on the relative quality of their joint papers?
Strikingly, the amount of disagreement between co-authors in terms of the perceived relative scientific contribution of their papers (Q2) is similar to the amount of disagreement between authors and reviewers! In cases where one paper from an author was ultimately accepted and another rejected, authors rated the rejected paper higher about a third of the time. But looking at pairs of papers with overlapping authors in which both authors provided rankings, the co-authors also disagreed with each other about a third of the time. While there are discussions in the literature about inter-reviewer disagreements, this result suggests that there is similar disagreement in co-authors’ views of their papers as well.
5. Does peer review change authors’ perception of their own papers?
The question Q3 was a multiple-choice question with five choices: much more positive (“++”), slightly more positive (“+”), did not change (“0”), slightly more negative (“-”), much more negative (“- -”).
We find that among both accepted and rejected papers, about 50% of authors report that their perception about their own paper changed after seeing the initial reviews (Q3). Moreover, among both accepted and rejected papers, over 30% of authors report that their perception became more positive.
Accepted papers
Rejected papers
Discussion
The fact that authors vastly overestimated the probability that their papers will be accepted suggests it would be useful for conference organizers and research mentors to attempt to recalibrate expectations prior to each conference. The disagreements we document around paper quality — between co-authors as well as between authors and reviewers — taken together with the disagreement among committees of reviewers observed in the complementary NeurIPS 2021 consistency experiment, suggest that assessing paper quality is not only an extremely noisy process, but may be a fundamentally challenging task with no objective right answer. The outcomes of paper submissions should thus be taken with a grain of salt. More broadly, as a community, we may take these findings into account when deciding on our policies and perceptions pertaining to the peer-review process and its outcomes. We hope the results of our experiment encourage discussion and introspection in the community.