Progress and Challenges in Long-Form Open-Domain Question Answering

Posted by Aurko Roy, Research Scientist, Google Research

Open-domain long-form question answering (LFQA) is a fundamental challenge in natural language processing (NLP) that involves retrieving documents relevant to a given question and using them to generate an elaborate paragraph-length answer. While there has been remarkable recent progress in factoid open-domain question answering (QA), where a short phrase or entity is enough to answer a question, much less work has been done in the area of long-form question answering. LFQA is nevertheless an important task, especially because it provides a testbed to measure the factuality of generative text models. But, are current benchmarks and evaluation metrics really suitable for making progress on LFQA?

In “Hurdles to Progress in Long-form Question Answering” (to appear at NAACL 2021), we present a new system for open-domain long-form question answering that leverages two recent advances in NLP: 1) state-of-the-art sparse attention models, such as Routing Transformer (RT), which allow attention-based models to scale to long sequences, and 2) retrieval-based models, such as REALM, which facilitate retrievals of Wikipedia articles related to a given query. To encourage more factual grounding, our system combines information from several retrieved Wikipedia articles related to the given question before generating an answer. It achieves a new state of the art on ELI5, the only large-scale publicly available dataset for long-form question answering.

However, while our system tops the public leaderboard, we discover several troubling trends with the ELI5 dataset and its associated evaluation metrics. In particular, we find 1) little evidence that models actually use the retrievals on which they condition; 2) that trivial baselines (e.g., input copying) beat modern systems, like RAG / BART+DPR; and 3) that there is a significant train/validation overlap in the dataset. Our paper suggests mitigation strategies for each of these issues.

Text Generation
The main workhorse of NLP models is the Transformer architecture, in which each token in a sequence attends to every other token in a sequence, resulting in a model that scales quadratically with sequence length. The RT model introduces a dynamic, content-based sparse attention mechanism that reduces the complexity of attention in the Transformer model from n2 to n1.5, where n is the sequence length, which enables it to scale to long sequences. This allows each word to attend to other relevant words anywhere in the entire piece of text, unlike methods such as Transformer-XL where a word can only attend to words in its immediate vicinity.

The key insight of the RT work is that each token attending to every other token is often redundant, and may be approximated by a combination of local and global attention. Local attention allows each token to build up a local representation over several layers of the model, where each token attends to a local neighborhood, facilitating local consistency and fluency. Complementing local attention, the RT model also uses mini-batch k-means clustering to enable each token to attend only to a set of most relevant tokens.

Attention maps for the content-based sparse attention mechanism used in Routing Transformer. The word sequence is represented by the diagonal dark colored squares. In the Transformer model (left), each token attends to every other token. The shaded squares represent the tokens in the sequence to which a given token (the dark square) is attending. The RT model uses both local attention (middle), where tokens attend only to other tokens in their local neighborhood, and routing attention (right), in which a token only attends to clusters of tokens most relevant to it in context. The dark red, green and blue tokens only attend to the corresponding color of lightly shaded tokens.

We pre-train an RT model on the Project Gutenberg (PG-19) data-set with a language modeling objective, i.e, the model learns to predict the next word given all the previous words, so as to be able to generate fluent paragraph long text.

Information Retrieval
To demonstrate the effectiveness of the RT model on the task of LFQA, we combine it with retrievals from REALM. The REALM model (Guu et al. 2020) is a retrieval-based model that uses the maximum inner product search to retrieve Wikipedia articles relevant to a particular query or question. The model was fine-tuned for factoid-based question answering on the Natural Questions dataset. REALM utilizes the BERT model to learn good representations for a question and uses SCANN to retrieve Wikipedia articles that have a high topical similarity with the question representation. This is then trained end-to-end to maximize the log-likelihood on the QA task.

We further improve the quality of REALM retrievals by using a contrastive loss. The idea behind this is to encourage the representation of a question to get close to its ground truth answer and diverge from the other answers in its mini-batch. This ensures that when the system retrieves relevant items using this question representation, it returns articles that are “similar” to ground truth answers. We call this retriever contrastive-REALM or c-REALM.

The combined RT + c-REALM system for LFQA.

Evaluation
We test the model on long-form question answering using the ELI5 dataset, which is a part of the KILT benchmark, and is the only publicly available large-scale LFQA dataset. The KILT benchmark measures text retrievals using Precision (R-Prec) and text generation using ROUGE-L. The two scores are combined to give a KILT R-L score, which determines a model’s ranking on the leaderboard. We fine-tune the pre-trained RT model together with retrievals from c-REALM on the ELI5 dataset from KILT.

Our submission tops the KILT leaderboard for long-form question answering on ELI5 with a combined KILT R-L score of 2.36. It improves on the previous leaderboard entry of BART + DPR (KILT R-L score of 1.9), while having a similar number of parameters as the other models on the leaderboard. In terms of text generation quality, we see an improvement of +4.11, +5.78 and +9.14 Rouge-L over T5, BART + DPR and RAG, respectively.

Our results on the KILT leaderboard for long form question answering on ELI5.

Example Generations from the RT Model

  • Q: Why are almost all boats white?

    Boats are generally painted white for obvious reasons, as they are inexpensive and easy to clean. Ships are also often painted to provide a visible visual contrast to the sea. Boats also often have high visibility ports to allow for easy international trade. In addition, many boats have an air-breathing function that makes them easier to spot and stay afloat, which is a good thing if you’re a pirate or trying to get into some country’s waters. Lastly, many ships are built in the last few decades, so they are not as exotic looking as they were thirty years ago. And lastly, many ships are made out of materials that are easy to work with, and so many of them are painted white as a practical matter. There are many, many more reasons for why ships are white, but those are the ones I know of. I’m sure there are many other reasons, but those are the ones I know of. Edit: I forgot to mention that the other answers about visibility and easy to see ports are also true. Ships are painted to appear to be larger, and more expensive. That makes them appear more impressive, and so they are more likely to be bought by people who want to show off their wealth and status.

  • Q: What causes the trail behind jets at high altitude?

    The same thing that causes the clouds you see at the start of a storm to form in the first place. The air pressure changes as the jet moves through the air, causing it to cool down and condense into water droplets. These droplets are in fact a lot bigger than the water droplets in the clouds, so when the plane passes through them, they can cause the air pressure to change, causing the water droplets to condense into clouds. This is what you see when you look at a cloud formation from the ground. In addition to this, when the plane passes through the clouds, it also moves through the air, causing it to change its temperature, causing the water droplets to condense and form clouds. These clouds are often quite large, and can block out the sun. This is what you see when you look at a contrail. Finally, the water vapor that forms at high altitudes is cooled by the air above, and as it cools, it condenses into clouds. This is what you see when you look at a cumulonimbus cloud. These clouds are even larger than the water droplets in the clouds, and they block out the sun.

  • Q: How do shampoo and conditioner work?

    The key ingredients in shampoo and conditioner are usually surfactants, they tend to attract oil from your skin. Shampoo will usually contain an anti-dandruff agent (dandruff is caused by fungi in your skin which cause dandruff). They also contain some moisturizing agents (moisturizers are things that stay in your skin longer and protect your scalp from cleaning), and a lot of anti-dandruff (dandruff is a side effect of high doses of both surfactants and moisturizers) Conditioner will contain moisturizing agents, like a skin moisturizer. They tend to cause oils to be separated from your hair (e.g. wash your hair with conditioner instead of shampoo). They both have a foaming agent that makes them foam up (usually Hydrogen Peroxide). They also have a foaming agent that stops dirt from attaching to your hair. Some conditioner will also suppress the action of the shampoo (to help prevent dirt from attaching to your hair).

  • Q: When glancing at a clock, why does the first second after glancing at it sometimes feel longer than the rest?

    This is known as chronostasis. Your brain gets used to seeing the second hand move after the first, so it ignores the first one. If you were to look at the clock a lot, you’d see the second hand move a lot more than the first, and you’d be able to tell the difference between the second and the first. The second hand will appear to move much slower than the first hand because your brain is used to seeing the second hand move a lot more than the first. If you were to go on a long trip, you’d get used to the second hand moving a lot more than the first, and you’d have no problem telling the difference between the second and the first hand. It’s kind of like how you don’t notice your nose when you’re driving, but if you look at it for a while, you start to notice it. Your brain is just getting used to seeing the nose move a lot more than the other, and it can’t tell the difference between the two. It’s kind of like that with your eyes.

  • Q: Why do Britain and other English empire countries still bow to monarchs? What real purpose does the queen serve?

    Depends what you define as “purpose”. There is no specific purpose for the British monarch – rather, they are figureheads that serve a role. The role is to have a high degree of influence over a country, though the rules of the monarchy can somewhat vary from country to country, and the ceremony of “kissing the royal corpse” does not always involve kissing the monarch. Whether that’s the only reason for the custom, or if it was the only reason, I can’t say, but that’s the reasoning. When the U.K. was in full power, the monarch was a puppet of Parliament, and the powers of the monarch were transferred to the Prime Minister and thus to the Prime Minister’s deputy, who then became the Prime Minister. Since then, Parliament has been able to vote on legislation that goes through the monarch, although they may still act as the monarch’s representative in negotiating treaties, which can have very very deep consequences. The Queen’s role, as a representative of Britain, doesn’t necessarily involve her formal approval of any of the laws or legislation that goes through Parliament, though.

  • Q: What exactly is fire, in detail? How can light and heat come from something we can’t really touch?

    Fire is the chemical reaction of fuel (oxygen) with (what we call) impurities (ash, soot, oil etc). These impurities are created due to a number of things. If you were to study the chemical composition of a log (or a candle) and analyse it you would discover the chemical formula (A, C, and G) which is a simplified version of what actually happens. There is always fuel, and there is always the ash/soot. Wood is made up of mostly carbon and hydrogen and these are the fuels that create the fire. Hydrogen and oxygen are reactive, and they form water (H2O) and CO2 (CO2). If we split a log into smaller pieces, we start to see the process of combustion as it occurs. The carbon is burned, and with the explosion of CO, H2O and H2CO2, we get water and carbon dioxide (mostly). If we smash a log together, we start to see what happens when we also hit another log with another log…imagine smashing two sand castles together. We see how the sand castles and brick castles form and collapse at different rates…then it starts to burn. The smoke, ash and flames are the actual combustion of fuel.

  • Q: What goes on in those tall tower buildings owned by major banks?

    The actual buildings are not there for profit but for show. There are many real world historical buildings that have little to no effect other than being a symbol of an economic entity or symbol of an era or location, and nothing more. For example look at Sears, Sears didn’t care what went on inside, it was all about the _appearance_ of its location, the prestige of the location, the facilities and so on. It didn’t care about how long it took it to operate, it was about how much people would pay to go see it. Sears was a landmark as a cultural movement and other big companies followed suit, so if you want to see a building you’ve never seen before, you have to go see Sears, just like you have to see a Toyota Camry for Toyota Camry. They used to be all about building new factories, some of them if I recall, but now that they’re bigger, that means that more factory jobs are coming to them. You’ve probably seen them in stores as stores where people buy and sell stuff, so there aren’t that many places for them to come from. Instead, it’s just for show, a symbol of rich people.

Hurdles Towards Progress in LFQA
However, while the RT system described here tops the public leaderboard, a detailed analysis of the model and the ELI5 dataset reveal some concerning trends.

  • Many held-out questions are paraphrased in the training set. Best answer to similar train questions gets 27.4 ROUGE-L.

  • Simply retrieving answers to random unrelated training questions yields relatively high ROUGE-L, while actual gold answers underperform generations.

  • Conditioning answer generation on random documents instead of relevant ones does not measurably impact its factual correctness. Longer outputs get higher ROUGE-L.

We find little to no evidence that the model is actually grounding its text generation in the retrieved documents — fine-tuning an RT model with random retrievals from Wikipedia (i.e., random retrieval + RT) performs nearly as well as the c-REALM + RT model (24.2 vs 24.4 ROUGE-L). We also find significant overlap in the training, validation and test sets of ELI5 (with several questions being paraphrases of each other), which may eliminate the need for retrievals. The KILT benchmark measures the quality of retrievals and generations separately, without making sure that the text generation actually use the retrievals.

Trivial baselines get higher Rouge-L scores than RAG and BART + DPR.

Moreover, we find issues with the Rouge-L metric used to evaluate the quality of text generation, with trivial nonsensical baselines, such as a Random Training Set answer and Input Copying, achieving relatively high Rouge-L scores (even beating BART + DPR and RAG).

Conclusion
We proposed a system for long form-question answering based on Routing Transformers and REALM, which tops the KILT leaderboard on ELI5. However, a detailed analysis reveals several issues with the benchmark that preclude using it to inform meaningful modelling advances. We hope that the community works together to solve these issues so that researchers can climb the right hills and make meaningful progress in this challenging but important task.

Acknowledgments
The Routing Transformer work has been a team effort involving Aurko Roy, Mohammad Saffar, Ashish Vaswani and David Grangier. The follow-up work on open-domain long-form question answering has been a collaboration involving Kalpesh Krishna, Aurko Roy and Mohit Iyyer. We wish to thank Vidhisha Balachandran, Niki Parmar and Ashish Vaswani for several helpful discussions, and the REALM team (Kenton Lee, Kelvin Guu, Ming-Wei Chang and Zora Tung) for help with their codebase and several useful discussions, which helped us improve our experiments. We are grateful to Tu Vu for help with the QQP classifier used to detect paraphrases in ELI5 train and test sets. We thank Jules Gagnon-Marchand and Sewon Min for suggesting useful experiments on checking ROUGE-L bounds. Finally we thank Shufan Wang, Andrew Drozdov, Nader Akoury and the rest of the UMass NLP group for helpful discussions and suggestions at various stages in the project.

Read More

Helping newsrooms experiment together with AI

In our JournalismAI report, journalists around the world told researchers they are eager to collaborate and explore the benefits of AI, especially as it applies to newsgathering, production and distribution. 

To facilitate their collaboration, the Google News Initiative and Polis – the journalism think tank at the London School of Economics and Political Science – are launching the JournalismAI Collab Challenges, an opportunity for three groups of five newsrooms from the Americas, Europe, the Middle East and Africa, and Asia Pacific to experiment together.

Each cohort – selected by Polis – will have six months to either cover global news stories using AI-powered storytelling techniques or to develop prototypes of new AI-based products and processes.

Participants will receive support from the JournalismAI team and partner institutions in each region: in the Americas, the challenge will be co-hosted with the Knight Lab at Northwestern University; in Europe, the Middle East and Africa, the challenge will be co-hosted with BBC News Labs and Clwstwr. JournalismAI’s partner in Asia Pacific will be announced later this year.

The Collab Challenges build on a successful pilot run by JournalismAI last year. More than 20 global newsrooms joined forces to solve four common problems using AI, from creating automated news summaries to mitigating newsroom biases, and from powering archives to increasing audience loyalty. JournalismAI online trainings are available on the Google News Initiative Training Center, where they have already been seen by more than 110,000 participants.

Newsrooms interested in participating in this free, year-long program must have made AI a strategic priority, must guarantee the participation of two staff members – one from editorial and one from the technical department – who can participate two to four hours a week, and must embrace collaboration with other publishers.

The outcome of their work – whose ownership will be shared among participants – will be presented at the second edition of the JournalismAI Festival in November.

Applications for the Americas challenge and the Europe, the Middle East and Africa Challenge open today and close at 11:59 PM GMT on April 5. The Challenge will open later this year in Asia Pacific.

To learn more about the process, please visit Polis blog and sign up for the JournalismAI newsletter.

Read More

What drives Nithya Sambasivan’s fight for fairness

When Nithya Sambasivan was finishing her undergraduate degree in engineering, she felt slightly unsatisfied. “I wanted to know, ‘how will the technology I build impact people?’” she says. Luckily, she would soon discover the field of Human Computer Interaction (HCI) and pursue her graduate degrees. 

She completed her master’s and PhD in HCI focusing on technology design for low-income communities in India. “I worked with sex workers, slum communities, microentrepreneurs, fruit and vegetables sellers on the streetside…” she says. “I wanted to understand what their values, aspirations and struggles are, and how we can build with them in mind.” 

Today, Nithya is the founder of the HCI group at the Google Research India lab and an HCI researcher at PAIR, a multidisciplinary team at Google that explores the human side of AI by doing fundamental research, building tools, creating design frameworks, and working with diverse communities. She recently sat down to answer some of our questions about her journey to researching responsible AI, fairness and championing historically underrepresented technology users.

How would you explain your job to someone who isn’t in tech?

I’m a human-computer interaction (HCI) researcher, which means I study people to better understand how to build technology that works for them. There’s been a lot of focus in the research community on building AI systems and the possibility of positively impacting the lives of billions of people. I focus on human-centered, responsible AI; specifically looking for ways it can empower communities in the Global South, where over 80% of the world’s population lives. Today, my research outlines a road map for fairness research in India, calling for re-contextualizing datasets and models while empowering communities and enabling an entire fairness ecosystem.

What originally inspired your interest in technology? 

I grew up in a middle class family, the younger of two daughters from the South of India. My parents have very progressive views about gender roles and independence, especially in a conservative society — this definitely influenced what and how I research; things like gender, caste and  poverty. In school, I started off studying engineering, which is a conventional path in India. Then, I went on to focus on HCI and designing with my own and other under-represented communities around the world.

Nithya smiling at a small child while working in the field.

How do Google’s  AI Principles inform your research? And how do you approach your research in general?

Context matters. A general theory of algorithmic fairness cannot be based on “Western” populations alone. My general approach is to research an important long-term, foundational problem. For example, our research on algorithmic fairness reframes the conversation on ethical AI away from focusing mainly on Western, meaning largely European or North American, perspectives. Another project revealed that AI developers have historically focused more on the model — or algorithm — instead of the data. Both deeply affect the eventual AI performance, so being so focused on only one aspect creates downstream problems. For example, data sets may fully miss sub-populations, so when they are deployed, they may  have much higher error rates or be unusable. Or they could make outcomes worse for certain groups, by misidentifying them as suspects for crimes or erroneously denying them bank loans they should receive.  

These insights not only enable AI systems to be better designed for under-represented communities; they also generate new considerations in the field of computing for humane and inclusive data collection, gender and social status representation, and privacy and safety needs of the most vulnerable. They are then  incorporated into Google products that millions of people use, such as Safe Folder on Files Go, Google Go’s incognito mode, Neighbourly‘s privacy, Safe Safer by Google Maps and Women in STEM videos. 

What are some of the questions you’re seeking to answer with your work?

How do we challenge inherent “West”-centric assumptions for algorithmic fairness, tech norms and make AI work better for people around the world?

For example, there’s an assumption that algorithmic biases can be fixed by adding more data from different groups. But in India, we’ve found that data can’t always represent individuals or events for many different reasons like economics and access to devices. The data could come mostly from middle class Indian men, since they’re more likely to have internet access. This means algorithms will work well for them. Yet, over half the population — primarily women, rural and tribal communities — lack access to the internet and they’re left out. Caste, religion and other factors can also contribute to new biases for AI models. 

How should aspiring AI thinkers and future technologists prepare for a career in this field? 

It’s really important that Brown and Black people enter this field. We not only bring technical skills but also lived experiences and values that are so critical to the field of computing. Our communities are the most vulnerable to AI interventions, so it’s important we shape and build these systems. To members of this community: Never play small or let someone make you feel small. Involve yourself in the political, social and ecological aspects of the invisible, not on tech innovation alone. We can’t afford not to.

Read More

Leveraging Machine Learning for Game Development

Posted by Ji Hun Kim and Richard Wu, Software Engineers, Stadia

Over the years, online multiplayer games have exploded in popularity, captivating millions of players across the world. This popularity has also exponentially increased demands on game designers, as players expect games to be well-crafted and balanced — after all, it’s no fun to play a game where a single strategy beats all the rest.

In order to create a positive gameplay experience, game designers typically tune the balance of a game iteratively:

  1. Stress-test through thousands of play-testing sessions from test users
  2. Incorporate feedback and re-design the game
  3. Repeat 1 & 2 until both the play-testers and game designers are satisfied

This process is not only time-consuming but also imperfect — the more complex the game, the easier it is for subtle flaws to slip through the cracks. When games often have many different roles that can be played, with dozens of interconnecting skills, it makes it all the more difficult to hit the right balance.

Today, we present an approach that leverages machine learning (ML) to adjust game balance by training models to serve as play-testers, and demonstrate this approach on the digital card game prototype Chimera, which we’ve previously shown as a testbed for ML-generated art. By running millions of simulations using trained agents to collect data, this ML-based game testing approach enables game designers to more efficiently make a game more fun, balanced, and aligned with their original vision.

Chimera
We developed Chimera as a game prototype that would heavily lean on machine learning during its development process. For the game itself, we purposefully designed the rules to expand the possibility space, making it difficult to build a traditional hand-crafted AI to play the game.

The gameplay of Chimera revolves around the titular chimeras, creature mash-ups that players aim to strengthen and evolve. The objective of the game is to defeat the opponent’s chimera. These are the key points in the game design:

  • Players may play:
    • creatures, which can attack (through their attack stat) or be attacked (against their health stat), or
    • spells, which produce special effects.
  • Creatures are summoned into limited-capacity biomes, which are placed physically on the board space. Each creature has a preferred biome and will take repeated damage if placed on an incorrect biome or a biome that is over capacity.
  • A player controls a single chimera, which starts off in a basic “egg” state and can be evolved and strengthened by absorbing creatures. To do this, the player must also acquire a certain amount of link energy, which is generated from various gameplay mechanics.
  • The game ends when a player has successfully brought the health of the opponent’s chimera to 0.

Learning to Play Chimera
As an imperfect information card game with a large state space, we expected Chimera to be a difficult game for an ML model to learn, especially as we were aiming for a relatively simple model. We used an approach inspired by those used by earlier game-playing agents like AlphaGo, in which a convolutional neural network (CNN) is trained to predict the probability of a win when given an arbitrary game state. After training an initial model on games where random moves were chosen, we set the agent to play against itself, iteratively collecting game data, that was then used to train a new agent. With each iteration, the quality of the training data improved, as did the agent’s ability to play the game.

The ML agent’s performance against our best hand-crafted AI as training progressed. The initial ML agent (version 0) picked moves randomly.

For the actual game state representation that the model would receive as input, we found that passing an “image” encoding to the CNN resulted in the best performance, beating all benchmark procedural agents and other types of networks (e.g. fully connected). The chosen model architecture is small enough to run on a CPU in reasonable time, which allowed us to download the model weights and run the agent live in a Chimera game client using Unity Barracuda.

An example game state representation used to train the neural network.
In addition to making decisions for the game AI, we also used the model to display the estimated win probability for a player over the course of the game.

Balancing Chimera
This approach enabled us to simulate millions more games than real players would be capable of playing in the same time span. After collecting data from the games played by the best-performing agents, we analyzed the results to find imbalances between the two of the player decks we had designed.

First, the Evasion Link Gen deck was composed of spells and creatures with abilities that generated extra link energy used to evolve a player’s chimera. It also contained spells that enabled creatures to evade attacks. In contrast, the Damage-Heal deck contained creatures of variable strength with spells that focused on healing and inflicting minor damage. Although we had designed these decks to be of equal strength, the Evasion Link Gen deck was winning 60% of the time when played against the Damage-Heal deck.

When we collected various stats related to biomes, creatures, spells, and chimera evolutions, two things immediately jumped out at us:

  1. There was a clear advantage in evolving a chimera — the agent won a majority of the games where it evolved its chimera more than the opponent did. Yet, the average number of evolves per game did not meet our expectations. To make it more of a core game mechanic, we wanted to increase the overall average number of evolves while keeping its usage strategic.
  2. The T-Rex creature was overpowered. Its appearances correlated strongly with wins, and the model would always play the T-Rex regardless of penalties for summoning into an incorrect or overcrowded biome.

From these insights, we made some adjustments to the game. To emphasize chimera evolution as a core mechanism in the game, we decreased the amount of link energy required to evolve a chimera from 3 to 1. We also added a “cool-off” period to the T-Rex creature, doubling the time it took to recover from any of its actions.

Repeating our ‘self-play’ training procedure with the updated rules, we observed that these changes pushed the game in the desired direction — the average number of evolves per game increased, and the T-Rex’s dominance faded.

One example comparison of the T-Rex’s influence before and after balancing. The charts present the number of games won (or lost) when a deck initiates a particular spell interaction (e.g., using the “Dodge” spell to benefit a T-Rex). Left: Before the changes, the T-Rex had a strong influence in every metric examined — highest survival rate, most likely to be summoned ignoring penalties, most absorbed creature during wins. Right: After the changes, the T-Rex was much less overpowered.

By weakening the T-Rex, we successfully reduced the Evasion Link Gen deck’s reliance on an overpowered creature. Even so, the win ratio between the decks remained at 60/40 rather than 50/50. A closer look at the individual game logs revealed that the gameplay was often less strategic than we would have liked. Searching through our gathered data again, we found several more areas to introduce changes in.

To start, we increased the starting health of both players as well as the amount of health that healing spells could replenish. This was to encourage longer games that would allow a more diverse set of strategies to flourish. In particular, this enabled the Damage-Heal deck to survive long enough to take advantage of its healing strategy. To encourage proper summoning and strategic biome placement, we increased the existing penalties on playing creatures into incorrect or overcrowded biomes. And finally, we decreased the gap between the strongest and weakest creatures through minor attribute adjustments.

New adjustments in place, we arrived at the final game balance stats for these two decks:

Deck Avg # evolves per game    
(before → after)    
Win % (1M games)
(before → after)
Evasion Link Gen     1.54 → 2.16     59.1% → 49.8%
Damage Heal 0.86 → 1.76     40.9% → 50.2%

Conclusion
Normally, identifying imbalances in a newly prototyped game can take months of playtesting. With this approach, we were able to not only discover potential imbalances but also introduce tweaks to mitigate them in a span of days. We found that a relatively simple neural network was sufficient to reach high level performance against humans and traditional game AI. These agents could be leveraged in further ways, such as for coaching new players or discovering unexpected strategies. We hope this work will inspire more exploration in the possibilities of machine learning for game development.

Acknowledgements
This project was conducted in collaboration with many people. Thanks to Ryan Poplin, Maxwell Hannaman, Taylor Steil, Adam Prins, Michal Todorovic, Xuefan Zhou, Aaron Cammarata, Andeep Toor, Trung Le, Erin Hoffman-John, and Colin Boswell. Thanks to everyone who contributed through playtesting, advising on game design, and giving valuable feedback.

Read More

Massively Parallel Graph Computation: From Theory to Practice

Posted by Jakub Łącki and Vahab Mirrokni, Research Scientists, Google Research

Graphs are useful theoretical representations of the connections between groups of entities, and have been used for a variety of purposes in data science, from ranking web pages by popularity and mapping out social networks, to assisting with navigation. In many cases, such applications require the processing of graphs containing hundreds of billions of edges, which is far too large to be processed on a single consumer-grade machine. A typical approach to scaling graph algorithms is to run in a distributed setting, i.e., to partition the data (and the algorithm) among multiple computers to perform the computation in parallel. While this approach allows one to process graphs with trillions of edges, it also introduces new challenges. Namely, because each computer only sees a small piece of the input graph at a time, one needs to handle inter-machine communication and design algorithms that can be split across multiple computers.

A framework for implementing distributed algorithms, MapReduce, was introduced in 2008. It transparently handled communication between machines while offering good fault-tolerance capabilities and inspired the development of a number of distributed computation frameworks, including Pregel, Apache Hadoop, and many others. Still, the challenge of developing algorithms for distributed computation on very large graphs remained, and designing efficient algorithms in this context even for basic problems, such as connected components, maximum matching or shortest paths, has been an active area of research. While recent work has demonstrated new algorithms for many problems, including our algorithms for connected components (both in theory and practice) and hierarchical clustering, there was still a need for methods that could solve a range of problems more quickly.

Today we present a pair of recent papers that address this problem by first constructing a theoretical model for distributed graph algorithms and then demonstrating how the model can be applied. The proposed model, Adaptive Massively Parallel Computation (AMPC), augments the theoretical capabilities of MapReduce, providing a pathway to solve many graph problems in fewer computation rounds. We also show how the AMPC model can be effectively implemented in practice. The suite of algorithms we describe, which includes algorithms for maximal independent set, maximum matching, connected components and minimum spanning tree, work up to 7x faster than current state-of-the-art approaches.

Limitations of MapReduce
In order to understand the limitations of MapReduce for developing graph algorithms, consider a simplified variant of the connected components problem. The input is a collection of rooted trees, and the goal is to compute, for each node, the root of its tree. Even this seemingly simple problem is not easy to solve in MapReduce. In fact, in the Massively Parallel Computation (MPC) model — the theoretical model behind MapReduce, Pregel, Apache Giraph and many other distributed computation frameworks — this problem is widely believed to require at least a number of rounds of computation proportional to log n, where n is the total number of nodes in the graph. While log n may not seem to be a large number, algorithms processing trillion-edge graphs often write hundreds of terabytes of data to disk in each round, and thus even a small reduction in the number of rounds may bring significant resource savings.

The problem of finding root nodes. Nodes are represented by blue circles. Gray arrows point from each node to its parent. The root nodes are the nodes with no parents. The orange arrows illustrate the path an algorithm would follow from a node to the root of the tree to which it belongs.

A similar subproblem showed up in our algorithms for finding connected components and computing a hierarchical clustering. We observed that one can bypass the limitations of MapReduce by implementing these algorithms through the use of a distributed hash table (DHT), a service that is initialized with a collection of key-value pairs and then returns a value associated with a provided key in real-time. In our implementation, for each node, the DHT stores its parent node. Then, a machine that processes a graph node can use the DHT and “walk up” the tree until it reaches the root. While the use of a DHT worked well for this particular problem (although it relied on the input trees being not too deep), it was unclear if the idea could be applied more broadly.

The Adaptive Massively Parallel Computation Model
To extend this approach to other problems, we started by developing a model to theoretically analyze algorithms that utilize a DHT. The resulting AMPC model builds upon the well-established MPC model and formally describes the capabilities brought by the use of a distributed hash table.

In the MPC model there is a collection of machines, which communicate via message passing in synchronous rounds. Messages sent in one round are delivered in the beginning of the following round and constitute that round’s entire input (i.e., the machines do not retain information from one round to the next). In the first round, one can assume that the input is randomly distributed across the machines. The goal is to minimize the number of computation rounds, while assuring load-balancing between machines in each round.

Computation in the MPC model. Each column represents one machine in subsequent computation rounds. Once all machines have completed a round of computation, all messages sent in that round are delivered, and the following round begins.

We then formalized the AMPC model by introducing a new approach, in which machines write to a write-only distributed hash table each round, instead of communicating via messages. Once a new round starts, the hash table from the previous round becomes read-only and a new write-only output hash table becomes available. What is important is that only the method of communication changes — the amount of communication and available space per machine is constrained exactly in the same way as in the MPC model. Hence, at a high level the added capability of the AMPC model is that each machine can choose what data to read, instead of being provided a piece of data.

Computation in the AMPC model. Once all machines have completed a round of computation, the data they produced is saved to a distributed hash table. In the following round, each machine can read arbitrary values from this distributed hash table and write to a new distributed hash table.

Algorithms and Empirical Evaluation
This seemingly small difference in the way machines communicate allowed us to design much faster algorithms to a number of basic graph problems. In particular, we show that it is possible to find connected components, minimum spanning tree, maximal matching and maximal independent set in a constant number of rounds, regardless of the size of the graph.

To investigate the practical applicability of the AMPC algorithms, we have instantiated the model by combining Flume C++ (a C++ counterpart of FlumeJava) with a DHT communication layer. We have evaluated our AMPC algorithms for minimum spanning tree, maximal independent set and maximum matching and observed that we can achieve up to 7x speedups over implementations that did not use a DHT. At the same time, the AMPC implementations used 10x fewer rounds on average to complete, and also wrote less data to disk.

Our implementation of the AMPC model took advantage of hardware-accelerated remote direct memory access (RDMA), a technology that allows reading from the memory of a remote machine with a latency of a few microseconds, which is just an order of magnitude slower than reading from local memory. While some of the AMPC algorithms communicated more data than their MPC counterparts, they were overall faster, as they performed mostly fast reads using RDMA, instead of costly writes to disk.

Conclusion
With the AMPC model, we built a theoretical framework inspired by practically efficient implementations, and then developed new theoretical algorithms that delivered good empirical performance and maintained good fault-tolerance properties. We’ve been happy to see that the AMPC model has already been the subject of further study and are excited to learn what other problems can be solved more efficiently using the AMPC model or its practical implementations.

Acknowledgements
Co-authors on the two papers covered in this blog post include Soheil Behnezhad, Laxman Dhulipala, Hossein Esfandiari, and Warren Schudy. We also thank members of the Graph Mining team for their collaborations, and especially Mohammad Hossein Bateni for his input on this post. To learn more about our recent work on scalable graph algorithms, see videos from our recent Graph Mining and Learning workshop.

Read More

Contactless Sleep Sensing in Nest Hub

Posted by Michael Dixon, Software Engineer and Reena Singhal Lee, Product Manager, Google Health

People often turn to technology to manage their health and wellbeing, whether it is to record their daily exercise, measure their heart rate, or increasingly, to understand their sleep patterns. Sleep is foundational to a person’s everyday wellbeing and can be impacted by (and in turn, have an impact on) other aspects of one’s life — mood, energy, diet, productivity, and more.

As part of our ongoing efforts to support people’s health and happiness, today we announced Sleep Sensing in the new Nest Hub, which uses radar-based sleep tracking in addition to an algorithm for cough and snore detection. While not intended for medical purposes1, Sleep Sensing is an opt-in feature that can help users better understand their nighttime wellness using a contactless bedside setup. Here we describe the technologies behind Sleep Sensing and discuss how we leverage on-device signal processing to enable sleep monitoring (comparable to other clinical- and consumer-grade devices) in a way that protects user privacy.

Soli for Sleep Tracking
Sleep Sensing in Nest Hub demonstrates the first wellness application of Soli, a miniature radar sensor that can be used for gesture sensing at various scales, from a finger tap to movements of a person’s body. In Pixel 4, Soli powers Motion Sense, enabling touchless interactions with the phone to skip songs, snooze alarms, and silence phone calls. We extended this technology and developed an embedded Soli-based algorithm that could be implemented in Nest Hub for sleep tracking.

Soli consists of a millimeter-wave frequency-modulated continuous wave (FMCW) radar transceiver that emits an ultra-low power radio wave and measures the reflected signal from the scene of interest. The frequency spectrum of the reflected signal contains an aggregate representation of the distance and velocity of objects within the scene. This signal can be processed to isolate a specified range of interest, such as a user’s sleeping area, and to detect and characterize a wide range of motions within this region, ranging from large body movements to sub-centimeter respiration.

Soli spectrogram illustrating its ability to detect a wide range of motions, characterized as (a) an empty room (no variation in the reflected signal demonstrated by the black space), (b) large pose changes, (c) brief limb movements, and (d) sub-centimeter chest and torso displacements from respiration while at rest.

In order to make use of this signal for Sleep Sensing, it was necessary to design an algorithm that could determine whether a person is present in the specified sleeping area and, if so, whether the person is asleep or awake. We designed a custom machine-learning (ML) model to efficiently process a continuous stream of 3D radar tensors (summarizing activity over a range of distances, frequencies, and time) and automatically classify each feature into one of three possible states: absent, awake, and asleep.

To train and evaluate the model, we recorded more than a million hours of radar data from thousands of individuals, along with thousands of sleep diaries, reference sensor recordings, and external annotations. We then leveraged the TensorFlow Extended framework to construct a training pipeline to process this data and produce an efficient TensorFlow Lite embedded model. In addition, we created an automatic calibration algorithm that runs during setup to configure the part of the scene on which the classifier will focus. This ensures that the algorithm ignores motion from a person on the other side of the bed or from other areas of the room, such as ceiling fans and swaying curtains.

The custom ML model efficiently processes a continuous stream of 3D radar tensors (summarizing activity over a range of distances, frequencies, and time) to automatically compute probabilities for the likelihood of user presence and wakefulness (awake or asleep).

To validate the accuracy of the algorithm, we compared it to the gold-standard of sleep-wake determination, the polysomnogram sleep study, in a cohort of 33 “healthy sleepers” (those without significant sleep issues, like sleep apnea or insomnia) across a broad age range (19-78 years of age). Sleep studies are typically conducted in clinical and research laboratories in order to collect various body signals (brain waves, muscle activity, respiratory and heart rate measurements, body movement and position, and snoring), which can then be interpreted by trained sleep experts to determine stages of sleep and identify relevant events. To account for variability in how different scorers apply the American Academy of Sleep Medicine’s staging and scoring rules, our study used two board-certified sleep technologists to independently annotate each night of sleep and establish a definitive groundtruth.

We compared our Sleep Sensing algorithm’s outputs to the corresponding groundtruth sleep and wake labels for every 30-second epoch of time to compute standard performance metrics (e.g., sensitivity and specificity). While not a true head-to-head comparison, this study’s results can be compared against previously published studies in similar cohorts with comparable methodologies in order to get a rough estimate of performance. In “Sleep-wake detection with a contactless, bedside radar sleep sensing system”, we share the full details of these validation results, demonstrating sleep-wake estimation equivalent to or, in some cases, better than current clinical and consumer sleep tracking devices.

Aggregate performance from previously published accuracies for detection of sleep (sensitivity) and wake (specificity) of a variety of sleep trackers against polysomnography in a variety of different studies, accounting for 3,990 nights in total. While this is not a head-to-head comparison, the performance of Sleep Sensing on Nest Hub in a population of healthy sleepers who simultaneously underwent polysomnography is added to the figure for rough comparison. The size of each circle is a reflection of the number of nights and the inset illustrates the mean±standard deviation for the performance metrics.

Understanding Sleep Quality with Audio Sensing
The Soli-based sleep tracking algorithm described above gives users a convenient and reliable way to see how much sleep they are getting and when sleep disruptions occur. However, to understand and improve their sleep, users also need to understand why their sleep is disrupted. To assist with this, Nest Hub uses its array of sensors to track common sleep disturbances, such as light level changes or uncomfortable room temperature. In addition to these, respiratory events like coughing and snoring are also frequent sources of disturbance, but people are often unaware of these events.

As with other audio-processing applications like speech or music recognition, coughing and snoring exhibit distinctive temporal patterns in the audio frequency spectrum, and with sufficient data an ML model can be trained to reliably recognize these patterns while simultaneously ignoring a wide variety of background noises, from a humming fan to passing cars. The model uses entirely on-device audio processing with privacy-preserving analysis, with no raw audio data sent to Google’s servers. A user can then opt to save the outputs of the processing (sound occurrences, such as the number of coughs and snore minutes) in Google Fit, in order to view personal insights and summaries of their night time wellness over time.

The Nest Hub displays when snoring and coughing may have disturbed a user’s sleep (top) and can track weekly trends (bottom).

To train the model, we assembled a large, hand-labeled dataset, drawing examples from the publicly available AudioSet research dataset as well as hundreds of thousands of additional real-world audio clips contributed by thousands of individuals.

Log-Mel spectrogram inputs comparing cough (left) and snore (right) audio snippets.

When a user opts in to cough and snore tracking on their bedside Nest Hub, the device first uses its Soli-based sleep algorithms to detect when a user goes to bed. Once it detects that a user has fallen asleep, it then activates its on-device sound sensing model and begins processing audio. The model works by continuously extracting spectrogram-like features from the audio input and feeding them through a convolutional neural network classifier in order to estimate the probability that coughing or snoring is happening at a given instant in time. These estimates are analyzed over the course of the night to produce a report of the overall cough count and snoring duration and highlight exactly when these events occurred.

Conclusion
The new Nest Hub, with its underlying Sleep Sensing features, is a first step in empowering users to understand their nighttime wellness using privacy-preserving radar and audio signals. We continue to research additional ways that ambient sensing and the predictive ability of consumer devices could help people better understand their daily health and wellness in a privacy-preserving way.

Acknowledgements
This work involved collaborative efforts from a multidisciplinary team of software engineers, researchers, clinicians, and cross-functional contributors. Special thanks to D. Shin for his significant contributions to this technology and blogpost, and Dr. Logan Schneider, visiting sleep neurologist affiliated with the Stanford/VA Alzheimer’s Center and Stanford Sleep Center, whose clinical expertise and contributions were invaluable to continuously guide this research. In addition to the authors, key contributors to this research from Google Health include Jeffrey Yu, Allen Jiang, Arno Charton, Jake Garrison, Navreet Gill, Sinan Hersek, Yijie Hong, Jonathan Hsu, Andi Janti, Ajay Kannan, Mukil Kesavan, Linda Lei, Kunal Okhandiar‎, Xiaojun Ping, Jo Schaeffer, Neil Smith, Siddhant Swaroop, Bhavana Koka, Anupam Pathak, Dr. Jim Taylor, and the extended team. Another special thanks to Ken Mixter for his support and contributions to the development and integration of this technology into Nest Hub. Thanks to Mark Malhotra and Shwetak Patel for their ongoing leadership, as well as the Nest, Fit, Soli, and Assistant teams we collaborated with to build and validate Sleep Sensing on Nest Hub.


1 Not intended to diagnose, cure, mitigate, prevent or treat any disease or condition. 

Read More

LEAF: A Learnable Frontend for Audio Classification

Posted by Neil Zeghidour, Research Scientist, Google Research

Developing machine learning (ML) models for audio understanding has seen tremendous progress over the past several years. Leveraging the ability to learn parameters from data, the field has progressively shifted from composite, handcrafted systems to today’s deep neural classifiers that are used to recognize speech, understand music, or classify animal vocalizations such as bird calls. However, unlike computer vision models, which can learn from raw pixels, deep neural networks for audio classification are rarely trained from raw audio waveforms. Instead, they rely on pre-processed data in the form of mel filterbanks — handcrafted mel-scaled spectrograms that have been designed to replicate some aspects of the human auditory response.

Although modeling mel filterbanks for ML tasks has been historically successful, it is limited by the inherent biases of fixed features: even though using a fixed mel-scale and a logarithmic compression works well in general, we have no guarantee that they provide the best representations for the task at hand. In particular, even though matching human perception provides good inductive biases for some application domains, e.g., speech recognition or music understanding, these biases may be detrimental to domains for which imitating the human ear is not important, such as recognizing whale calls. So, in order to achieve optimal performance, the mel filterbanks should be tailored to the task of interest, a tedious process that requires an iterative effort informed by expert domain knowledge. As a consequence, standard mel filterbanks are used for most audio classification tasks in practice, even though they are suboptimal. In addition, while researchers have proposed ML systems to address these problems, such as Time-Domain Filterbanks, SincNet and Wavegram, they have yet to match the performance of traditional mel filterbanks.

In “LEAF, A Fully Learnable Frontend for Audio Classification”, accepted at ICLR 2021, we present an alternative method for crafting learnable spectrograms for audio understanding tasks. LEarnable Audio Frontend (LEAF) is a neural network that can be initialized to approximate mel filterbanks, and then be trained jointly with any audio classifier to adapt to the task at hand, while only adding a handful of parameters to the full model. We show that over a wide range of audio signals and classification tasks, including speech, music and bird songs, LEAF spectrograms improve classification performance over fixed mel filterbanks and over previously proposed learnable systems. We have implemented the code in TensorFlow 2 and released it to the community through our GitHub repository.

Mel Filterbanks: Mimicking Human Perception of Sound
The first step in the traditional approach to creating a mel filterbank is to capture the sound’s time-variability by windowing, i.e., cutting the signal into short segments with fixed duration. Then, one performs filtering, by passing the windowed segments through a bank of fixed frequency filters, that replicate the human logarithmic sensitivity to pitch. Because we are more sensitive to variations in low frequencies than high frequencies, mel filterbanks give more importance to the low-frequency range of sounds. Finally, the audio signal is compressed to mimic the ear’s logarithmic sensitivity to loudness — a sound needs to double its power for a person to perceive an increase of 3 decibels.

LEAF loosely follows this traditional approach to mel filterbank generation, but replaces each of the fixed operations (i.e., the filtering layer, windowing layer, and compression function) by a learned counterpart. The output of LEAF is a time-frequency representation (a spectrogram) similar to mel filterbanks, but fully learnable. So, for example, while a mel filterbank uses a fixed scale for pitch, LEAF learns the scale that is best suited to the task of interest. Any model that can be trained using mel filterbanks as input features, can also be trained on LEAF spectrograms.

Diagram of computation of mel filterbanks compared to LEAF spectrograms.

While LEAF can be initialized randomly, it can also be initialized in a way that approximates mel filterbanks, which have been shown to be a better starting point. Then, LEAF can be trained with any classifier to adapt to the task of interest.

Left: Mel filterbanks for a person saying “wow”. Right: LEAF’s output for the same example, after training on a dataset of speech commands.

A Parameter-Efficient Alternative to Fixed Features
A potential downside of replacing fixed features that involve no learnable parameter with a trainable system is that it can significantly increase the number of parameters to optimize. To avoid this issue, LEAF uses Gabor convolution layers that have only two parameters per filter, instead of the ~400 parameters typical of a standard convolution layer. This way, even when paired with a small classifier, such as EfficientNetB0, the LEAF model only accounts for 0.01% of the total parameters.

Top: Unconstrained convolutional filters after training for audio event classification. Bottom: LEAF filters at convergence after training for the same task.

Performance
We apply LEAF to diverse audio classification tasks, including recognizing speech commands, speaker identification, acoustic scene recognition, identifying musical instruments, and finding birdsongs. On average, LEAF outperforms both mel filterbanks and previous learnable frontends, such as Time-Domain Filterbanks, SincNet and Wavegram. In particular, LEAF achieves a 76.9% average accuracy across the different tasks, compared to 73.9% for mel filterbanks. Moreover we show that LEAF can be trained in a multi-task setting, such that a single LEAF parametrization can work well across all these tasks. Finally, when combined with a large audio classifier, LEAF reaches state-of-the-art performance on the challenging AudioSet benchmark, with a 2.74 d-prime score.

D-prime score (the higher the better) of LEAF, mel filterbanks and previously proposed learnable spectrograms on the evaluation set of AudioSet.

Conclusion
The scope of audio understanding tasks keeps growing, from diagnosing dementia from speech to detecting humpback whale calls from underwater microphones. Adapting mel filterbanks to every new task can require a significant amount of hand-tuning and experimentation. In this context, LEAF provides a drop-in replacement for these fixed features, that can be trained to adapt to the task of interest, with minimal task-specific adjustments. Thus, we believe that LEAF can accelerate development of models for new audio understanding tasks.

Acknowledgements
We thank our co-authors, Olivier Teboul, Félix de Chaumont-Quitry and Marco Tagliasacchi. We also thank Dick Lyon, Vincent Lostanlen, Matt Harvey, and Alex Park for helpful discussions, and Julie Thomas for helping to design figures for this post.

Read More

A New Lens on Understanding Generalization in Deep Learning

Hanie Sedghi, Google Research and Preetum Nakkiran, Harvard University

Understanding generalization is one of the fundamental unsolved problems in deep learning. Why does optimizing a model on a finite set of training data lead to good performance on a held-out test set? This problem has been studied extensively in machine learning, with a rich history going back more than 50 years. There are now many mathematical tools that help researchers understand generalization in certain models. Unfortunately, most of these existing theories fail when applied to modern deep networks — they are both vacuous and non-predictive in realistic settings. This gap between theory and practice is largest for overparameterized models, which in theory have the capacity to overfit their train sets, but often do not in practice.

In “The Deep Bootstrap Framework: Good Online Learners are Good Offline Generalizers”, accepted at ICLR 2021, we present a new framework for approaching this problem by connecting generalization to the field of online optimization. In a typical setting, a model trains on a finite set of samples, which are reused for multiple epochs. But in online optimization, the model has access to an infinite stream of samples, and can be iteratively updated while processing this stream. In this work, we find that models that train quickly on infinite data are the same models that generalize well if they are instead trained on finite data. This connection brings new perspectives on design choices in practice, and lays a roadmap for understanding generalization from a theoretical perspective.

The Deep Bootstrap Framework
The main idea of the Deep Bootstrap framework is to compare the real world, where there is finite training data, to an “ideal world”, where there is infinite data. We define these as:

  • Real World (N, T): Train a model on N train samples from a distribution, for T minibatch stochastic gradient descent (SGD) steps, re-using the same N samples in multiple epochs, as usual. This corresponds to running SGD on the empirical loss (loss on training data), and is the standard training procedure in supervised learning.
  • Ideal World (T): Train the same model for T steps, but use fresh samples from the distribution in each SGD step. That is, we run the exact same training code (same optimizer, learning-rates, batch-size, etc.), but sample a fresh train set in each epoch instead of reusing samples. In this ideal world setting, with an effectively infinite “train set”, there is no difference between train error and test error.
Test soft-error for ideal world and real world during SGD iterations for ResNet-18 architecture. We see that the two errors are similar.

A priori, one might expect the real and ideal worlds may have nothing to do with each other, since in the real world the model sees a finite number of examples from the distribution while in the ideal world the model sees the whole distribution. But in practice, we found that the real and ideal models actually have similar test error.

In order to quantify this observation, we simulated an ideal world setting by creating a new dataset, which we call CIFAR-5m. We trained a generative model on CIFAR-10, which we then used to generate ~6 million images. The scale of the dataset was chosen to ensure that it is “virtually infinite” from the model’s perspective, so that the model never resamples the same data. That is, in the ideal world, the model sees an entirely fresh set of samples.

Samples from CIFAR-5m

The figure below presents the test error of several models, comparing their performance when trained on CIFAR-5m data in the real world setting (i.e., re-used data) and the ideal world (“fresh” data). The solid blue line shows a ResNet model in the real world, trained on 50K samples for 100 epochs with standard CIFAR-10 hyperparameters. The dashed blue line shows the corresponding model in the ideal world, trained on 5 million samples in a single pass. Surprisingly, these worlds have very similar test error — the model in some sense “doesn’t care” whether it sees re-used samples or fresh ones.

The real world model is trained on 50K samples for 100 epochs, and the ideal world model is trained on 5M samples for a single epoch. The lines show the test error vs. the number of SGD steps.

This also holds for other architectures, e.g., a Multi-Layer-Perceptron (red), a Vision Transformer (green), and across many other settings of architecture, optimizer, data distribution, and sample size. These experiments suggest a new perspective on generalization: models that optimize quickly (on infinite data), generalize well (on finite data). For example, the ResNet model generalizes better than the MLP model on finite data, but this is “because” it optimizes faster even on infinite data.

Understanding Generalization from Optimization Behavior
The key observation is that real world and ideal world models remain close, in test error, for all timesteps, until the real world converges (< 1% train error). Thus, one can study models in the real world by studying their corresponding behavior in the ideal world.

This means that the generalization of the model can be understood in terms of its optimization performance under two frameworks:

  1. Online Optimization: How fast the ideal world test error decreases
  2. Offline Optimization: How fast the real world train error converges

Thus, to study generalization, we can equivalently study the two terms above, which can be conceptually simpler, since they only involve optimization concerns. Based on this observation, good models and training procedures are those that (1) optimize quickly in the ideal world and (2) do not optimize too quickly in the real world.

All design choices in deep learning can be viewed through their effect on these two terms. For example, some advances like convolutions, skip-connections, and pretraining help primarily by accelerating ideal world optimization, while other advances like regularization and data-augmentation help primarily by decelerating real world optimization.

Applying the Deep Bootstrap Framework
Researchers can use the Deep Bootstrap framework to study and guide design choices in deep learning. The principle is: whenever one makes a change that affects generalization in the real world (the architecture, learning-rate, etc.), one should consider its effect on (1) the ideal world optimization of test error (faster is better) and (2) the real world optimization of train error (slower is better).

For example, pre-training is often used in practice to help generalization of models in small-data regimes. However, the reason that pre-training helps remains poorly understood. One can study this using the Deep Bootstrap framework by looking at the effect of pre-training on terms (1) and (2) above. We find that the primary effect of pre-training is to improve the ideal world optimization (1) — pre-training turns the network into a “fast learner” for online optimization. The improved generalization of pretrained models is thus almost exactly captured by their improved optimization in the ideal world. The figure below shows this for Vision-Transformers (ViT) trained on CIFAR-10, comparing training from scratch vs. pre-training on ImageNet.

Effect of pre-training — pre-trained ViTs optimize faster in the ideal world.

One can also study data-augmentation using this framework. Data-augmentation in the ideal world corresponds to augmenting each fresh sample once, as opposed to augmenting the same sample multiple times. This framework implies that good data-augmentations are those that (1) do not significantly harm ideal world optimization (i.e., augmented samples don’t look too “out of distribution”) or (2) inhibit real world optimization speed (so the real world takes longer to fit its train set).

The main benefit of data-augmentation is through the second term, prolonging the real world optimization time. As for the first term, some aggressive data augmentations (mixup/cutout) can actually harm the ideal world, but this effect is dwarfed by the second term.

Concluding Thoughts
The Deep Bootstrap framework provides a new lens on generalization and empirical phenomena in deep learning. We are excited to see it applied to understand other aspects of deep learning in the future. It is especially interesting that generalization can be characterized via purely optimization considerations, which is in contrast to many prevailing approaches in theory. Crucially, we consider both online and offline optimization, which are individually insufficient, but that together determine generalization.

The Deep Bootstrap framework can also shed light on why deep learning is fairly robust to many design choices: many kinds of architectures, loss functions, optimizers, normalizations, and activation functions can generalize well. This framework suggests a unifying principle: that essentially any choice that works well in the online optimization setting will also generalize well in the offline setting.

Finally, modern neural networks can be either overparameterized (e.g., large networks trained on small data tasks) or underparmeterized (e.g., OpenAI’s GPT-3, Google’s T5, or Facebook’s ResNeXt WSL). The Deep Bootstrap framework implies that online optimization is a crucial factor to success in both regimes.

Acknowledgements
We are thankful to our co-author, Behnam Neyshabur, for his great contributions to the paper and valuable feedback on the blog. We thank Boaz Barak, Chenyang Yuan, and Chiyuan Zhang for helpful comments on the blog and paper.

Read More

Accelerating Neural Networks on Mobile and Web with Sparse Inference

Posted by Artsiom Ablavatski and Marat Dukhan, Software Engineers, Google Research

On-device inference of neural networks enables a variety of real-time applications, like pose estimation and background blur, in a low-latency and privacy-conscious way. Using ML inference frameworks like TensorFlow Lite with XNNPACK ML acceleration library, engineers optimize their models to run on a variety of devices by finding a sweet spot between model size, inference speed and the quality of the predictions.

One way to optimize a model is through use of sparse neural networks [1, 2, 3], which have a significant fraction of their weights set to zero. In general, this is a desirable quality as it not only reduces the model size via compression, but also makes it possible to skip a significant fraction of multiply-add operations, thereby speeding up inference. Further, it is possible to increase the number of parameters in a model and then sparsify it to match the quality of the original model, while still benefiting from the accelerated inference. However, the use of this technique remains limited in production largely due to a lack of tools to sparsify popular convolutional architectures as well as insufficient support for running these operations on-device.

Today we announce the release of a set of new features for the XNNPACK acceleration library and TensorFlow Lite that enable efficient inference of sparse networks, along with guidelines on how to sparsify neural networks, with the goal of helping researchers develop their own sparse on-device models. Developed in collaboration with DeepMind, these tools power a new generation of live perception experiences, including hand tracking in MediaPipe and background features in Google Meet, accelerating inference speed from 1.2 to 2.4 times, while reducing the model size by half. In this post, we provide a technical overview of sparse neural networks — from inducing sparsity during training to on-device deployment — and offer some ideas on how researchers might create their own sparse models.

Comparison of the processing time for the dense (left) and sparse (right) models of the same quality for Google Meet background features. For readability, the processing time shown is the moving average across 100 frames.

Sparsifying a Neural Network
Many modern deep learning architectures, like MobileNet and EfficientNetLite, are primarily composed of depthwise convolutions with a small spatial kernel and 1×1 convolutions that linearly combine features from the input image. While such architectures have a number of potential targets for sparsification, including the full 2D convolutions that frequently occur at the beginning of many networks or the depthwise convolutions, it is the 1×1 convolutions that are the most expensive operators as measured by inference time. Because they account for over 65% of the total compute, they are an optimal target for sparsification.

Architecture Inference Time
MobileNet 85%
MobileNetV2 71%
MobileNetV3 71%
EfficientNet-Lite   66%
Comparison of inference time dedicated to 1×1 convolutions in % for modern mobile architectures.

In modern on-device inference engines, like XNNPACK, the implementation of 1×1 convolutions as well as other operations in the deep learning models rely on the HWC tensor layout, in which the tensor dimensions correspond to the height, width, and channel (e.g., red, green or blue) of the input image. This tensor configuration allows the inference engine to process the channels corresponding to each spatial location (i.e., each pixel of an image) in parallel. However, this ordering of the tensor is not a good fit for sparse inference because it sets the channel as the innermost dimension of the tensor and makes it more computationally expensive to access.

Our updates to XNNPACK enable it to detect if a model is sparse. If so, it switches from its standard dense inference mode to sparse inference mode, in which it employs a CHW (channel, height, width) tensor layout. This reordering of the tensor allows for an accelerated implementation of the sparse 1×1 convolution kernel for two reasons: 1) entire spatial slices of the tensor can be skipped when the corresponding channel weight is zero following a single condition check, instead of a per-pixel test; and 2) when the channel weight is non-zero, the computation can be made more efficient by loading neighbouring pixels into the same memory unit. This enables us to process multiple pixels simultaneously, while also performing each operation in parallel across several threads. Together these changes result in a speed-up of 1.8x to 2.3x when at least 80% of the weights are zero.

In order to avoid converting back and forth between the CHW tensor layout that is optimal for sparse inference and the standard HWC tensor layout after each operation, XNNPACK provides efficient implementations of several CNN operators in CHW layout.

Guidelines for Training Sparse Neural Networks
To create a sparse neural network, the guidelines included in this release suggest one start with a dense version and then gradually set a fraction of its weights to zero during training. This process is called pruning. Of the many available techniques for pruning, we recommend using magnitude pruning (available in the TF Model Optimization Toolkit) or the recently introduced RigL method. With a modest increase in training time, both of these can successfully sparsify deep learning models without degrading their quality. The resulting sparse models can be stored efficiently in a compressed format that reduces the size by a factor of two compared to their dense equivalent.

The quality of sparse networks is influenced by several hyperparameters, including training time, learning rate and schedules for pruning. The TF Pruning API provides an excellent example of how to select these, as well as some tips for training such models. We recommend running hyperparameter searches to find the sweet spot for your application.

Applications
We demonstrate that it is possible to sparsify classification tasks, dense segmentation (e.g., Meet background blur) and regression problems (MediaPipe Hands), which provides tangible benefits to users. For example, in the case of Google Meet, sparsification lowered the inference time of the model by 30%, which provided access to higher quality models for more users.

Model size comparisons for the dense and sparse models in Mb. The models have been stored in 16- and 32-bit floating-point formats.

The approach to sparsity described here works best with architectures based on inverted residual blocks, such as MobileNetV2, MobileNetV3 and EfficientNetLite. The degree of sparsity in a network influences both inference speed and quality. Starting from a dense network of a fixed capacity, we found modest performance gains even at 30% sparsity. With increased sparsity, the quality of the model remains relatively close to the dense baseline until reaching 70% sparsity, beyond which there is a more pronounced drop in accuracy. However, one can compensate for the reduced accuracy at 70% sparsity by increasing the size of the base network by 20%, which results in faster inference times without degrading the quality of the model. No further changes are required to run the sparsified models, because XNNPACK can recognize and automatically enable sparse inference.

Ablation studies of different sparsity levels with respect to inference time (the smaller the better) and the quality measured by the Intersection over Union (IoU) for predicted segmentation mask.

Sparsity as Automatic Alternative to Distillation
Background blur in Google Meet uses a segmentation model based on a modified MobileNetV3 backbone with attention blocks. We were able to speed up the model by 30% by applying a 70% sparsification, while preserving the quality of the foreground mask. We examined the predictions of the sparse and dense models on images from 17 geographic subregions, finding no significant difference, and released the details in the associated model card.

Similarly, MediaPipe Hands predicts hand landmarks in real-time on mobile and the web using a model based on the EfficientNetLite backbone. This backbone model was manually distilled from the large dense model, which is a computationally expensive, iterative process. Using the sparse version of the dense model instead of distilled one, we were able to maintain the same inference speed but without the labor intensive process of distilling from a dense model. Compared with the dense model the sparse model improved the inference by a factor of two, achieving the identical landmark quality as the distilled model. In a sense, sparsification can be thought of as an automatic approach to unstructured model distillation, which can improve model performance without extensive manual effort. We evaluated the sparse model on the geodiverse dataset and made the model card publicly available.

Comparison of execution time for the dense (left), distilled (middle) and sparse (right) models of the same quality. Processing time of the dense model is 2x larger than sparse or distilled models. The distilled model is taken from the official MediPipe solution. The dense and sparse web demos are publicly available.

Future work
We find sparsification to be a simple yet powerful technique for improving CPU inference of neural networks. Sparse inference allows engineers to run larger models without incurring a significant performance or size overhead and offers a promising new direction for research. We are continuing to extend XNNPACK with wider support for operations in CHW layout and are exploring how it might be combined with other optimization techniques like quantization. We are excited to see what you might build with this technology!

Acknowledgments
Special thanks to all who worked on this project: Karthik Raveendran, Erich Elsen, Tingbo Hou‎, Trevor Gale, Siargey Pisarchyk, Yury Kartynnik, Yunlu Li, Utku Evci, Matsvei Zhdanovich, Sebastian Jansson, Stéphane Hulaud, Michael Hays, Juhyun Lee, Fan Zhang, Chuo-Ling Chang, Gregory Karpiak, Tyler Mullen, Jiuqiang Tang, Ming Guang Yong, Igor Kibalchich, and Matthias Grundmann.

Read More

PAIRED: A New Multi-agent Approach for Adversarial Environment Generation

Posted by Natasha Jaques, Google Research and Michael Dennis, UC Berkeley

The effectiveness of any machine learning method is critically dependent on its training data. In the case of reinforcement learning (RL), one can rely either on limited data collected by an agent interacting with the real world, or a simulated training environment that can be used to collect as much data as needed. This latter method of training in simulation is increasingly popular, but it has a problem — the RL agent can learn what is built into the simulator, but tends to be bad at generalizing to tasks that are even slightly different than the ones simulated. And obviously building a simulator that covers all the complexity of the real-world is extremely challenging.

An approach to address this is to automatically create more diverse training environments by randomizing all the parameters of the simulator, a process called domain randomization (DR). However, DR can fail even in very simple environments. For example, in the animation below, the blue agent is trying to navigate to the green goal. The left panel shows an environment created with DR where the positions of the obstacles and goal have been randomized. Many of these DR environments were used to train the agent, which was then transferred to the simple Four Rooms environment in the middle panel. Notice that the agent can’t find the goal. This is because it has not learned to walk around walls. Even though the wall configuration from the Four Rooms example could have been generated randomly in the DR training phase, it’s unlikely. As a result, the agent has not spent enough time training on walls similar to the Four Rooms structure, and is unable to reach the goal.

Domain randomization (left) does not effectively prepare an agent to transfer to previously unseen environments, such as the Four Rooms scenario (middle). To address this, a minimax adversary is used to construct previously unseen environments (right), but can result in creating situations that are impossible to solve.

Instead of just randomizing the environment parameters, one could train a second RL agent to learn how to set the environment parameters. This minimax adversary can be trained to minimize the performance of the first RL agent by finding and exploiting weaknesses in its policy, e.g. building wall configurations it has not encountered before. But again there is a problem. The right panel shows an environment built by a minimax adversary in which it is actually impossible for the agent to reach the goal. While the minimax adversary has succeeded in its task — it has minimized the performance of the original agent — it provides no opportunity for the agent to learn. Using a purely adversarial objective is not well suited to generating training environments, either.

In collaboration with UC Berkeley, we propose a new multi-agent approach for training the adversary in “Emergent Complexity and Zero-shot Transfer via Unsupervised Environment Design”, a publication recently presented at NeurIPS 2020. In this work we present an algorithm, Protagonist Antagonist Induced Regret Environment Design (PAIRED), that is based on minimax regret and prevents the adversary from creating impossible environments, while still enabling it to correct weaknesses in the agent’s policy. PAIRED incentivizes the adversary to tune the difficulty of the generated environments to be just outside the agent’s current abilities, leading to an automatic curriculum of increasingly challenging training tasks. We show that agents trained with PAIRED learn more complex behavior and generalize better to unknown test tasks. We have released open-source code for PAIRED on our GitHub repo.

PAIRED
To flexibly constrain the adversary, PAIRED introduces a third RL agent, which we call the antagonist agent, because it is allied with the adversarial agent, i.e., the one designing the environment. We rename our initial agent, the one navigating the environment, the protagonist. Once the adversary generates an environment, both the protagonist and antagonist play through that environment.

The adversary’s job is to maximize the antagonist’s reward while minimizing the protagonist’s reward. This means it must create environments that are feasible (because the antagonist can solve them and get a high score), but challenging to the protagonist (exploit weaknesses in its current policy). The gap between the two rewards is the regret — the adversary tries to maximize the regret, while the protagonist competes to minimize it.

The methods discussed above (domain randomization, minimax regret and PAIRED) can be analyzed using the same theoretical framework, unsupervised environment design (UED), which we describe in detail in the paper. UED draws a connection between environment design and decision theory, enabling us to show that domain randomization is equivalent to the Principle of Insufficient Reason, the minimax adversary follows the Maximin Principle, and PAIRED is optimizing minimax regret. This formalism enables us to use tools from decision theory to understand the benefits and drawbacks of each method. Below, we show how each of these ideas works for environment design:

Domain randomization (a) generates unstructured environments that aren’t tailored to the agent’s learning progress. The minimax adversary (b) may create impossible environments. PAIRED (c) can generate challenging, structured environments, which are still possible for the agent to complete.

Curriculum Generation
What’s interesting about minimax regret is that it incentivizes the adversary to generate a curriculum of initially easy, then increasingly challenging environments. In most RL environments, the reward function will give a higher score for completing the task more efficiently, or in fewer timesteps. When this is true, we can show that regret incentivizes the adversary to create the easiest possible environment the protagonist can’t solve yet. To see this, let’s assume the antagonist is perfect, and always gets the highest score that it possibly can. Meanwhile, the protagonist is terrible, and gets a score of zero on everything. In that case, the regret just depends on the difficulty of the environment. Since easier environments can be completed in fewer timesteps, they allow the antagonist to get a higher score. Therefore, the regret of failing at an easy environment is greater than the regret of failing on a hard environment:

So, by maximizing regret the adversary is searching for easy environments that the protagonist fails to do. Once the protagonist learns to solve each environment, the adversary must move on to finding a slightly harder environment that the protagonist can’t solve. Thus, the adversary generates a curriculum of increasingly difficult tasks.

Results
We can see the curriculum emerging in the learning curves below, which plot the shortest path length of a maze the agents have successfully solved. Unlike minimax or domain randomization, the PAIRED adversary creates a curriculum of increasingly longer, but possible, mazes, enabling PAIRED agents to learn more complex behavior.

But can these different training schemes help an agent generalize better to unknown test tasks? Below, we see the zero-shot transfer performance of each algorithm on a series of challenging test tasks. As the complexity of the transfer environment increases, the performance gap between PAIRED and the baselines widens. For extremely difficult tasks like Labyrinth and Maze, PAIRED is the only method that can occasionally solve the task. These results provide promising evidence that PAIRED can be used to improve generalization for deep RL.

Admittedly, these simple gridworlds do not reflect the complexities of the real world tasks that many RL methods are attempting to solve. We address this in “Adversarial Environment Generation for Learning to Navigate the Web”, which examines the performance of PAIRED when applied to more complex problems, such as teaching RL agents to navigate web pages. We propose an improved version of PAIRED, and show how it can be used to train an adversary to generate a curriculum of increasingly challenging websites:

Above, you can see websites built by the adversary in the early, middle, and late training stages, which progress from using very few elements per page to many simultaneous elements, making the tasks progressively harder. We test whether agents trained on this curriculum can generalize to standardized web navigation tasks, and achieve a 75% success rate, with a 4x improvement over the strongest curriculum learning baseline:

Conclusions
Deep RL is very good at fitting a simulated training environment, but how can we build simulations that cover the complexity of the real world? One solution is to automate this process. We propose Unsupervised Environment Design (UED) as a framework that describes different methods for automatically creating a distribution of training environments, and show that UED subsumes prior work like domain randomization and minimax adversarial training. We think PAIRED is a good approach for UED, because regret maximization leads to a curriculum of increasingly challenging tasks, and prepares agents to transfer successfully to unknown test tasks.

Acknowledgements
We would like to recognize the co-authors of “Emergent Complexity and Zero-shot Transfer via Unsupervised Environment Design”: Michael Dennis, Natasha Jaques, Eugene Vinitsky, Alexandre Bayen, Stuart Russell, Andrew Critch, and Sergey Levine, as well as the co-authors of Adversarial Environment Generation for Learning to Navigate the Web: Izzeddin Gur, Natasha Jaques, Yingjie Miao, Jongwook Choi, Kevin Malta, Manoj Tiwari, Honglak Lee, Aleksandra Faust. In addition, we thank Michael Chang, Marvin Zhang, Dale Schuurmans, Aleksandra Faust, Chase Kew, Jie Tan, Dennis Lee, Kelvin Xu, Abhishek Gupta, Adam Gleave, Rohin Shah, Daniel Filan, Lawrence Chan, Sam Toyer, Tyler Westenbroek, Igor Mordatch, Shane Gu, DJ Strouse, and Max Kleiman-Weiner for discussions that contributed to this work.

Read More