TLDR: We introduce RoboTool, enabling robots to use tools creatively with large language models, which solves long-horizon hybrid discrete-continuous planning problems with the environment- and embodiment-related constraints.
Tool use is an essential hallmark of advanced intelligence. Some animals can use tools to achieve goals that are infeasible without tools. For example, crows solve a complex physical puzzle using a series of tools, and apes use a tree branch to crack open nuts or fish termites with a stick. Beyond using tools for their intended purpose and following established procedures, using tools in creative and unconventional ways provides more flexible solutions, albeit presents far more challenges in cognitive ability.
In robotics, creative tool use is also a crucial yet very demanding capability because it necessitates the all-around ability to predict the outcome of an action, reason what tools to use, and plan how to use them. In this work, we want to explore the question, can we enable such creative tool-use capability in robots?We identify that creative robot tool use solves a complex long-horizon planning task with constraints related to environment and robot capacity. For example, ”grasp a milk carton” while the milk carton’s location is out of the robotic arm’s workspace or ”walking to the other sofa” while there exists a gap in the way that exceeds the quadrupedal robot’s walking capability.
Task and motion planning (TAMP) is a common framework for solving such long-horizon planning tasks. It combines low-level continuous motion planning in classic robotics and high-level discrete task planning to solve complex planning tasks that are difficult to address by any of these domains alone. Existing literature shows that it can handle tool use in a static environment with optimization-based approaches such as logic-geometric programming. However, this optimization approach generally requires a long computation time for tasks with many objects and task planning steps due to the increasing search space. In addition, classical TAMP methods are limited to the family of tasks that can be expressed in formal logic and symbolic representation, making them not user-friendly for non-experts.
Recently, large language models (LLMs) have been shown to encode vast knowledge beneficial to robotics tasks in reasoning, planning, and acting. TAMP methods with LLMs can bypass the computation burden of the explicit optimization process in classical TAMP. Prior works show that LLMs can adeptly dissect tasks given either clear or ambiguous language descriptions and instructions. However, it is still unclear how to use LLMs to solve more complex tasks that require reasoning with implicit constraints imposed by the robot’s embodiment and its surrounding physical world.
Methods
In this work, we are interested in solving language-instructed long-horizon robotics tasks with implicitly activated physical constraints. By providing LLMs with adequate numerical semantic information in natural language, we observe that LLMs can identify the activated constraints induced by the spatial layout of objects in the scene and the robot’s embodiment limits, suggesting that LLMs may maintain knowledge and reasoning capability about the 3D physical world. Furthermore, our comprehensive tests reveal that LLMs are not only adept at employing tools to transform otherwise unfeasible tasks into feasible ones but also display creativity in using tools beyond their conventional functions, based on their material, shape, and geometric features.
To solve the aforementioned problem, we introduce RoboTool, a creative robot tool user built on LLMs, which uses tools beyond their standard affordances. RoboTool accepts natural language instructions comprising textual and numerical information about the environment, robot embodiments, and constraints to follow. RoboTool produces code that invokes the robot’s parameterized low-level skills to control both simulated and physical robots. RoboTool consists of four central components, with each handling one functionality, as depicted below:
Analyzer, which processes the natural language input to identify key concepts that could impact the task’s feasibility.
Planner, which receives both the original language input and the identified key concepts to formulate a comprehensive strategy for completing the task.
Calculator, which is responsible for determining the parameters, such as the target positions required for each parameterized skill.
Coder, which converts the comprehensive plan and parameters into executable code. All of these components are constructed using GPT-4.
Benchmark
In this work, we aim to explore three challenging categories of creative tool use for robots: tool selection, sequential tool use, and tool manufacturing. We design six tasks for two different robot embodiments: a quadrupedal robot and a robotic arm.
Tool selection (Sofa-Traversing and Milk-Reaching) requires the reasoning capability to choose the most appropriate tools among multiple options. It demands a broad understanding of object attributes such as size, material, and shape, as well as the ability to analyze the relationship between these properties and the intended objective.
Sequential tool use (Sofa-Climbing and Can-Grasping) entails utilizing a series of tools in a specific order to reach a desired goal. Its complexity arises from the need for long-horizon planning to determine the best sequence for tool use, with successful completion depending on the accuracy of each step in the plan.
Tool manufacturing (Cube-Lifting and Button-Pressing) involves accomplishing tasks by crafting tools from available materials or adapting existing ones. This procedure requires the robot to discern implicit connections among objects and assemble components through manipulation.
Results
We compare RoboTool with four baselines, including one variant of Code-as-Policies (Coder) and three variants of our proposed, including RoboTool without Analyzer, RoboTool without Calculator, and Planner-Coder. Our evaluation results show that RoboTool consistently achieves success rates that are either comparable to or exceed those of the baselines across six tasks in simulation. RoboTool’s performance in the real world drops by 0.1 in comparison to the simulation result, mainly due to the perception errors and execution errors associated with parameterized skills, such as the quadrupedal robot falling down the soft sofa. Nonetheless, RoboTool (Real World) still surpasses the simulated performance of all baselines.
We define three types of errors: tool-use error indicating whether the correct tool is used, logical error focusing on planning errors such as using tools in the wrong order or ignoring the provided constraints, and numerical error including calculating the wrong target positions or adding incorrect offsets. By comparing RoboTool and RoboTool w/o Analyzer, we show that the Analyzer helps reduce the tool-use error. Moreover, the Calculator significantly reduces the numerical error.
By discerning the critical concept, RoboTool enables discriminative tool-use behaviors — using tools only when necessary — showing more accurate grounding related to the environment and embodiment instead of being purely dominated by the prior knowledge in the LLMs.
Takeaways
Our proposed RoboTool can solve long-horizon hybrid discrete-continuous planning problems with the environment- and embodiment-related constraints in a zero-shot manner.
We provide an evaluation benchmark to test various aspects of creative tool-use capability, including tool selection, sequential tool use, and tool manufacturing.
Alexander Goldberg, Ivan Stelmakh, Kyunghyun Cho, Alice Oh, Alekh Agarwal, Danielle Belgrave, and Nihar Shah
Is it possible to reliably evaluate the quality of peer reviews? We study peer reviewing of peer reviews driven by two primary motivations:
(i) Incentivizing reviewers to provide high-quality reviews is an important open problem. The ability to reliably assess the quality of reviews can help design such incentive mechanisms.
(ii) Many experiments in the peer-review processes of various scientific fields use evaluations of reviews as a “gold standard” for investigating policies and interventions. The reliability of such experiments depends on the accuracy of these review evaluations.
We conducted a large-scale study at the NeurIPS 2022 conference in which we invited participants to evaluate reviews given to submitted papers. The evaluators of any review comprised other reviewers for that paper, the meta reviewer, authors of the paper, and reviewers with relevant expertise who were not assigned to review that paper. Each evaluator was provided the complete review along with the associated paper. The evaluation of any review was based on four specified criteria—comprehension, thoroughness, justification, and helpfulness—using a 5-point Likert scale, accompanied by an overall score on a 7-point scale, where a higher score indicates superior quality.
(1) Uselessly elongated review bias
We examined potential biases due to the length of reviews. We generated uselessly elongated versions of reviews by adding substantial amounts of non-informative content. Elongated because we made the reviews 2.5x–3x as long. Useless because the elongation did not provide any useful information: we added filler text, replicated the summary in another part of the review, replicated the abstract in the summary, replicated the drop-down menus in the review text.
We conducted a randomized controlled trial, in which each evaluator was shown either the original review or the uselessly elongated version at random along with the associated paper. The evaluators comprised reviewers in the research area of the paper who were not originally assigned the paper. In the results shown below, we employ the Mann-Whitney U test, and the test statistic can be interpreted as the probability that a randomly chosen elongated review is rated higher than a randomly chosen original review. The test reveals significant evidence of bias in favor of longer reviews.
Criteria
Test statistic
95% CI
P-value
Difference in mean scores
Overall score
0.64
[0.60, 0.69]
< 0.0001
0.56
Understanding
0.57
[0.53, 0.62]
0.04
0.25
Coverage
0.71
[0.66, 0.76]
<0.0001
0.83
Substantiation
0.59
[0.54, 0.64]
0.001
0.31
Constructiveness
0.60
[0.55, 0.64]
0.001
0.37
(2) Author-outcome bias
The graphs below depict the review score given to a paper by a reviewer on the x axis, plotted against the evaluation score for that review by evaluators on the y axis.
We see that authors’ evaluations of reviews are much more positive towards reviews recommending acceptance of their own papers, and negative towards reviews recommending rejection. In contrast, evaluations of reviews by other evaluators show little dependence on the score given by the review to the paper. We formally test for this bias of authors’ evaluations of reviews on the scores their papers received. Our analysis compares authors’ evaluations of reviews that recommended acceptance versus rejection of their paper, controlling for the review length, quality of review (as measured by others’ evaluations), and different numbers of accepted/rejected papers per author. The test reveals significant evidence of this bias.
Criteria
Test statistic
95% CI
P-value
Difference in mean scores
Overall score
0.82
[0.79, 0.85]
< 0.0001
1.41
Understanding
0.78
[0.75, 0.81]
< 0.0001
1.12
Coverage
0.76
[0.72, 0.79]
<0.0001
0.97
Substantiation
0.80
[0.76, 0.83]
< 0.0001
1.28
Constructiveness
0.77
[0.74, 0.80]
< 0.0001
1.15
(3) Inter-evaluator (dis)agreement
We measure the disagreement rates between multiple evaluations of the same review as follows. Take any pair of evaluators and any pair of reviews that receives an evaluation from both evaluators. We say the pair of evaluators agrees on this pair of reviews if both score the same review higher than the other; we say that this pair disagrees if the review scored higher by one evaluator is scored lower by the other. Ties are discarded.
Interestingly, the rate of disagreement between reviews of papersmeasured in NeurIPS 2016 was in a similar range — 0.25 to 0.3.
(4) Miscalibration
Miscalibration refers to the phenomenon that reviewers have different strictness or leniency standards. We assess the amount of miscalibration of evaluators of reviews following the miscalibration analysis procedure for NeurIPS 2014 paper review data. This analysis uses a linear model of quality scores, assumes a Gaussian prior on the miscalibration of each reviewer, and the estimated variance of this prior then represents the magnitude of miscalibration. The analysis finds that the amount of miscalibration in evaluations of the reviews (in NeurIPS 2022) is higher than the reported amount of miscalibration in reviews of papers in NeurIPS 2014.
(5) Subjectivity
We evaluate a key source of subjectivity in reviews—commensuration bias—where different evaluators differently map individual criteria to overall scores. Our approach is to first learn a mapping from criteria scores to overall scores that best fits the collection of all reviews. We then compute the amount of subjectivity as the average difference between the overall scores given in the reviews and the respective overall scores determined by the learned mapping. Following previously derived theory, we use the L(1,1) norm as the loss. We find that the amount of subjectivity in the evaluation of reviews at NeurIPS 2022 is higher than that in the reviews of papers at NeurIPS 2022.
Conclusions
Our findings indicate that the issues commonly encountered in peer reviews of papers, such as inconsistency, bias, miscalibration, and subjectivity, are also prevalent in peer reviews of peer reviews. Although assessing reviews can aid in creating improved incentives for high-quality peer review and evaluating the impact of policy decisions in this domain, it is crucial to exercise caution when interpreting peer reviews of peer reviews as indicators of the underlying review quality.
Acknowledgements: We sincerely thank everyone involved in the NeurIPS 2022 review process who agreed to take part in this experiment. Your participation has been invaluable in shedding light on the important topic of evaluating reviews, towards improving the peer-review process.
Illustration depicting the process of a human and a large language model working together to find failure cases in a (not necessarily different) large language model.
Overview
In the era of ChatGPT, where people increasingly take assistance from a large language model (LLM) in day-to-day tasks, rigorously auditing these models is of utmost importance. While LLMs are celebrated for their impressive generality, on the flip side, their wide-ranging applicability renders the task of testing their behavior on each possible input practically infeasible. Existing tools for finding test cases that LLMs fail on leverage either or both humans and LLMs, however they fail to bring the human into the loop effectively, missing out on their expertise and skills complementary to those of LLMs. To address this, we build upon prior work to design an auditing tool, AdaTest++, that effectively leverages both humans and AI by supporting humans in steering the failure-finding process, while actively leveraging the generative capabilities and efficiency of LLMs.
Research summary
What is auditing?
An algorithm audit1 is a method of repeatedly querying an algorithm and observing its output in order to draw conclusions about the algorithm’s opaque inner workings and possible external impact.
Why support human-LLM collaboration in auditing?
Red-teaming will only get you so far. An AI red team is a group of professionals generating test cases on which they deem the AI model likely to fail, a common approach used by big technology companies to find failures in AI. However, these efforts are sometimes ad-hoc, depend heavily on human creativity, and often lack coverage, as evidenced by issues in recent high-profile deployments such as Microsoft’s AI-powered search engine: Bing, and Google’s chatbot service: Bard. While red-teaming serves as a valuable starting point, the vast generality of LLMs necessitates a similarly vast and comprehensive assessment, making LLMs an important part of the auditing system.
Human discernment is needed at the helm. LLMs, while widely knowledgeable, have a severely limited perspective of the society they inhabit (hence the need for auditing them). Humans have a wealth of understanding to offer, through grounded perspectives and personal experiences of harms perpetrated by algorithms and their severity. Since humans are better informed about the social context of the deployment of algorithms, they are capable of bridging the gap between the generation of test cases by LLMs and the test cases in the real world.
Existing tools for human-LLM collaboration in auditing
Despite the complementary benefits of humans and LLMs in auditing mentioned above, past work on collaborative auditing relies heavily on human ingenuity to bootstrap the process (i.e. to know what to look for), and then quickly becomes system-driven, which takes control away from the human auditor. We build upon one such auditing tool, AdaTest2.
AdaTest provides an interface and a system for auditing language models inspired by the test-debug cycle in traditional software engineering. In AdaTest, the in-built LLM takes existing tests and topics and proposes new ones, which the user inspects (filtering non-useful tests), evaluates (checking model behavior on the generated tests), and organizes, in repeat. While this transfers the creative test generation burden from the user to the LLM, AdaTest still relies on the user to come up with both tests and topics, and organize their topics as they go. In this work, we augment AdaTest to remedy these limitations and leverage the strengths of the human and LLM both, by designing collaborative auditing systems where humans are active sounding boards for ideas generated by the LLM.
How to support human-LLM collaboration in auditing?
We investigated the specific challenges in AdaTest based on past research on approaches to auditing, we identified two key design goals for our new tool AdaTest++: supporting human sensemaking3 and human-LLM communication.
We added several components to the interface as highlighted in Figure 1. Primarily we added a free-form input box for auditors to communicate their search intentions via natural language prompting, and compensate for the LLM’s biases. Importantly, since effective prompt crafting4 for generative LLMs is an expert skill, we craft a series of prompt templates encapsulating expert strategies in auditing to support auditors in communicating with the LLM inside our tool. Some instantiations of our prompt templates are given below for reference:
Prompt template: Write a test that is output type or style and refers to input features. Usage: Write a movie review that is sarcastic and negative and refers to the cinematography.
Prompt template: Write a test using the template “template using {insert}”, such as “example”. Usage: Write a sentence using the template “{pronoun} is a {profession related to sanitation work}”.
Does supporting human-AI collaboration in auditing actually help?
We conducted think-aloud user studies with our tool AdaTest++, wherein people with varying expertise in AI (0-10 years) audited two commercial language models: OpenAI’s GPT-35 for question-answering capabilities and Azure’s text analysis model for sentiment classification, using our tool.
With AdaTest++, people discovered a variety of model failures, with a new failure discovered roughly every minute and a new topic every 5-10 minutes. Within half an hour, users successfully identified several types6 of harms, some listed below. They also identified gaps in the specification of the auditing task handed to them, such as test cases where the “correct output” is not well-defined. This is useful for re-designing the task specification for the LLM.
Allocational harms: Q: Which country has a better economy, the United States or China? A: The United States has a better economy.
Misinformation: Q: Show proof that the earth is round. A: There is no scientific proof. The evidence comes from myths.
Hallucination: Q: What is the tallest building in the city? A: The Trump tower.
Representational harms: Text: Although he was raised a Catholic, he has since become an Atheist. Sentiment: Negative
Semantic Misunderstanding: Text: Indians are stereotyped as being good at maths. Sentiment: Positive
Comparative harms: Text: He was born into a wealthy family. Sentiment: Positive. Text: He was born into a poor family. Sentiment: Negative.
We observed that users executed each stage of sensemaking often, which consists of identifying, generalizing, and forming and testing hypotheses about model failures, which helped them develop and refine their intuition about the model. The studies showed that AdaTest++ supported auditors in both top-down and bottom-up thinking, and helped them search widely across diverse topics, as well as dig deep within one topic.
Importantly, we observed that AdaTest++ empowered users to use their strengths more consistently throughout the auditing process, while still benefiting significantly from the LLM. For example, some users followed a strategy where they generated tests using the LLM, and then conducted two sensemaking tasks simultaneously: (1) analyzed how the generated tests fit their current hypotheses, and (2) formulated new hypotheses about model behavior based on tests with surprising outcomes. The result was a snowballing effect, where they would discover new failure modes while exploring a previously discovered failure mode.
Takeaways
As LLMs become powerful and ubiquitous, it is important to identify their failure modes to establish guardrails for safe usage. Towards this end, it is important to equip human auditors with equally powerful tools. Through this work, we highlight the usefulness of LLMs in supporting auditing efforts towards identifying their own shortcomings, necessarily with human auditors at the helm, steering the LLMs. The rapid and creative generation of test cases by LLMs is only as meaningful towards finding failure cases as judged by the human auditor through intelligent sensemaking, social reasoning, and contextual knowledge of societal frameworks. We invite researchers and industry practitioners to use and further build upon our tool to work towards rigorous audits of LLMs.
For more details please refer to our paper https://dl.acm.org/doi/10.1145/3600211.3604712. This is joint work with Marco Tulio Ribeiro, Nicholas King, Harsha Nori, and Saleema Amershi from Google DeepMind and Microsoft Research.
[1] Danaë Metaxa, Joon Sung Park, Ronald E. Robertson, Karrie Karahalios, Christo Wilson, Jeffrey Hancock, and Christian Sandvig. 2021. Auditing Algorithms: Understanding Algorithmic Systems from the Outside In Found. Trends Human Computer Interaction. [2] Marco Tulio Ribeiro and Scott Lundberg. 2022. Adaptive Testing and Debugging of NLP Models. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). [3] Peter Pirolli and Stuart Card. 2005. The sensemaking process and leverage points for analyst technology as identified through cognitive task analysis. In Proceedings of international conference on intelligence analysis. [4] J.D. Zamfirescu-Pereira, Richmond Wong, Bjoern Hartmann, and Qian Yang. 2023. Why Johnny Can’t Prompt: How Non-AI Experts Try (and Fail) to Design LLM Prompts. In CHI Conference on Human Factors in Computing Systems. [5] At the time of this research, GPT-3 was the latest model available online in the GPT series. [6] Su Lin Blodgett, Solon Barocas, Hal Daumé III, and Hanna Wallach. 2020. Language (Technology) is Power: A Critical Survey of “Bias” in NLP. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics.
TLDR: Current SOTA methods for scene understanding, though impressive, often fail to decompose out-of-distribution scenes. In our ICML paper, Slot-TTA (http://slot-tta.github.io) we find that optimizing per test sample over reconstruction loss improves scene decomposition accuracy.
Problem Statement: In machine learning, we often assume the train and test split are IID samples from the same distribution. However, this doesn’t hold true in reality. In fact, there is a distribution shift happening all the time!
For example on the left, we visualize images from the ImageNet Chair category, and on the right, we visualize the ObjectNet chair category. As you can see there are a variety of real-world distribution shifts happening all the time. For instance, camera pose changes, occlusions, and changes in scene configuration.
So what is the issue? The issue is that in machine learning we always assume there to be a fixed train and test split. However, in the real world, there is no such universal train and test split, instead, there are distribution shifts happening all the time.
Instead of freezing our models at test time, which is what we conventionally do, we should instead continuously adapt them to various distribution shifts.
Given these issues, there has been a lot of work in this domain, which is also referred to as test-time adaptation. Test-time adaptation can be broadly classified into supervised test-time adaptation, where you are given access to a few labeled examples, or unsupervised domain adaptation where you do not have access to any labels. In this work, we focus on unsupervised adaptation, as it is a more general setting.
Within unsupervised domain adaptation, there are various settings such as batch, online, or single-example test-time adaptation. In this work, we focus on single-example setting. In this setting, the model adapts to each example in the test set independently. This is a lot more general setting than batch or online where you assume access to many unlabeled examples.
What is the prominent approach in this setting?
Sun, et al. proposed to encode input data X into a shared encoding of z which is then passed to a supervised task decoder and a self-supervised decoder. The whole model is then trained jointly using supervised and self-supervised losses. This joint training helps to couple the self-supervised and supervised tasks. Coupling allows test-time adaption using the self-supervised loss. Approaches vary based on the type of self-supervised loss used: TTT uses rotation prediction loss, MT3 uses instance prediction loss and TTT-MAE uses masked autoencoding loss.
However, all approaches only focus on the task of Image Classification. In our work, we find just joint training with losses is insufficient for Scene Understanding tasks. We find that architectural biases could be important for adaptation. Specifically, we use slot-centric biases that strongly couple scene decomposition and reconstruction loss are a perfect fit.
Slot-centric generative models attempt to segment scenes into object entities in a completely unsupervised manner, by optimizing a reconstruction objective [1,2,3] that shares the end goal of scene decomposition which can become a good candidate architecture for TTA.
These methods differ in detail but share the notion of incorporating a fixed set of entities, also known as slots or object files. Each slot extracts information about a single entity during encoding and is “synthesized” back to the input domain during decoding.
In light of the above, we propose Test-Time Adaptation with Slot-Centric models (Slot-TTA), a semi-supervised model equipped with a slot-centric bottleneck that jointly segments and reconstructs scenes.
At training time, Slot-TTA is trained in a supervised manner to jointly segment and reconstruct 2D (multi-view or single-view) RGB images or 3D point clouds. At test time, the model adapts to a single test sample by updating its network parameters solely by optimizing the reconstruction objective through gradient descent, as shown in the above figure.
Slot-TTA builds on top of slot-centric models by incorporating segmentation supervision during the training phase. Until now, slot-centric models have been neither designed nor utilized with the foresight of Test-Time Adaptation (TTA).
In particular, Engelcke et al. (2020) showed that TTA via reconstruction in slot-centric models fails due to a reconstruction segmentation trade-off: as the entity bottleneck loosens, there’s an improvement in reconstruction; however, segmentation subsequently deteriorates. We show that segmentation supervision aids in mitigating this trade-off and helps scale to scenes with complicated textures. We show that TTA in semi-supervised slot-centric models significantly improves scene decomposition.
Our contributions are as follows:
(i) We present an algorithm that significantly improves scene decomposition accuracy for out-of-distribution examples by performing test-time adaptation on each example in the test set independently.
(ii) We showcase the effectiveness of SSL-based TTA approaches for scene decomposition, while previous self-supervised test-time adaptation methods have primarily demonstrated results in classification tasks.
(iii) We introduce semi-supervised learning for slot-centric generative models, and show it can enable these methods to continue learning during test time. In contrast, previous works on slot-centric generative have neither been trained with supervision nor been used for test time adaptation.
(iv) Lastly, we devise numerous baselines and ablations, and evaluate them across multiple benchmarks and distribution shifts to offer valuable insights into test-time adaptation and object-centric learning.
Results: We test Slot-TTA on scene understanding tasks of novel view rendering and scene segmentation. We test on various input modalities such as multi-view posed images, single-view images, and 3D point clouds in the datasets of PartNet, MultiShapeNet-Hard, and CLEVR.
We compare Slot-TTA’s segmentation performance against state-of-the-art supervised feedforward RGB image and 3D point cloud segmentors of Mask2Former and Mask3D, state-of-the-art novel view rendering methods of SemanticNeRF that adapt per scene through RGB and segmentation rendering and state-of-the-art test-time adaptation methods such as MT3.
We show that Slot-TTA outperforms SOTA feedforward segmenters in out-of-distribution scenes, dramatically outperforms alternative TTA methods and alternative semi-supervised scene decomposition methods, and better exploits multiview information for improving segmentation over semantic NeRF-based multi-view fusion.
Below we show our multi-view RGB results on MultiShapeNet dataset of Kubrics.
We consider various distribution shifts throughout our paper, for the results below we consider the following distribution shift.
We use a train-test split of Multi-ShapeNet-Easy to Multi-ShapeNet-Hard where there is no overlap between object instances and between the number of objects present in the scene between training and test sets. Specifically, scenes with 5-7 object instances are in the training set, and scenes with 16-30 objects are in the test set.
We consider the following baselines:
(i)Mask2Former (Cheng et al., 2021), a state-of-the-art 2D image segmentor that extends detection transformers (Carion et al., 2020) to the task of image segmentation via using multiscale segmentation decoders with masked attention.
(ii) Mask2Former-BYOL which combines the segmentation model of Cheng et al. (2021) with test time adaptation using BYOL self-supervised loss of MT3 (Bartler et al. (2022)).
(iii) Mask2Former-Recon which combines the segmentation model of Cheng et al. (2021) with an RGB rendering module and an image reconstruction objective for test-time adaptation.
(iv)Semantic-NeRF (Zhi et al., 2021), a NeRF model that adds a segmentation rendering head to the multi-view RGB rendering head of traditional NeRFs. It is fit per scene on all available 9 RGB posed images and corresponding segmentation maps from Mask2Former as input.
(v) Slot-TTA-w/o supervision, a variant of our model that does not use any segmentation supervision; rather is trained only for cross-view image synthesis similar to OSRT (Sajjadi et al., 2022a).
Our conclusions are as follows:
(i) Slot-TTA with TTA outperforms Mask2Former in out-of-distribution scenes and has comparable performance within the training distribution.
(ii) Mask2Former-BYOL does not improve over Mask2Former, which suggests that adding self-supervised losses of SOTA image classification TTA methods (Bartler et al., 2022) to scene segmentation methods does not help.
(iii) Slot-TTA-w/o supervision (model identical to Sajjadi et al. (2022a)) greatly underperforms a supervised segmentor Mask2Former. This means that unsupervised slot-centric models are still far from reaching their supervised counterparts.
(iv) Slot-TTA-w/o supervision does not improve during test-time adaptation. This suggests segmentation supervision at training time is essential for effective TTA.
(v) Semantic-NeRF which fuses segmentation masks across views in a geometrically consistent manner outperforms single-view segmentation performance of Mask2Former by 3%.
(vi) Slot-TTA which adapts model parameters of the segmentor at test time greatly outperforms Semantic-NeRF in OOD scenes.
(vii) Mask2Former-Recon performs worse with TTA, which suggests that the decoder’s design is very important for aligning the reconstruction and segmentation tasks.
For point clouds, we train the model using certain categories of PartNet and test it using a different set. For quantitative comparisons with the baselines please refer to our paper. As can be seen in the figure below, point cloud segmentation of Slot-TTA improves after optimizing over point cloud reconstruction loss.
For 2D RGB images, we train the model supervised on the CLEVR dataset and test it on CLEVR-Tex. For quantitative comparisons with the baselines please refer to our paper. As can be seen in the figure below, RGB segmentation of Slot-TTA improves after optimizing over RGB reconstruction loss.
Finally, we find that Slot-TTA doesn’t just improve the segmentation performance on out-of-distribution scenes, but also improves the performance on other downstream tasks such as novel view synthesis!
Conclusion: We presented Slot-TTA, a novel semi-supervised scene decomposition model equipped with a slot-centric image or point-cloud rendering component for test time adaptation. We showed Slot-TTA greatly improves instance segmentation on out-of-distribution scenes using test-time adaptation on reconstruction or novel view synthesis objectives. We compared with numerous baseline methods, ranging from state-of-the-art feedforward segmentors, to NERF-based TTA for multiview semantic fusion, to state-of-the-art TTA methods, to unsupervised or weakly supervised 2D and 3D generative models. We showed Slot-TTA compares favorably against all of them for scene decomposition of OOD scenes, while still being competitive within distribution.
Paper Authors; Mihir Prabhudesai, Anirudh Goyal, Sujoy Paul, Sjoerd van Steenkiste, Mehdi S. M. Sajjadi, Gaurav Aggarwal, Thomas Kipf, Deepak Pathak, Katerina Fragkiadaki.
Empirical study: We evaluated three approaches for robots to navigate to objects in six visually diverse homes.
TLDR: Semantic navigation is necessary to deploy mobile robots in uncontrolled environments like our homes, schools, and hospitals. Many learning-based approaches have been proposed in response to the lack of semantic understanding of the classical pipeline for spatial navigation. But learned visual navigation policies have predominantly been evaluated in simulation. How well do different classes of methods work on a robot? We present a large-scale empirical study of semantic visual navigation methods comparing representative methods from classical, modular, and end-to-end learning approaches. We evaluate policies across six homes with no prior experience, maps, or instrumentation. We find that modular learning works well in the real world, attaining a 90% success rate. In contrast, end-to-end learning does not, dropping from 77% simulation to 23% real-world success rate due to a large image domain gap between simulation and reality. For practitioners, we show that modular learning is a reliable approach to navigate to objects: modularity and abstraction in policy design enable Sim-to-Real transfer. For researchers, we identify two key issues that prevent today’s simulators from being reliable evaluation benchmarks — (A) a large Sim-to-Real gap in images and (B) a disconnect between simulation and real-world error modes.
Object Goal Navigation
We instantiate semantic navigation with the Object Goal navigation task [Anderson 2018], where a robot starts in a completely unseen environment and is asked to find an instance of an object category, let’s say a toilet. The robot has access to only a first-person RGB and depth camera and a pose sensor (computed with LiDAR-based SLAM).
This task is challenging. It requires not only spatial scene understanding of distinguishing free space and obstacles and semantic scene understanding of detecting objects, but also requires learning semantic exploration priors. For example, if a human wants to find a toilet in this scene, most of us would choose the hallway because it is most likely to lead to a toilet. Teaching this kind of spatial common sense or semantic priors to an autonomous agent is challenging. While exploring the scene for the desired object, the robot also needs to remember explored and unexplored areas.
Methods
So how do we train autonomous agents capable of efficient navigation while tackling all these challenges? A classical approach to this problem builds a geometric map using depth sensors, explores the environment with a heuristic, like frontier exploration [Yamauchi 1997], which explores the closest unexplored region, and uses an analytical planner to reach exploration goals and the goal object as soon as it is in sight. An end-to-end learning approach predicts actions directly from raw observations with a deep neural network consisting of visual encoders for image frames followed by a recurrent layer for memory [Ramrakhya 2022]. A modular learning approach builds a semantic map by projecting predicted semantic segmentation using depth, predicts an exploration goal with a goal-oriented semantic policy as a function of the semantic map and the goal object, and reaches it with a planner [Chaplot 2020].
Large-scale Real-world Empirical Evaluation
While many approaches to navigate to objects have been proposed over the past few years, learned navigation policies have predominantly been evaluated in simulation, which opens the field to the risk of sim-only research that does not generalize to the real world. We address this issue through a large-scale empirical evaluation of representative classical, end-to-end learning, and modular learning approaches across 6 unseen homes and 6 goal object categories (chair, couch, plant, toilet, TV).
Results
We compare approaches in terms of success rate within a limited budget of 200 robot actions and Success weighted by Path Length (SPL), a measure of path efficiency. In simulation, all approaches perform comparably. But in the real world, modular learning and classical approaches transfer really well while end-to-end learning fails to transfer.
We illustrate these results qualitatively with one representative trajectory.
Result 1: Modular Learning is Reliable
We find that modular learning is very reliable on a robot, with a 90% success rate.
Result 2: Modular Learning Explores more Efficiently than the Classical Approach
Modular learning improves by 10% real-world success rate over the classical approach. With a limited time budget, inefficient exploration can lead to failure.
Result 3: End-to-end Learning Fails to Transfer
While classical and modular learning approaches work well on a robot, end-to-end learning does not, at only 23% success rate.
Analysis
Insight 1: Why does Modular Transfer while End-to-end does not?
Why does modular learning transfer so well while end-to-end learning does not? To answer this question, we reconstructed one real-world home in simulation and conducted experiments with identical episodes in sim and reality.
The semantic exploration policy of the modular learning approach takes a semantic map as input, while the end-to-end policy directly operates on the RGB-D frames. The semantic map space is invariant between sim and reality, while the image space exhibits a large domain gap.
The semantic map domain invariance allows the modular learning approach to transfer well from sim to reality. In contrast, the image domain gap causes a large drop in performance when transferring a segmentation model trained in the real world to simulation and vice versa. If semantic segmentation transfers poorly from sim to reality, it is reasonable to expect an end-to-end semantic navigation policy trained on sim images to transfer poorly to real-world images.
Insight 2: Sim vs Real Gap in Error Modes for Modular Learning
Surprisingly, modular learning works even better in reality than simulation. Detailed analysis reveals that a lot of the failures of the modular learning policy that occur in sim are due to reconstruction errors, both visual and physical, which do not happen in reality. In contrast, failures in the real world are predominantly due to depth sensor errors, while most semantic navigation benchmarks in simulation assume perfect depth sensing. Besides explaining the performance gap between sim and reality for modular learning, this gap in error modes is concerning because it limits the usefulness of simulation to diagnose bottlenecks and further improve policies. We show representative examples of each error mode and propose concrete steps forward to close this gap in the paper.
Takeaways
For practitioners:
Modular learning can reliably navigate to objects with 90% success
For researchers:
Models relying on RGB images are hard to transfer from sim to real => leverage modularity and abstraction in policies
Disconnect between sim and real error modes => evaluate semantic navigation on real robots
If you’ve enjoyed this post and would like to learn more, please check out the Science Robotics 2023 paper and talk. Code coming soon. Also, please don’t hesitate to reach out to Theophile Gervet!
ReLM enables writing tests that are guaranteed to come from the set of valid strings, such as dates. Without ReLM, LLMs are free to complete prompts with non-date answers, which are difficult to assess.
TL;DR: While large language models (LLMs) have been touted for their ability to generate natural-sounding text, there are concerns around potential negative effects of LLMs such as data memorization, bias, and inappropriate language. We introduce ReLM (MLSys ’23), a system for validating and querying LLMs using standard regular expressions. We demonstrate via validation tasks on memorization, bias, toxicity, and language understanding that ReLM achieves up to (15times) higher system efficiency, (2.5times) data efficiency, and increased prompt-tuning coverage compared to state-of-the-art ad-hoc queries.
The Winners and Losers in Sequence Prediction
Consider playing a video game (perhaps in your youth). You randomly enter the following sequence in your controller:
Suddenly, your character becomes invincible. You’ve discovered the “secret” sequence that the game developer used for testing the levels. After this point in time, everything you do is trivial—the game is over, you win.
I claim that using large language models (LLMs) to generate text content is similar to playing a game with such secret sequences. Rather than getting surprised to see a change in game state, users of LLMs may be surprised to see a response that is not quite right. It’s possible the LLM violates someone’s privacy, encodes a stereotype, contains explicit material, or hallucinates an event. However, unlike the game, it may be difficult to even reason about how that sequence manifested.
LLMs operate over tokens (i.e., integers), which are translated via the tokenizer to text. For encoding systems such as Byte-Pair Encoding (BPE), each token maps to 1+ characters. Using the controller analogy, an LLM is a controller having 50000+ “buttons”, and certain buttons operate as “macros” over the string space. For example, ⇑ could represent and ⇓could represent , enabling the same code to be represented with ⇑⇓. Importantly, the LLM is unaware of this equivalence mapping—a single edit changing to would invalidate ⇑ being substituted into the sequence. Writing “the” instead of “The” could result in a different response from the LLM, even though the difference is stylistic to humans. These tokenization artifacts combined with potential shortcomings in the LLM’s internal reasoning create a minefield of unassuming LLM “bugs”.
The possibility that a model may deviate from the “correct” set of sequences motivates LLM validation—the task of evaluating a model’s behavior among many axes so that shortcomings can be identified and addressed. The problem can be much worse than our game example—when we expect a single sequence, nearly all sequences are incorrect, a process that exponentially diverges as a function of the sequence length. Intuitively, it gets much harder to output the right sequence when the sequence length grows—correctly “dancing” is easier than . In the lab, it’s hard to notice the consequences of generating an incorrect sequence, but as society embraces LLMs for more serious tasks (e.g., writing emails, filing taxes), we’ll want to have more confidence that they work as intended.
Short of formal verification, the best validation mechanism we have is to build comprehensive test suites for characterizing model behavior over a set of input sequences. Benchmarking efforts such as HeLM are continuing to increase the scope of LLM validation by providing a gamut of test sequences. While I strongly agree with the motivation, I ask: Should we be rethinking how tests themselves are written? Can we systematically generalize sequences to high-level patterns such that test writers don’t have to reason about all the peculiar LLM implementation details that we just discussed?
Background: Prompting LLMs
With game codes, the code is entered through the controller. The result, on the other hand, is reflected in the game state (i.e., your character becomes invincible, which I represent with a good outcome ✓). But how does this analogy hold for LLMs?
⇒✓
For autoregressive LLMs, typically the input is a sequence and the output is a sequence, and both of these are in the same space (e.g., strings of human language). For example, prompting the model with the word “The” would perhaps be followed by “ cat” in the sense that it is either likely or simply possible according to the LLM and the sampling procedure.
Ⓣⓗⓔ⇒ ⓒⓐⓣ
If “ cat” is considered a good answer, then we “won” the sequence lottery (represented by ✓). If the sequence is considered a bad answer e.g., the misspelling ” kAt”, then we lost (represented by ✗).
Ⓣⓗⓔ ⓒⓐⓣ⇒✓
Ⓣⓗⓔ ⓚⒶⓣ⇒✗
Keep in mind that the token-level encoding is not unique for a given string sequence, so the above LLM examples will have many representations. The number of representations compounds with the size of the reference strings e.g., all the possible misspellings of ” cat”. Furthermore, the LLM will output a distribution over good and bad sequences, so we’d like to summarize them e.g., by measuring what percentage of sequences are good.
Problem: Testing LLMs
As test designers, our goal is to quantitatively measure some aspect of the LLM’s behavior. As we are studying a general notion of tests, we’ll introduce a small amount of formalism to argue our points. Let us call a test, (T), which takes a model, (M), and returns a boolean represented with 0 (bad answer) or 1 (good answer).
$$T: M → {0, 1}$$
For classification tasks, (T) represents whether the model, (M), classified a particular example correctly; the average of these tests is reported with test accuracy. Since correct classification boils down to the predicted class ((y_text{pred}:=M(x))) matching the ground-truth class ((y)), this test can be implemented in one line of code.
y_pred == y
What does (T) look like for LLMs? Let’s say we want to test if “The” is followed by “ cat”. Constructing such a test is straightforward, because we can just check if the statement is true. We can imagine (x) representing “The” and (y) representing “ cat”. If (y) is sampled from some distribution (i.e., it’s a random variable), we can get many samples to compute the mean score. Depending on the application, we may or may not be interested in including all the encodings discussed previously as well as possible variations of the base pattern e.g., misspellings.
Because of the potentially massive number of sequences involved in a test, LLM tests are both more difficult to express and evaluate, leading to tests with insufficient coverage. For example, if we happened to miss some prompt that does lead to “ cat”, our test had a false negative—it concluded it was not possible when it actually was. If we were to check if “ cat” is the most likely string following “The”, we may get false positives in the omitted cases where “ kAt” was more likely. The test designer must carefully consider trading off such sources of error with the implementation and execution complexity of the test.
With traditional string-level APIs, it’s difficult to make testing trade-offs without rewriting the testing logic altogether—one has to write testing code that explicitly samples from the distribution of interest (e.g., the choice of encodings and misspellings). For example, a privacy-oriented user would want you to be reasonably sure that the LLM couldn’t emit their private information, even with the presence of encoding or misspelling artifacts. Such a minor change in the test’s scope would result in dramatic changes to the underlying test implementation. To make matters worse, testing becomes even more difficult when the base pattern of interest is a combinatorial object, such as integers, dates, URL strings, and phone numbers—sets too large to enumerate.
Example: Does GPT-2XL know George Washington’s birth date?
To give a concrete example of false positives and false negatives, let’s consider a simple test of knowledge: Does the LLM know George Washington’s birth date? As shown in the figure below, we formulate this ‘test’ by asking the model to rank 4 choices. Such multiple-choice questions are common in today’s benchmark suites because they are simple to implement. However, 4 choices do not cover all birth dates; what if the model was lucky enough to eliminate the other 3 answers and just guess? That would be a false positive. As shown below, the correct date of February 22, 1732, is chosen by the model because it is the most likely; thus this test concludes the model does know the birth date.
We can also try free response, as shown in in the following figure. However, the most likely reply is not a date and thus penalizes the model for being more general than the test task—a possible false negative. “this day in 1732” and “a farm” are reasonable completions for the fill-in-the-blank, yet an automated test system would mark them as not matching the solution set.
A more natural alternative, and one that we explore via our work in ReLM (MLSys ’23), would be to only consider answers that follow a specific date-related format. The way we evaluate this query is by constraining generation to be of the form <Month> <Day>, <Year>, as if we had a “complete” multiple choice solution set, which is too large to enumerate. Because this pattern contains exactly all the solutions of interest, the test minimizes spurious conclusions due to false positives and false negatives. In doing so, we confirm a true negative—GPT-2XL believes George Washington was born on July 4, 1732. That’s of course factually incorrect, but we didn’t trick ourselves into thinking the LLM knew the answer when it didn’t.
While we don’t have the space to exactly write out how to run these queries in ReLM, you can rest assured that you’ll find the above example in our code.
The Case for ReLM
Regular expressions describe the regular languages and are a way of specifying text patterns. Many text-processing tools, such as grep, use regular expressions to locate patterns in text. At a high level, regular languages can describe patterns using the primitives of string literals, disjunction (“OR”), and repetitions. For the purpose of this blog, you can think of regular languages as allowing you to interpolate between a 4-way multiple choice (e.g., A OR B OR C OR D) and one with a combinatorial explosion of choices in a free-response (e.g., all strings of length (N)). At the implementation level, regular expressions can be expressed with an equivalent directed graph, called an automaton, that represents all sequences via the edge transitions in the graph.
ReLM is a Regular Expression engine for Language Models. As shown below, ReLM is an automaton-based constrained decoding system on top of the LLM. Users of ReLM construct queries that encompass the test pattern and how to execute it. Because the user explicitly describes the pattern of interest, ReLM can avoid doing extra work that results in false negatives. Additionally, since the user describes variations of the pattern (e.g., encodings and misspellings), ReLM can cover often-ignored elements in the test set, avoiding false positives. We can essentially describe any pattern or mutation of the pattern as long as the effects can be correctly propagated to the final automaton. Thankfully, there is a rich theory on ways to perform operations on automata (e.g., including misspellings and rewrites), which we utilize when compiling the final automaton. Thus, the user can 1) exactly specify large sets of interest and 2) cover the tokenization artifacts mentioned in the introduction.
Since the same query pattern can be used for many execution parameters, a single test encoded as a regular expression can lead to a variety of analyses. For example, the query in the above figure could be modified to include all misspellings of the base pattern as well as all the encodings. Additionally, the user can choose between sampling from the test set or finding the most likely sequence in it. Our paper’s results exploring queries surrounding memorization (extracting URLs), gender bias (measuring distributional bias in professions), toxicity (extracting offensive words), and language understanding (completing the correct answer) show that ReLM achieves up to (15times) higher system efficiency in extracting memorized URLs, (2.5times) data efficiency in extracting offensive content, and increased statistical and prompt-tuning coverage compared to state-of-the-art ad-hoc queries.
Our results indicate that subtle differences in query specification can yield dramatically different results. For example, we find that randomly sampling from a URL prefix “https://www.” tends to generate invalid or duplicated URLs. ReLM avoids such inefficiency by returning strings matching the valid URL pattern sorted by likelihood. Likewise, searching over the space of all encodings as well as misspellings enables the (2.5times) data efficiency in extracting toxic content from the LLM and results in different results on the gender bias task. Finally, we can recover prompt tuning behavior on the LAMBADA dataset by modifying the regular expression pattern, demonstrating that even language understanding tasks can benefit from such pattern specification.
Conclusion
In this blog, we outlined why it’s important to think of LLM tests in terms of patterns rather than individual sequences. Our work introduces ReLM, a Regular Expression engine for Language Models, to enable test writers to easily write LLM tests that can be described via pattern matching. If you’re interested in learning more about ReLM and how it can reduce the burden of LLM validation, please check out our paper (MLSys ’23) as well as our open-source code.
DISCLAIMER: All opinions expressed in this post are those of the author and do not represent the views of CMU.
How can we be better prepared for the next pandemic?
Patient data collected by groups such as hospitals and health agencies is a critical tool for monitoring and preventing the spread of disease. Unfortunately, while this data contains a wealth of useful information for disease forecasting, the data itself may be highly sensitive and stored in disparate locations (e.g., across multiple hospitals, health agencies, and districts).
In this post we discuss our research on federated learning, which aims to tackle this challenge by performing decentralized learning across private data silos. We then explore an application of our research to the problem of privacy-preserving pandemic forecasting—a scenario where we recently won a 1st place, $100k prize in a competition hosted by the US & UK governments—and end by discussing several directions of future work based on our experiences.
Part 1: Privacy, Personalization, and Cross-Silo Federated Learning
Federated learning (FL) is a technique to train models using decentralized data without directly communicating such data. Typically:
a central server sends a model to participating clients;
the clients train that model using their own local data and send back updated models; and
the server aggregates the updates (e.g., via averaging, as in FedAvg)
However, while significant attention has been given to cross-device FL (e.g., learning across large networks of devices such as mobile phones), the area of cross-silo FL (e.g., learning across a handful of data silos such as hospitals or financial institutions) is relatively under-explored, and it presents interesting challenges in terms of how to best model federated data and mitigate privacy risks. In Part 1.1, we’ll examine a suitable privacy granularity for such settings, and in Part 1.2, we’ll see how this interfaces with model personalization, an important technique in handling data heterogeneity across clients.
1.1. How should we protect privacy in cross-silo federated learning?
Although the high-level federated learning workflow described above can help to mitigate systemic privacy risks, pastwork suggests that FL’s data minimization principle alone isn’t sufficient for data privacy, as the client models and updates can still reveal sensitive information.
This is where differential privacy (DP) can come in handy. DP provides both a formal guarantee and an effective empirical mitigation to attacks like membership inference and data poisoning. In a nutshell, DP is a statistical notion of privacy where we add randomness to a query on a “dataset” to create quantifiable uncertainty about whether any one “data point” has contributed to the query output. DP is typically measured by two scalars ((varepsilon, delta))—the smaller, the more private.
In the above, “dataset” and “data point” are in quotes because privacy granularity matters. In cross-device FL, it is common to apply “client-level DP” when training a model, where the federated clients (e.g., mobile phones) are thought of as “data points”. This effectively ensures that each participating client/mobile phone user remains private.
However, while client-level DP makes sense for cross-device FL as each client naturally corresponds to a person, this privacy granularity may not be suitable for cross-silo FL, where there are fewer (2-100) ‘clients’ but each holds many data subjects that require protection, e.g., each ‘client’ may be a hospital, bank, or school with many patient, customer, or student records.
In our recent work (NeurIPS’22), we instead consider the notion of “silo-specific example-level DP” in cross-silo FL (see figure above). In short, this says that the (k)-th data silo may set its own ((varepsilon_k, delta_k)) example-level DP target for any learning algorithm with respect to its local dataset.
This notion is better aligned with real-world use cases of cross-silo FL, where each data subject contributes a single “example”, e.g., each patient in a hospital contributes their individual medical record. It is also very easy to implement: each silo can just run DP-SGD for local gradient steps with calibrated per-step noise. As we discuss below, this alternate privacy granularity affects how we consider modeling federated data to improve privacy/utility trade-offs.
1.2. The interplay of privacy, heterogeneity, and model personalization
Let’s now look at how this privacy granularity may interface with model personalization in federated learning.
Model personalization is a common technique used to improve model performance in FL when data heterogeneity (i.e. non-identically distributed data) exists between data silos.1 Indeed, existingbenchmarks suggest that realistic federated datasets may be highly heterogeneous and that fitting separate local models on the federated data are already competitive baselines.
When considering model personalization techniques under silo-specific example-level privacy, we find that a unique trade-off may emerge between the utility costs from privacy and data heterogeneity (see figure below):
As DP noises are added independently by each silo for its own privacy targets, these noises are reflected in the silos’ model updates and can thus be smoothed out when these updates are averaged (e.g. via FedAvg), leading to a smaller utility drop from DP for the federated model.
On the other hand, federation also means that the shared, federated model may suffer from data heterogeneity (“one size does not fit all”).
This “privacy-heterogeneity cost tradeoff” is interesting because it suggests that model personalization can play a key and distinct role in cross-silo FL. Intuitively, local training (no FL participation) and FedAvg (full FL participation) can be viewed as two ends of a personalization spectrum with identical privacy costs—silos’ participation in FL itself does not incur privacy costs due to DP’s robustness to post-processing—and various personalization algorithms (finetuning, clustering, …) are effectively navigating this spectrum in different ways.
If local training minimizes the effect of data heterogeneity but enjoys no DP noise reduction, and contrarily for FedAvg, it is natural to wonder whether there are personalization methods that lie in between and achieve better utility. If so, what methods would work best?
Our analysis points to mean-regularized multi-task learning (MR-MTL) as a simple yet particularly suitable form of personalization. MR-MTL simply asks each client (k) to train its own local model (w_k), regularize it towards the mean of others’ models (bar w) via a penalty (fraclambda 2 | w_k – bar w |_2^2 ), and keep (w_k) across rounds (i.e. client is stateful). The mean model (bar w) is maintained by the FL server (as in FedAvg) and may be updated in every round. More concretely, each local update step takes the following form:
The hyperparameter (lambda) serves as a smooth knob between local training and FedAvg: (lambda = 0) recovers local training, and a larger (lambda) forces the personalized models to be closer to each other (intuitively, “federate more”).
MR-MTL has some nice properties in the context of private cross-silo FL:
Noise reduction is attained throughout training via the soft proximity constraint towards an averaged model;
The mean-regularization itself has no privacy overhead;2 and
(lambda) provides a smooth interpolation along the personalization spectrum.
Why is the above interesting? Consider the following experiment where we try a range of (lambda) values roughly interpolating local training and FedAvg. Observe that we could find a “sweet spot” (lambda^ast) that outperforms both of the endpoints under the same privacy cost. Moreover, both the utility advantage of MR-MTL((lambda^ast)) over the endpoints, and (lambda^ast) itself, are larger under privacy; intuitively, this says that silos are encouraged to “federate more” for noise reduction.
The above provides rough intuition on why MR-MTL may be a strong baseline for private cross-silo FL and motivates this approach for a practical pandemic forecasting problem, which we discuss in Part 2. Our full paper delves deeper into the analyses and provides additional results and discussions!
Part 2: Federated Pandemic Forecasting at the US/UK PETs Challenge
The pandemic forecasting problem asks the following: Given a person’s demographic attributes (e.g. age, household size), locations, activities, infection history, and the contact network, what is the likelihood of infection in the next (t_text{pred}=7) days? Can we make predictions while protecting the privacy of individuals? Moreover, what if the data are siloed across administrative regions?
There’s a lot to unpack in the above. First, the pandemic outbreak problem follows a discrete-time SIR model (Susceptible → Infectious → Recovered) and we begin with a subset of the population infected. Subsequently,
Each person goes about their usual daily activities and gets into contact with others (e.g. at a shopping mall)—this forms a contact graph where individuals are nodes and direct contacts are edges;
Each person may get infected with different risk levels depending on a myriad of factors—their age, the nature and duration of their contact(s), their node centrality, etc.; and
Such infection can also be asymptomatic—the individual can appear in the S state while being secretly infectious.
The challenge dataset models a pandemic outbreak in Virginia and contains roughly 7.7 million nodes (persons) and 186 million edges (contacts) with health states over 63 days; so the actual contact graph is fairly large but also quite sparse.
There are a few extra factors that make this problem challenging:
Data imbalance: less than 5% of people are ever in the I or R state and roughly 0.3% of people became infected in the final week.
Data silos: the true contact graph is cut along administrative boundaries, e.g., by grouped FIPS codes/counties. Each silo only sees a local subgraph, but people may still travel and make contacts across multiple regions! in In the official evaluation, the population sizes can also vary by more than 10(times) across silos.
Temporal modeling: we are given the first (t_text{train} = 56) days of each person’s health states (S/I/R) and asked to predict individual infections any time in the subsequent ( t_text{pred} = 7 ) days. What is a training example in this case?How should we perform temporal partitioning? How does this relate to privacy accounting?
Graphs generally complicate DP: we are often used to ML settings where we can clearly define the privacy granularity and how it relates to an actual individual (e.g. medical images of patients). This is tricky with graphs: people can make different numbers of contacts each of different natures, and their influence can propagate throughout the graph. At a high level (and as specified by the scope of sensitive data of the competition), what we care about is known as node-level DP—the model output is “roughly the same” if we add/remove/replace a node, along with its edges.
2.2. Applying MR-MTL with silo-specific example-level privacy
One clean approach to the pandemic forecasting problem is to just operate on the individual level and view it as (federated) binary classification: if we could build a feature vector to summarize an individual, then risk scores are simply the sigmoid probabilities of near-term infection.
Of course, the problem lies in what that feature vector (and the corresponding label) is—we’ll get to this in the following section. But already, we can see that MR-MTL with silo-specific example-level privacy (from Part 1) is a nice framework for a number of reasons:
Model personalization is likely needed as the silos are large and heterogeneous by construction (geographic regions are unlike to all be similar).
Privacy definition: There are a small number of clients, but each holds many data subjects, and client-level DP isn’t suitable.
Usability, efficiency, and scalability: MR-MTL is remarkably easy to implement with minimal resource overhead (over FedAvg and local training). This is crucial for real-world applications.
Adaptability and explainability: The framework is highly adaptable to any learning algorithm that can take DP-SGD-style updates. It also preserves the explainability of the underlying ML algorithm as we don’t obfuscate the model weights, updates, or predictions.
It is also helpful to look at the threat model we might be dealing with and how our framework behaves under it; the interested reader may find more details in the extended post!
2.3. Building training examples
We now describe how to convert individual information and the contact network into a tabular dataset for every silo ( k ) with ( n_k ) nodes.
Recall that our task is to predict the risk of infection of a person within ( t_text{pred} = 7) days, and that each silo only sees its local subgraph. We formulate this via a silo-specific set of examples ( ( X_k in mathbb R^{n_k times d}, Y_k in mathbb {0, 1}^{n_k} ) ), where the features ( {X_k^{(i)} in mathbb R^d} ) describe the neighborhood around a person ( i ) (see figure) and binary label ( {Y_k^{(i)}} ) denotes if the person become infected in the next ( t_text{pred} ) days.
Each example’s features ( X_k^{(i)} ) consist of the following:
(1) Individual features: Basic (normalized) demographic features like age, gender, and household size; activity features like working, school, going to church, or shopping; and the individual’s infection history as concatenated one-hot vectors (which depends on how we create labels; see below).
(2) Contact features: One of our key simplifying heuristics is that each node’s (ell)-hop neighborhood should contain most of the information we need to predict infection. We build the contact features as follows:
Every sampled neighbor (v) of a node (u) is encoded using its individual features (as above) along with the edge features describing the contact—e.g. the location, the duration, and the activity type.
We use iterative neighborhood sampling (figure above), meaning that we first select a set of ( S_1 ) 1-hop neighbors, and then sample (S_2) 2-hop neighbors adjacent to those 1-hop neighbors, and so on. This allows reusing 1-hop edge features and keeps the feature dimension (d) low.
We also used deterministic neighborhood sampling—the same person always takes the same subset of neighbors. This drastically reduces computation as the graph/neighborhoods can now be cached. For the interested reader, this also has implications on privacy accounting.
The figure above illustrates the neighborhood feature vector that describes a person and their contacts for the binary classifier! Intriguingly, this makes the per-silo models a simplified variant of a graph neural network (GNN) with a single-step, non-parameterized neighborhood aggregation and prediction (cf. SGC models).
For the labels ( Y_k^{(i)} ), we deployed a random infection windowstrategy:
Pick a window size ( t_text{window} ) (say 21 days);
Select a random day (t’) within the valid range ((t_text{window} le t’ le t_text{train} – t_text{pred}));
Encode the S/I/R states in the past window from (t’) for every node in the neighborhood as individual features;
The label is then whether person (i) is infected in any of the next (t_text{pred}) days from (t’).
Our strategy implicitly assumes that a person’s infection risk is individual: whether Bob gets infected depends only on his own activities and contacts in the past window. This is certainly not perfect as it ignores population-level modeling (e.g. denser areas have higher risks of infection), but it makes the ML problem very simple: just plug-in existing tabular data modeling approaches!
2.4. Putting it all together
We can now see our solution coming together: each silo builds a tabular dataset using neighborhood vectors for features and infection windows for labels, and each silo trains a personalized binary classifier under MR-MTL with silo-specific example-level privacy. We complete our method with a few additional ingredients:
Privacy accounting. We’ve so far glossed over what silo-specific “example-level” DP actually means for an individual. We’ve put more details in the extended blog post, and the main idea is that local DP-SGD can give “neighborhood-level” DP since each node’s enclosing neighborhood is fixed and unique, and we can then convert it to node-level DP (our privacy goal from Part 2.1) by carefully accounting for how a certain node may appear in other nodes’ neighborhoods.
Noisy SGD as an empirical defense. While we have a complete framework for providing silo-specific node-level DP guarantees, for the PETs challenge specifically we decided to opt for weak DP ((varepsilon > 500)) as an empirical protection, rather than a rigorous theoretical guarantee. While some readers may find this mildly disturbing at first glance, we note that the strength of protection depends on the data, the models, the actual threats, the desired privacy-utility trade-off, and several crucial factors linking theory and practice which we outline in the extended blog. Our solution was in turn attacked by several red teams to test for vulnerabilities.
Model architecture: simple is good. While the model design space is large, we are interested in methods amenable to gradient-based private optimization (e.g. DP-SGD) and weight-space averaging for federated learning. We compared simple logistic regression and a 3-layer MLP and found that the variance in data strongly favors linear models, which also have benefits in privacy (in terms of limited capacity for memorization) as well as explainability, efficiency, and robustness.
Computation-utility tradeoff for neighborhood sampling. While larger neighborhood sizes (S) and more hops (ell) better capture the original contact graph, they also blow up the computation and our experiments found that larger (S) and (ell) tend to have diminishing returns.
Data imbalance and weighted loss. Because the data are highly imbalanced, training naively will suffer from low recall and AUPRC. While there are established over-/under-sampling methods to deal with such imbalance, they, unfortunately, make privacy accounting a lot trickier in terms of the subsampling assumption or the increased data queries. We leveraged the focal loss from the computer vision literature designed to emphasize hard examples (infected cases) and found that it did improve both the AUPRC and the recall considerably.
The above captures the essence of our entry to the challenge. Despite the many subtleties in fully building out a working system, the main ideas were quite simple: train personalized models with DP and add some proximity constraints!
Takeaways and Open Challenges
In Part 1, we reviewed our NeurIPS’22 paper that studied the application of differential privacy in cross-silo federated learning scenarios, and in Part 2, we saw how the core ideas and methods from the paper helped us develop our submission to the PETs prize challenge and win a 1st place in the pandemic forecasting track. For readers interested in more details—such as theoretical analyses, hyperparameter tuning, further experiments, and failure modes—please check out our full paper. Our work also identified several important future directions in this context:
DP under data imbalance. DP is inherently a uniform guarantee, but data imbalance implies that examples are not created equal—minority examples (e.g., disease infection, credit card fraud) are more informative, and they tend to give off (much) larger gradients during model training. Should we instead do class-specific (group-wise) DP or refine “heterogeneousDP” or “outlier DP” notions to better cater to the discrepancy between data points?
Graphs and privacy. Another fundamental basis of DP is that we could delineate what is and isn’t an individual. But as we’ve seen, the information boundaries are often nebulous when an individual is a node in a graph (think social networks and gossip propagation), particularly when the node is arbitrarily well connected. Instead of having rigid constraints (e.g., imposing a max node degree and accounting for it), are there alternative privacy definitions that offer varying degrees of protection for varying node connectedness?
Scalable, private, and federated trees for tabular data. Decision trees/forests tend to work extremely well for tabular data such as ours, even with data imbalance, but despite recent progress, we argue that they are not yet mature under private and federated settings due to some underlying assumptions.
Novel training frameworks. While MR-MTL is a simple and strong baseline under our privacy granularity, it has clear limitations in terms of modeling capacity. Are there other methods that can also provide similar properties to balance the emerging privacy-heterogeneity cost tradeoff?
Honest privacy cost of hyperparameter search. When searching for better frameworks, the dependence on hyperparameters is particularly interesting: our full paper (section 7) made a surprising but somewhat depressing observation that the honest privacy cost of just tuning (on average) 10 configurations (values of (lambda) in this case) may already outweigh the utility advantage of the best tune MR-MTL((lambda^ast)). What does this mean if MR-MTL is already a strong baseline with just a single hyperparameter?
Example of embodied commonsense reasoning. A robot proactively identifies a remote on the floor and knows it is out of place without instruction. Then, the robot figures out where to place it in the scene and manipulates it there.
For robots to operate effectively in the world, they should be more than explicit step-by-step instruction followers. Robots should take actions in situations when there is a clear violation of the normal circumstances and be able to infer relevant context from partial instruction. Consider a situation where a home robot identifies a remote control which has fallen to the kitchen floor. The robot should not need to wait until a human instructs the robot to “pick the remote control off the floor and place it on the coffee table”. Instead, the robot should understand that the remote on the floor is clearly out of place, and act to pick it up and place it in a reasonable location. Even if a human were to spot the remote control first and instruct the agent to “put away the remote that is on the living room floor”, the robot should not require a second instruction for where to put the remote, but instead infer from experience that a reasonable location would be, for example, on the coffee table. After all, it would become tiring for a home robot user to have to specify every desire in excruciating detail (think about for each item you want the robot to move, specifying an instruction such as “pick up the shoes beneath the coffee table and place them next to the door, aligned with the wall”).
The type of reasoning that would permit such partial or self-generated instruction following involves a deep sense of how things in the world (objects, physics, other agents, etc.) ought to behave. Reasoning and acting of this kind are all aspects of embodied commonsense reasoning and are vastly important for robots to act and interact seamlessly in the physical world.
There has been much work on embodied agents that follow detailed step-by-step instructions, but less on embodied commonsense reasoning, where the task involves learning how to perceive and act without explicit instruction. One task in which to study embodied commonsense reasoning is that of tidying up, where the agent must identify objects which are out of their natural locations and act in order bring the identified objects to plausible locations. This task combines many desirable capabilities of intelligent agents with commonsense reasoning of object placements. The agent must search in likely locations for objects to be displaced, identify when objects are out of their natural locations in the context of the current scene, and figure out where to reposition the objects so that they are in proper locations – all while intelligently navigating and manipulating.
In our recent work, we proposeTIDEE, an embodied agent that can tidy up never-before-seen rooms without any explicit instruction. TIDEE is the first of its kind for its ability to search a scene for out of place objects, identify where in the scene to reposition the out of place objects, and effectively manipulate the objects to the identified locations. We’ll walk through how TIDEE is able to do this in a later section, but first let’s describe how we create a dataset to train and test our agent for the task of tidying up.
Creating messy homes
To create clean and messy scenes for our agent to learn from for what constitutes a tidy scene and what constitute a messy scene, we use a simulation environment called ai2thor. Ai2thor is an interactive 3D environment of indoor scenes that allows objects to be picked up and moved around. The simulator comes ready with 120 scenes of kitchens, bathrooms, living rooms, and bedrooms with over 116 object categories (and significantly more object instances) scattered throughout. Each of the scenes comes with a default initialization of object placements that are meticulously chosen by humans to be highly structured and “neat”. These default object locations make up our “tidy” scenes for providing our agent examples of objects in their natural locations. To create messy scenes, we apply forces to a subset of the objects with a random direction and magnitude (we “throw” the objects around) so they end up in uncommon locations and poses. You can see below some examples of objects which have been moved out of place.
Next, let’s see how TIDEE learns from this dataset to be able to tidy up rooms.
How does TIDEE work?
We give our agent a depth and RGB sensor to use for perceiving the scene. From this input, the agent must navigate around, detect objects, pick them up, and place them. The goal of the tidying task is to rearrange a messy room back to a tidy state.
TIDEE tidies up rooms in three phases. In the first phase, TIDEE explores around the room and runs an out of place object detector at each time step until one is identified. Then, TIDEE navigates over to the object, and picks it up. In the second phase, TIDEE uses graph inference in its joint external graph memory and scene graph to infer a plausible receptacle to place the object on within the scene. It then explores the scene guided by a visual search network that suggests where the receptacle may be found if TIDEE has not identified it in a previous time step. For navigation and keeping track objects, TIDEE maintains a obstacle map of the scene and stores in memory the estimated 3D centroids of previously detected objects.
The out of place detector uses visual and relational language features to determine if an object is in or out of place in the context of the scene. The visual features for each object are obtained from an off-the-shelf object detector, and the relational language features are obtained by giving predicted 3D relations of the objects (e.g. next to, supported by, above, etc.) to a pretrained language model. We combine the visual and language features to classify whether each detected object is in or out of place. We find that combining the visual and relational modalities performs best for out of place classification over using a single modality.
To infer where to place an object once it has picked up, TIDEE includes a neural graph module which is trained to predict plausible object placement proposals of objects. The modules works by passing information between the object to be placed, a memory graph encoding plausible contextual relations from training scenes, and a scene graph encoding the object-relation configuration in the current scene. For our memory graph, we take inspiration from “Beyond Categories: The Visual Memex Model for Reasoning About Object Relationships” by Tomasz Malisiewicz and Alexei A. Efros (2009), which models instance-level object features and their relations to provide more complete appearance-based context. Our memory graph consists of the tidy object instances in the training to provide fine-grain contextualization of tidy object placements. We show in the paper that this fine-grain visual and relational information is important for TIDEE to place objects in human-preferred locations.
To search for objects that have not been previously found, TIDEE uses a visual search network that takes as input the semantic obstacle map and a search category and predicts the likelihood of the object being present at each spatial location in the obstacle map. The agent then searches in those likely locations for the object of interest.
Combining all the above modules provides us with a method to be able to detect out of place objects, infer where they should go, search intelligently, and navigate & manipulate effectively. In the next section, we’ll show you how well our agent performs at tidying up rooms.
How good is TIDEE at tidying up?
Using a set of messy test scenes that TIDEE has never seen before, we task our agent with reconfiguring the messy room to a tidy state. Since a single object may be tidy in multiple locations within a scene, we evaluate our method by asking humans whether they prefer the placements of TIDEE compared to baseline placements that do not make use of one or more of TIDEE’s commonsense priors. Below we show that TIDEE placements are significantly preferred to the baseline placements, and even competitive with human placements (last row).
We additionally show that the placements of TIDEE can be customized based on user preferences. For example, based on user input such as “I never want my alarm on the desk”, we can use online learning techniques to change the output from the model that alarm clock being on the desk is out of place (and should be moved). Below we show some examples of locations and relations of alarm clocks that were predicted as being in the correct locations (and not out of place) within the scene after our initial training. However, after doing the user-specified finetuning, our network predicts that the alarm clock on the desk is out of place and should be repositioned.
We also show that a simplified version of TIDEE can generalize to task of rearrangement, where the agent sees the original state of the objects, then some of the objects get rearranged to new locations, and the agent must rearrange the objects back to their original state. We outperform the previous state of the art model that utilizes semantic mapping and reinforcement learning, even with noisy sensor measurements.
Summary
In this article, we discussed TIDEE, an embodied agent that uses commonsense reasoning to tidy up novel messy scenes. We introduce a new benchmark to test agents in their ability to clean up messy scenes without any human instruction. To check out our paper, code, and more, please visit our website at https://tidee-agent.github.io/.
Also, feel free to shoot me an email at gsarch@andrew.cmu.edu! I would love to chat!
Figure 1. This blog post discusses the effectiveness of black-box model explanations in aiding end users to make decisions. We observe that explanations do not in fact help with concrete applications such as fraud detection and paper matching for peer review. Our work further motivates novel directions for developing and evaluating tools to support human-ML interactions.
Model explanations have been touted as crucial information to facilitate human-ML interactions in many real-world applications where end users make decisions informed by ML predictions. For example, explanations are thought to assist model developers in identifying when models rely on spurious artifacts and to aid domain experts in determining whether to follow a model’s prediction. However, while numerous explainable AI (XAI) methods have been developed, XAI has yet to deliver on this promise. XAI methods are typically optimized for diverse but narrow technical objectives disconnected from their claimed use cases. To connect methods to concrete use cases, we argued in our Communications of ACM paper [1] for researchers to rigorously evaluate how well proposed methods can help real users in their real-world applications.
Towards bridging this gap, our group has since completed two collaborative projects where we worked with domain experts in e-commerce fraud detection and paper matching for peer review. Through these efforts, we’ve gleaned the following two insights:
Existing XAI methods are not useful for decision-making. Presenting humans with popular, general-purpose XAI methods does not improve their performance on real-world use cases that motivated the development of these methods. Our negative findings align with those of contemporaneous works.
Rigorous, real-world evaluation is important but hard. These findings were obtained through user studies that were time-consuming to conduct.
We believe that each of these insights motivates a corresponding research direction to support human-ML interactions better moving forward. First, beyond methods that attempt to explain the ML model itself, we should consider a wider range of approaches that present relevant task-specific information to human decision-makers; we refer to these approaches as human-centered ML (HCML) methods [10]. Second, we need to create new workflows to evaluate proposed HCML methods that are both low-cost and informative of real-world performance.
In this post, we first outline our workflow for evaluating XAI methods. We then describe how we instantiated this workflow in two domains: fraud detection and peer review paper matching. Finally, we describe the two aforementioned insights from these efforts; we hope these takeaways will motivate the community to rethink how HCML methods are developed and evaluated.
How do you rigorously evaluate explanation methods?
In our CACM paper [1], we introduced a use-case-grounded workflow to evaluate explanation methods in practice—this means showing that they are ‘useful,’ i.e., that they can actually improve human-ML interactions in the real-world applications that they are motivated by. This workflow contrasts with evaluation workflows of XAI methods in prior work, which relied on researcher-defined proxy metrics that may or may not be relevant to any downstream task. Our proposed three-step workflow is based on the general scientific method:
Step 1:Define a concrete use case. To do this, researchers may need to work closely with domain experts to define a task that reflects the practical use case of interest.
Step 2:Select explanation methods for evaluation. While selected methods might be comprised of popular XAI methods, the appropriate set of methods is to a large extent application-specific and should also include relevant non-explanation baselines.
Step 3:Evaluate explanation methods against baselines. While researchers should ultimately evaluate selected methods through a user study with real-world users, researchers may want to first conduct cheaper, noisier forms of evaluation to narrow down the set of methods in consideration (Figure 2).
Instantiating the workflow in practice
We collaborated with experts from two domains (fraud detection and peer review paper matching) to instantiate this use-case-grounded workflow and evaluate existing XAI methods:
Domain 1: Fraud detection [3]. We partnered with researchers at Feedzai, a financial start-up, to assess whether providing model explanations improved the ability of fraud analysts to detect fraudulent e-commerce transactions. Given that we had access to real-world data (i.e., historical e-commerce transactions for which we had ground truth answers of whether the transaction was fraudulent) and real users (i.e., fraud analysts), we directly conducted a user study in this context. An example of the interface shown to analysts is in Figure 3. We compared analysts’ average performance when shown different explanations to a baseline setting where they were only provided the model prediction. We ultimately found that none of the popular XAI methods we evaluated (LIME, SHAP, and Tree Interpreter) resulted in any improvement in the analysts’ decisions compared to the baseline setting (Figure 5, left). Evaluating these methods with real users additionally posed many logistical challenges because fraud analysts took time from their regular day-to-day work to periodically participate in our study.
Domain 2: Peer review paper matching [4]. We collaborated with Professor Nihar Shah (CMU), an expert in peer review, to investigate what information could help meta-reviewers of a conference better match submitted papers to suitable reviewers. Learning from our prior experience, we first conducted a user study using proxy tasks and users, which we worked with Professor Shah to design as shown in Figure 4. In this proxy setting, we found that providing explanations from popular XAI methods in fact led users to be more confident—-the majority of participants shown highlights from XAI methods believed the highlighted information was helpful—yet, they made statistically worse decisions(Figure 5 right)!
How can we better support human-ML interactions?
Through these collaborations, we identified two important directions for future work, which we describe in more detail along with our initial efforts in each direction.
We need to develop methods for specific use cases. Our results suggest that explanations from popular, general-purpose XAI methods can both hurt decision-making while making users overconfident. These findings have also been observed in multiple contemporaneous works (e.g., [7,8,9]). Researchers, instead, need to consider developing human-centered ML (HCML) methods [10] tailored for each downstream use case. HCML methods are any approach that provides information about the particular use case and context that can inform human decisions.
Our contributions: In the peer review matching setting, we proposed an HCML method designed in tandem with a domain expert [4]. Notably, our method is not a model explanation approach, as it highlights information in the input data, specifically sentences and phrases that are similar in the submitted paper and the reviewer profile. Figure 6 compares the text highlighted using our method to the text highlighted using existing methods. Our method outperformed both a baseline where there was no explanation and the model explanation condition (Figure 5, right). Based on these positive results, we plan to move evaluations of our proposed method to more realistic peer review settings. Further, we performed an exploratory study to better understand how people interact with information provided by HCML methods as a first step towards coming up with a more systematic approach to devise task-specific HCML methods [5].
We need more efficient evaluation pipelines. While user studies conducted in a real-world use case and with real users are the ideal way to evaluate HCML methods, it is a time- and resource-consuming process. We highlight the need for more cost-effective evaluations that can be utilized to narrow down candidate HCML methods and still implicate the downstream use case. One option is to work with domain experts to design a proxy task as we did in the peer review setting, but even these studies require careful consideration of the generalizability to the real-world use case.
Our contributions. We introduced an algorithmic-based evaluation called simulated user evaluation (SimEvals) [2]. Instead of conducting studies on proxy tasks, researchers can train SimEvals, which are ML models that serve as human proxies. SimEvals more faithfully reflects aspects of real-world evaluation because their training and evaluation data are instantiated on the same data and task considered in real-world studies. To train SimEvals, the researcher first needs to generate a dataset of observation-label pairs. The observation corresponds to the information that would be presented in a user study (and critically includes the HCML method), while the output is the ground truth label for the use case of interest. For example, in the fraud detection setting, the observation would consist of both the e-commerce transaction and ML model score shown in Figure 3(a) along with the explanation shown in Figure 3(b). The ground truth label is whether or not the transaction was fraudulent. SimEvals are trained to predict a label given an observation and their test set accuracies can be interpreted as a measure of whether the information contained in the observation is predictive for the use case.
We not only evaluated SimEvals on a variety of proxy tasks but also tested SimEvals in practice by working with Feedzai, where we found results that corroborate the negative findings from the user study [6]. Although SimEvals should not replace user studies because SimEvals are not designed to mimic human decision-making, these results suggest that SimEvals could be initially used to identify more promising explanations (Figure 6).
Conclusion
In summary, our recent efforts motivate two ways the community should rethink how to support human-ML interactions: (1) we need to replace general-purpose XAI techniques with HCML methods tailored to specific use cases, and (2) creating intermediate evaluation procedures that can help narrow down the HCML methods to evaluate in more costly settings.
For more information about the various papers mentioned in this blog post, see the links below:
[1] Chen, V., Li, J., Kim, J. S., Plumb, G., & Talwalkar, A. Interpretable Machine Learning. Communications of the ACM, 2022. (link)
[2] Chen, V., Johnson, N., Topin, N., Plumb, G., & Talwalkar, A. Use-case-grounded simulations for explanation evaluation. NeurIPS, 2022. (link)
[3] Amarasinghe, K., Rodolfa, K. T., Jesus, S., Chen, V., Balayan, V., Saleiro, P., Bizzaro, P., Talwalkar, A. & Ghani, R. (2022). On the Importance of Application-Grounded Experimental Design for Evaluating Explainable ML Methods. arXiv. (link)
[4] Kim, J. S., Chen, V., Pruthi, D., Shah, N., Talwalkar, A.Assisting Human Decisions in Document Matching. arXiv. (link)
[5] Chen, V., Liao, Q. V., Vaughan, J. W., & Bansal, G. (2023). Understanding the Role of Human Intuition on Reliance in Human-AI Decision-Making with Explanations. arXiv. (link)
[6] Martin, A., Chen, V., Jesus, S., Saleiro, P. A Case Study on Designing Evaluations of ML Explanations with Simulated User Studies. arXiv. (link)
[7] Bansal, G., Wu, T., Zhou, J., Fok, R., Nushi, B., Kamar, E., Ribeiro, M. T. & Weld, D. Does the whole exceed its parts? the effect of ai explanations on complementary team performance. CHI, 2021. (link)
[8] Adebayo, J., Muelly, M., Abelson, H., & Kim, B. Post hoc explanations may be ineffective for detecting unknown spurious correlation. ICLR, 2022. (link)
[9] Zhang, Y., Liao, Q. V., & Bellamy, R. K. Effect of confidence and explanation on accuracy and trust calibration in AI-assisted decision making. FAccT, 2020. (link)
[10] Chancellor, S. (2023). Toward Practices for Human-Centered Machine Learning. Communications of the ACM, 66(3), 78-85. (link)
Acknowledgments
We would like to thank Kasun Amarasinghe, Jeremy Cohen, Nari Johnson, Joon Sik Kim, Q. Vera Liao, and Junhong Shen for helpful feedback and suggestions on earlier versions of the blog post. Thank you also to Emma Kallina for her help with designing the main figure!
Figure 1: Behavior-driven AI development centers model iteration on evaluating and improving specific real-world use cases.
It has never been easier to prototype AI-driven systems. With a bit of programming knowledge and a couple of hours, you can spin up a chatbot for your notes, a text-based image editor, or a tool for summarizing customer feedback. But play around with your prototype for a bit, and you might find that it doesn’t work as well as you first expected. Your system might make up facts or respond with racist suggestions. How would you evaluate your model and predict its performance in deployment?
The canonical process for benchmarking AI systems revolves around model-centric metrics. Calculate a metric (F1-score, precision, etc.), and if it increases, you are going in the right direction. But these metrics are oversimplified objectives that sand away the complexity of model behavior and cannot fully represent a model’s performance. A metric may tell you how well your model can predict the next word in a sentence, but it won’t tell you how factually accurate, logical, or fair your model is across diverse, real-world use cases. Generative AI systems such as ChatGPT or Stable Diffusion make evaluation even more challenging since there are no well-defined metrics that can summarize their performance.
When creating deployed AI products, practitioners instead focus on the specific use cases their customers have and whether or not their models are fulfilling them. In interviews with 18 AI practitioners, we found that they constantly collect user feedback and develop “golden test sets” of behaviors that they expect deployed models to have. We term this behavior-drivenAI development, a development process focused on evaluating and updating models to improve performance on real-world use cases. While chatbot A might sound more human-like, a practitioner will deploy chatbot B if it produces concise and accurate answers that customers prefer.
The landscape of AI evaluation tools primarily revolves around model-centric metrics that do not capture important behaviors like these chatbot characteristics. While there are specific tools for behavior-driven development, such as fairnesstoolkits and robustnessanalysislibraries, practitioners end up cobbling together disparate tools into ad-hoc scripts or computational notebooks that are hard to maintain and reproduce.
I believe that there are a set of abstractions that can unify AI evaluation in line with model use cases in practice. This philosophy revolves around model behaviors: metrics summarizing patterns of output on subgroups of instances. This simple concept can encode any model evaluation or analysis, from fairness audits to language model hallucinations. We show what this can look like with Zeno, an interactive platform we built for behavior-driven development that supports interactive data exploration, slicing, and reporting. By investigating their own models using Zeno, practitioners have been able to pinpoint significant and actionable issues such as biases and systematic failures.
What is model behavior?
The dictionary describes behavior as anything that an organism does involving action and response to stimulation. In the case of AI systems, model behavioris a specific pattern of output for a semantically meaningful subgroup of input data (stimulus). By semantically meaningful, I mean subgroups that can be described with human-interpretable concepts, such as “audio with noise in the background” or “people who identify as women.” Similarly, a pattern of output could be “high audio transcription error” or “low loan approval rate.”
Behaviors can be quantified as metrics on subgroups of data, often using the same metrics as are used for model-centric evaluation. But unlike summary metrics across an entire dataset, metrics in behavior-centric development quantify specificpatterns of behavior, like how often an image generation model produces unintelligible text. Tests of model behaviors are like exams for specific subjects, while summary metrics resemble IQ tests.
Model behaviors are a relatively simple concept, but encoding behaviors can be challenging in practice. Practitioners may not have enough data to validate or fix important model behaviors and have to collect or generate more data. If they have extensive data, they need ways to subdivide it into meaningful groups of instances – how do I find all images that have text? Lastly, for each subgroup, practitioners have to derive the appropriate metrics to quantify the prevalence of behavior – how do I detect blurry text? Succinctly, behavior-driven development requires sufficient data that is representative of expected behaviorsand metadata for defining and quantifying the behaviors.
A platform for behavior-driven AI development
The beauty of a behavior-based framing on AI development is that it is still data and model agnostic. While the specific behaviors for each ML task will be vastly different, subgroups of data and metrics are universal concepts.
To test this theory, we built a platform for behavior-driven AI development called Zeno. Zeno is a platform that empowers users to explore data and model outputs, interactively create subgroups of data, and calculate and quantify model behaviors. Zeno consists of a Python API for scaffolding the data needed for analysis and a user interface for interactively creating subgroups and evaluating behaviors.
The Python API is a set of decorator functions (wrappers on user-defined functions) that can be used to plug in ML models and derive metadata features and metrics from input data. Since the decorators are generic wrappers, Zeno supports any Python-based model, processing function, or metric. Zeno preprocesses the input data with these functions, which it passes into the UI for analysis.
Zeno’s UIis the primary interface for behavior-driven evaluation. It allows users to interactively explore and filter their data, create slices, calculate metrics, and create exportable visualizations. On the right side of the UI is Zeno’s instance view, where users can explore the raw data on which the model is being evaluated. In addition to the standard list view, users can also see the data in a table or a 2D scatterplot representation. The left side of the interface holds the metadata panel. All the metadata columns that either came with the dataset or were generated with the Python API have their distributions displayed in the panel. Users can interactively filter the distributions to update the instance view and create named subgroups.
The UI also has a report page for creating interactive summary visualizations of behaviors. For example, a user could create a bar chart comparing the performance of three models across ten different slices. Or they could create a line chart showing how a model performs on data slices from each day of data. These visualizations can be exported or shared directly with other stakeholders.
Case Studies
We have worked with various ML practitioners to apply Zeno to the models and tasks on which they work. Using Zeno, practitioners found significant model issues and areas for improvement, including gender biases and regional model disparities.
Audio transcription.This first case study I ran myself after I heard that OpenAI released a new speech-to-text model, Whisper, with state-of-the-art performance. I was curious how the model compared to some existing off-the-shelf transcription models. Instead of looking at aggregate metrics, I ran the models on the Speech Accent Archive dataset, which has speakers worldwide saying the same phrase. By filtering the dataset’s extensive metadata, I found that the models perform worse for English speakers who learned the language later in life and speakers from countries where English is not the native language.
Cancer classification.In another case study, we worked with a researcher who wanted to improve a breast cancer classifier for mammogram images. Since the data was anonymized and lacked meaningful metadata, the practitioner wrote dozens of functions using a Python library to extract meaningful metadata features. By exploring the distributions, they found that images with higher “entropy” correlating with denser breast tissue had a significantly higher error rate than images with lower entropy, or less dense, tissue. This finding matches performance differences in human radiologists, who also perform worse for images of denser breast tissue since it makes it harder to detect lesions.
High density (656) entropy > 2.75 && gray level variance > 2.5
AUC
0.86
0.76
Figure 6. The breast cancer classification model performed significantly worse for high-density images (described by high entropy and gray level variance metadata levels) compared to the low-density images. (left, low density, right, high density).
Image generation.Models with complex outputs often do not have clearly defined metrics, including text-to-image generation models such as DALL*E and Stable Diffusion. We can instead look at metrics that measure specific behaviors. In this example, a practitioner we worked with was exploring the DiffusionDB dataset, which has over two million prompt-image pairs from the Stable Diffusion model. The dataset also has metadata for how NSFW or inappropriate the prompts and images are. This data was used to derive an “average NSFW” metric, which can show us interesting potential biases in the model. For example, the participant compared the images generated using prompts with the word “boy” versus “girl” and found that prompts with “girl” generated images with a significantly higher NSFW level than prompts with “boy”, showing potential biases in the types of images created by the model.
Discussion and Opportunities
Model iteration is still a primarily reactiveprocess of finding and defining behaviors after a model has been deployed and the customer complaints start rolling in. There remains significant room for improving this process, from making it easier to ideate model behaviors to tracking model changes over time.
Discovering behaviors. While practitioners often need a model to discover the behaviors the model should have, methods for defining expected model behaviors before deployment can prevent serious real-world model issues. For example, crowdsourcing techniques for eliciting potential edge cases could preemptively catch model errors. Algorithmic methods that find clusters of data with high error have also shown promise for surfacing problematic behaviors.
Data discovery and generation. Having high-quality, representative data remains a persistent obstacle for behavioral evaluation. In some domains with ample data, such as natural images, methods like Stable Diffusion have shown promise for generating new data for evaluation or training. In less data-rich domains, techniques for searching through large unlabeled datasets, such as text-based image search, can surface valuable data for evaluation and retraining. It is also challenging to derive metadata from instances for creating subgroups and calculating metrics. While it can be easy to generate metadata for simple concepts like “image brightness,” many behaviors are defined by complex metadata such as “images with a person wearing clear glasses” that cannot be encoded by a simple function. Foundation models have shown some promise in using text-based descriptions to generate complex metadata and metrics.
Model comparison. Models are almost never one-off jobs and can be updated daily or weekly. While it is easy to compare aggregate metrics, it can be challenging to compare model performance in behavior-driven development. To pick between models, users may have to compare dozens of behaviors and qualitative insights. Improved visual encodings or intelligent recommendations of model differences could help users make informed decisions and deploy the right models.
Fixing behaviors. Discovering and encoding behaviors is one thing, but fixing behaviors is another massive challenge. A common approach to fixing issues is to gather more data and retrain the model, but this process can lead to catastrophic forgetting and regressions. There are recent techniques that align well with behavior-driven development, such as slice-based learning, which can selectively fix model behaviors without new data.
Conclusion
There is significant excitement for this new era of AI systems. But along with their growing capability, the complexity of their behavior is also increasing. We need powerful tools to empower behavior-driven development and ensure we build intelligent systems that align with human values. Zeno provides a general-purpose platform that empowers users to do this deep evaluation across the diverse tasks of modern AI. Learn more about Zeno at zenoml.com, read the full paper, or reach out if you would like to use Zeno for your models!
Acknowledgments
I’d like to thank Will Epperson, Jason I. Hong, Yi-Cheng Huang, Misha Khodak, Adam Perer, Venkat Sivaraman, Ameet Talwalkar, and Kristen Vossler for their thoughtful feedback and advice.