FormNet: Beyond Sequential Modeling for Form-Based Document Understanding

Form-based document understanding is a growing research topic because of its practical potential for automatically converting unstructured text data into structured information to gain insight about a document’s contents. Recent sequence modeling, which is a self-attention mechanism that directly models relationships between all words in a selection of text, has demonstrated state-of-the-art performance on natural language tasks. A natural approach to handle form document understanding tasks is to first serialize the form documents (usually in a left-to-right, top-to-bottom fashion) and then apply state-of-the-art sequence models to them.

However, form documents often have more complex layouts that contain structured objects, such as tables, columns, and text blocks. Their variety of layout patterns makes serialization difficult, substantially limiting the performance of strict serialization approaches. These unique challenges in form document structural modeling have been largely underexplored in literature.

An illustration of the form document information extraction task using an example from the FUNSD dataset.

In “FormNet: Structural Encoding Beyond Sequential Modeling in Form Document Information Extraction”, presented at ACL 2022, we propose a structure-aware sequence model, called FormNet, to mitigate the sub-optimal serialization of forms for document information extraction. First, we design a Rich Attention (RichAtt) mechanism that leverages the 2D spatial relationship between word tokens for more accurate attention weight calculation. Then, we construct Super-Tokens (tokens that aggregate semantically meaningful information from neighboring tokens) for each word by embedding representations from their neighboring tokens through a graph convolutional network (GCN). Finally, we demonstrate that FormNet outperforms existing methods, while using less pre-training data, and achieves state-of-the-art performance on the CORD, FUNSD, and Payment benchmarks.

FormNet for Information Extraction
Given a form document, we first use the BERT-multilingual vocabulary and optical character recognition (OCR) engine to identify and tokenize words. We then feed the tokens and their corresponding 2D coordinates into a GCN for graph construction and message passing. Next, we use Extended Transformer Construction (ETC) layers with the proposed RichAtt mechanism to continue to process the GCN-encoded structure-aware tokens for schema learning (i.e., semantic entity extraction). Finally, we use the Viterbi algorithm, which finds a sequence that maximizes the posterior probability, to decode and obtain the final entities for output.

Extended Transformer Construction (ETC)
We adopt ETC as the FormNet model backbone. ETC scales to relatively long inputs by replacing standard attention, which has quadratic complexity, with a sparse global-local attention mechanism that distinguishes between global and long input tokens. The global tokens attend to and are attended by all tokens, but the long tokens attend only locally to other long tokens within a specified local radius, reducing the complexity so that it is more manageable for long sequences.

Rich Attention
Our novel architecture, RichAtt, avoids the deficiencies of absolute and relative embeddings by avoiding embeddings entirely. Instead, it computes the order of and log distance between pairs of tokens with respect to the x and y axes on the layout grid, and adjusts the pre-softmax attention scores of each pair as a direct function of these values.

In a traditional attention layer, each token representation is linearly transformed into a Query vector, a Key vector, and a Value vector. A token “looks” for other tokens from which it might want to absorb information (i.e., attend to) by finding the ones with Key vectors that create relatively high scores when matrix-multiplied (called Matmul) by its Query vector and then softmax-normalized. The token then sums together the Value vectors of all other tokens in the sentence, weighted by their score, and passes this up the network, where it will normally be added to the token’s original input vector.

However, other features beyond the Query and Key vectors are often relevant to the decision of how strongly a token should attend to another given token, such as the order they’re in, how many other tokens separate them, or how many pixels apart they are. In order to incorporate these features into the system, we use a trainable parametric function paired with an error network, which takes the observed feature and the output of the parametric function and returns a penalty that reduces the dot product attention score.

The network uses the Query and Key vectors to consider what value some low-level feature (e.g., distance) should take if the tokens are related, and penalizes the attention score based on the error.

At a high level, for each attention head at each layer, FormNet examines each pair of token representations, determines the ideal features the tokens should have if there is a meaningful relationship between them, and penalizes the attention score according to how different the actual features are from the ideal ones. This allows the model to learn constraints on attention using logical implication.

A visualization of how RichAtt might act on a sentence. There are three adjectives that the word “crow” might attend to. “Lazy” is to the right, so it probably does not modify “crow” and its attention edge is penalized. “Sly” is many tokens away, so its attention edge is also penalized. “Cunning” receives no significant penalties, so by process of elimination, it is the best candidate for attention.

Furthermore, if one assumes that the softmax-normalized attention scores represent a probability distribution, and the distributions for the observed features are known, then this algorithm — including the exact choice of parametric functions and error functions — falls out algebraically, meaning FormNet has a mathematical correctness to it that is lacking from many alternatives (including relative embeddings).

Super-Tokens by Graph Learning
The key to sparsifying attention mechanisms in ETC for long sequence modeling is to have every token only attend to tokens that are nearby in the serialized sequence. Although the RichAtt mechanism empowers the transformers by taking the spatial layout structures into account, poor serialization can still block significant attention weight calculation between related word tokens.

To further mitigate the issue, we construct a graph to connect nearby tokens in a form document. We design the edges of the graph based on strong inductive biases so that they have higher probabilities of belonging to the same entity type. For each token, we obtain its Super-Token embedding by applying graph convolutions along these edges to aggregate semantically relevant information from neighboring tokens. We then use these Super-Tokens as an input to the RichAtt ETC architecture. This means that even though an entity may get broken up into multiple segments due to poor serialization, the Super-Tokens learned by the GCN will have retained much of the context of the entity phrase.

An illustration of the word-level graph, with blue edges between tokens, of a FUNSD document.

Key Results
The Figure below shows model size vs. F1 score (the harmonic mean of the precision and recall) for recent approaches on the CORD benchmark. FormNet-A2 outperforms the most recent DocFormer while using a model that is 2.5x smaller. FormNet-A3 achieves state-of-the-art performance with a 97.28% F1 score. For more experimental results, please refer to the paper.

Model Size vs. Entity Extraction F1 Score on CORD benchmark. FormNet significantly outperforms other recent approaches in absolute F1 performance and parameter efficiency.

We study the importance of RichAtt and Super-Token by GCN on the large-scale masked language modeling (MLM) pre-training task across three FormNets. Both RichAtt and GCN components improve upon the ETC baseline on reconstructing the masked tokens by a large margin, showing the effectiveness of their structural encoding capability on form documents. The best performance is obtained when incorporating both RichAtt and GCN.

Performance of the Masked-Language Modeling (MLM) pre-training. Both the proposed RichAtt and Super-Token by GCN components improve upon ETC baseline by a large margin, showing the effectiveness of their structural encoding capability on large-scale form documents.

Using BertViz, we visualize the local-to-local attention scores for specific examples from the CORD dataset for the standard ETC and FormNet models. Qualitatively, we confirm that the tokens attend primarily to other tokens within the same visual block for FormNet. Moreover for that model, specific attention heads are attending to tokens aligned horizontally, which is a strong signal of meaning for form documents. No clear attention pattern emerges for the ETC model, suggesting the RichAtt and Super-Token by GCN enable the model to learn the structural cues and leverage layout information effectively.

The attention scores for ETC and FormNet (ETC+RichAtt+GCN) models. Unlike the ETC model, the FormNet model makes tokens attend to other tokens within the same visual blocks, along with tokens aligned horizontally, thus strongly leveraging structural cues.

Conclusion
We present FormNet, a novel model architecture for form-based document understanding. We determine that the novel RichAtt mechanism and Super-Token components help the ETC transformer excel at form understanding in spite of sub-optimal, noisy serialization. We demonstrate that FormNet recovers local syntactic information that may have been lost during text serialization and achieves state-of-the-art performance on three benchmarks.

Acknowledgements
This research was conducted by Chen-Yu Lee, Chun-Liang Li, Timothy Dozat, Vincent Perot, Guolong Su, Nan Hua, Joshua Ainslie, Renshen Wang, Yasuhisa Fujii, and Tomas Pfister. Thanks to Evan Huang, Shengyang Dai, and Salem Elie Haykal for their valuable feedback, and Tom Small for creating the animation in this post.

Read More

Learning to Prompt for Continual Learning

Supervised learning is a common approach to machine learning (ML) in which the model is trained using data that is labeled appropriately for the task at hand. Ordinary supervised learning trains on independent and identically distributed (IID) data, where all training examples are sampled from a fixed set of classes, and the model has access to these examples throughout the entire training phase. In contrast, continual learning tackles the problem of training a single model on changing data distributions where different classification tasks are presented sequentially. This is particularly important, for example, to enable autonomous agents to process and interpret continuous streams of information in real-world scenarios.

To illustrate the difference between supervised and continual learning, consider two tasks: (1) classify cats vs. dogs and (2) classify pandas vs. koalas. In supervised learning, which uses IID, the model is given training data from both tasks and treats it as a single 4-class classification problem. However, in continual learning, these two tasks arrive sequentially, and the model only has access to the training data of the current task. As a result, such models tend to suffer from performance degradation on the previous tasks, a phenomenon called catastrophic forgetting.

Mainstream solutions try to address catastrophic forgetting by buffering past data in a “rehearsal buffer” and mixing it with current data to train the model. However, the performance of these solutions depends heavily on the size of the buffer and, in some cases, may not be possible at all due to data privacy concerns. Another branch of work designs task-specific components to avoid interference between tasks. But these methods often assume that the task at test time is known, which is not always true, and they require a large number of parameters. The limitations of these approaches raise critical questions for continual learning: (1) Is it possible to have a more effective and compact memory system that goes beyond buffering past data? (2) Can one automatically select relevant knowledge components for an arbitrary sample without knowing its task identity?

In “Learning to Prompt for Continual Learning”, presented at CVPR2022, we attempt to answer these questions. Drawing inspiration from prompting techniques in natural language processing, we propose a novel continual learning framework called Learning to Prompt (L2P). Instead of continually re-learning all the model weights for each sequential task, we instead provide learnable task-relevant “instructions” (i.e., prompts) to guide pre-trained backbone models through sequential training via a pool of learnable prompt parameters. L2P is applicable to various challenging continual learning settings and outperforms previous state-of-the-art methods consistently on all benchmarks. It achieves competitive results against rehearsal-based methods while also being more memory efficient. Most importantly, L2P is the first to introduce the idea of prompting in the field of continual learning.

Compared with typical methods that adapt entire or partial model weights to tasks sequentially using a rehearsal buffer, L2P uses a single frozen backbone model and learns a prompt pool to conditionally instruct the model. “Model 0” indicates that the backbone model is fixed at the beginning.

<!–

Compared with typical methods that adapt entire or partial model weights to tasks sequentially using a rehearsal buffer, L2P uses a single frozen backbone model and learns a prompt pool to conditionally instruct the model. “Model 0” indicates that the backbone model is fixed at the beginning.

–>

Prompt Pool and Instance-Wise Query
Given a pre-trained Transformer model, “prompt-based learning” modifies the original input using a fixed template. Imagine a sentiment analysis task is given the input “I like this cat”. A prompt-based method will transform the input to “I like this cat. It looks X”, where the “X” is an empty slot to be predicted (e.g., “nice”, “cute”, etc.) and “It looks X” is the so-called prompt. By adding prompts to the input, one can condition the pre-trained models to solve many downstream tasks. While designing fixed prompts requires prior knowledge along with trial and error, prompt tuning prepends a set of learnable prompts to the input embedding to instruct the pre-trained backbone to learn a single downstream task, under the transfer learning setting.

In the continual learning scenario, L2P maintains a learnable prompt pool, where prompts can be flexibly grouped as subsets to work jointly. Specifically, each prompt is associated with a key that is learned by reducing the cosine similarity loss between matched input query features. These keys are then utilized by a query function to dynamically look up a subset of task-relevant prompts based on the input features. At test time, inputs are mapped by the query function to the top-N closest keys in the prompt pool, and the associated prompt embeddings are then fed to the rest of the model to generate the output prediction. At training, we optimize the prompt pool and the classification head via the cross-entropy loss.

Illustration of L2P at test time. First, L2P selects a subset of prompts from a key-value paired prompt pool based on our proposed instance-wise query mechanism. Then, L2P prepends the selected prompts to the input tokens. Finally, L2P feeds the extended tokens to the model for prediction.

Intuitively, similar input examples tend to choose similar sets of prompts and vice versa. Thus, prompts that are frequently shared encode more generic knowledge while other prompts encode more task-specific knowledge. Moreover, prompts store high-level instructions and keep lower-level pre-trained representations frozen, thus catastrophic forgetting is mitigated even without the necessity of a rehearsal buffer. The instance-wise query mechanism removes the necessity of knowing the task identity or boundaries, enabling this approach to address the under-investigated challenge of task-agnostic continual learning.

Effectiveness of L2P
We evaluate the effectiveness of L2P in different baseline methods using an ImageNet pre-trained Vision Transformer (ViT) on representative benchmarks. The naïve baseline, called Sequential in the graphs below, refers to training a single model sequentially on all tasks. The EWC model adds a regularization term to mitigate forgetting and the Rehearsal model saves past examples to a buffer for mixed training with current data. To measure the overall continual learning performance, we measure both the accuracy and the average difference between the best accuracy achieved during training and the final accuracy for all tasks (except the last task), which we call forgetting. We find that L2P outperforms the Sequential and EWC methods significantly in both metrics. Notably, L2P even surpasses the Rehearsal approach, which uses an additional buffer to save past data. Because the L2P approach is orthogonal to Rehearsal, its performance could be further improved if it, too, used a rehearsal buffer.

L2P outperforms baseline methods in both accuracy (top) and forgetting (bottom). Accuracy refers to the average accuracy for all tasks and forgetting is defined as the average difference between the best accuracy achieved during training and the final accuracy for all tasks (except the last task).

We also visualize the prompt selection result from our instance-wise query strategy on two different benchmarks, where one has similar tasks and the other has varied tasks. The results indicate that L2P promotes more knowledge sharing between similar tasks by having more shared prompts, and less knowledge sharing between varied tasks by having more task-specific prompts.

Prompt selection histograms for benchmarks of similar tasks (left) and varied tasks (right). The left benchmark has higher intra-task similarity, thus sharing prompts between tasks results in good performance, while the right benchmark favors more task-specific prompts.

Conclusion
In this work, we present L2P to address key challenges in continual learning from a new perspective. L2P does not require a rehearsal buffer or known task identity at test time to achieve high performance. Further, it can handle various complex continual learning scenarios, including the challenging task-agnostic setting. Because large-scale pre-trained models are widely used in the machine learning community for their robust performance on real-world problems, we believe that L2P opens a new learning paradigm towards practical continual learning applications.

Acknowledgements
We gratefully acknowledge the contributions of other co-authors, including Chen-Yu Lee, Han Zhang, Ruoxi Sun, Xiaoqi Ren, Guolong Su, Vincent Perot, Jennifer Dy, Tomas Pfister. We would also like to thank Chun-Liang Li, Jeremy Martin Kubica, Sayna Ebrahimi, Stratis Ioannidis, Nan Hua, and Emmanouil Koukoumidis, for their valuable discussions and feedback, and Tom Small for figure creation.

Read More

Locked-image Tuning: Adding Language Understanding to Image Models

The ability to classify images into categories has been transformed by deep learning. It has also been significantly accelerated by transfer learning, whereby models are first pre-trained on large datasets, like ImageNet, to learn visual representations that are then transferred via fine-tuning to a new task with less data (e.g., classifying animals). Previous works such as BiT and ViT employed these methods to achieve state-of-the-art performance on a wide range of classification tasks, such as the VTAB benchmark.

However, fine-tuning has some downsides: though pre-training is done only once, fine-tuning is necessary on every new dataset for which task-specific data is needed. Multimodal contrastive learning is an alternative, recently popularized paradigm (e.g., CLIP, ALIGN) that overcomes these issues by instead learning how to match free-form text with images. These models can then solve new tasks by reformulating them as image-text matching problems, without extra data (referred to as “zero-shot” learning). Contrastive learning is flexible and easy to adapt to new tasks, but has its own limitations, namely the need for a lot of paired image-text data and weaker performance than transfer learning approaches.

With those limitations in mind, we propose “LiT: Zero-Shot Transfer with Locked-image Text Tuning”, to appear at CVPR 2022. LiT models learn to match text to an already pre-trained image encoder. This simple yet effective setup provides the best of both worlds: strong image representations from pre-training, plus flexible zero-shot transfer to new tasks via contrastive learning. LiT achieves state-of-the-art zero-shot classification accuracy, significantly closing the gap between the two styles of learning. We think the best way to understand is to try it yourself, so we’ve included a demo of LiT models at the end of this post.

Fine-tuning (left) requires task-specific data and training to adapt a pre-trained model to a new task. An LiT model (right) can be used with any task, without further data or adaptation.

Contrastive Learning on Image-Text Data
Contrastive learning models learn representations from “positive” and “negative” examples, such that representations for “positive” examples are similar to each other but different from “negative” examples.

Multimodal contrastive learning applies this to pairs of images and associated texts. An image encoder computes representations from images, and a text encoder does the same for texts. Each image representation is encouraged to be close to the representation of its associated text (“positive”), but distinct from the representation of other texts (“negatives”) in the data, and vice versa. This has typically been done with randomly initialized models (“from scratch”), meaning the encoders have to simultaneously learn representations and how to match them.

Multimodal contrastive learning trains models to produce similar representations for closely matched images and texts.

This training can be done on noisy, loosely aligned pairs of image and text, which naturally occur on the web. This circumvents the need for manual labeling, and makes data scaling easy. Furthermore, the model learns much richer visual concepts — it’s not constrained to what’s defined in the classification label space. Instead of classifying an image as “coffee”, it can understand whether it’s “a small espresso in a white mug” or “a large latte in a red flask”.

Once trained, a model that aligns image and text can be used in many ways. For zero-shot classification, we compare image representations to text representations of the class names. For example, a “wombat vs jaguar” classifier can be built by computing the representations of the texts “jaguar” and “wombat”, and classifying an image as a jaguar if its representation better matches the former. This approach scales to thousands of classes and makes it very easy to solve classification tasks without the extra data necessary for fine-tuning. Another application of contrastive models is image search (a.k.a. image-text retrieval), by finding the image whose representation best matches that of a given text, or vice versa.

The Best of Both Worlds with Locked-image Tuning
As mentioned earlier, transfer learning achieves state-of-the-art accuracy, but requires per-task labels, datasets, and training. On the other hand, contrastive models are flexible, scalable, and easily adaptable to new tasks, but fall short in performance. To compare, at the time of writing, the state of the art on ImageNet classification using transfer learning is 90.94%, but the best contrastive zero-shot models achieve 76.4%.

LiT tuning bridges this gap: we contrastively train a text model to compute representations well aligned with the powerful ones available from a pre-trained image encoder. Importantly, for this to work well, the image encoder should be “locked“, that is: it should not be updated during training. This may be unintuitive since one usually expects the additional information from further training to increase performance, but we find that locking the image encoder consistently leads to better results.

LiT-tuning contrastively trains a text encoder to match a pre-trained image encoder. The text encoder learns to compute representations that align to those from the image encoder.

This can be considered an alternative to the classic fine-tuning stage, where the image encoder is separately adapted to every new classification task; instead we have one stage of LiT-tuning, after which the model can classify any data. LiT-tuned models achieve 84.5% zero-shot accuracy on ImageNet classification, showing significant improvements over previous methods that train models from scratch, and halving the performance gap between fine-tuning and contrastive learning.

Left: LiT-tuning significantly closes the gap between the best contrastive models and the best models fine-tuned with labels. Right: Using a pre-trained image encoder is always helpful, but locking it is surprisingly a key part of the recipe to success; unlocked image models (dashed) yield significantly worse performance.

An impressive benefit of contrastive models is increased robustness — they retain high accuracy on datasets that typically fool fine-tuned models, such as ObjectNet and ImageNet-C. Similarly, LiT-tuned models have high performance across various challenging versions of ImageNet, for example achieving a state-of-the-art 81.1% accuracy on ObjectNet.

LiT-tuning has other advantages. While prior contrastive works require large amounts of data and train for a very long time, the LiT approach is much less data hungry. LiT models trained on 24M publicly available image-text pairs rival the zero-shot classification performance of prior models trained on 400M image-text pairs of private data. The locked image encoder also leads to faster training with a smaller memory footprint. On larger datasets, image representations can be pre-computed; not running the image model during training further improves efficiency and also unlocks much larger batch sizes, which increases the number of “negatives” the model sees and is key to high-performance contrastive learning. The method works well with varied forms of image pre-training (e.g., including self-supervised learning), and with many publicly available image models. We hope that these benefits make LiT a great testbed for researchers.

Conclusion
We present Locked-image Tuning (LiT), which contrastively trains a text encoder to match image representations from a powerful pre-trained image encoder. This simple method is data and compute efficient, and substantially improves zero-shot classification performance compared to existing contrastive learning approaches.

Want to try it yourself?

A preview of the demo: use it to match free-form text descriptions to images and build your own zero-shot classifier!

We have prepared a small interactive demo to try some LiT-tuned models. We also provide a Colab with more advanced use cases and larger models, which are a great way to get started.

Acknowledgments
We would like to thank Xiaohua Zhai, Xiao Wang, Daniel Keysers, Alexander Kolesnikov, and Lucas Beyer who have co-authored the LiT paper and been involved in all aspects of its development, as well as the Brain team in Zürich. We also would like to thank Tom Small for creating the animations used in this blogpost.

Read More

Simple and Effective Zero-Shot Task-Oriented Dialogue

Modern conversational agents need to integrate with an ever-increasing number of services to perform a wide variety of tasks, from booking flights and finding restaurants, to playing music and telling jokes. Adding this functionality can be difficult — for each new task, one needs to collect new data and retrain the models that power the conversational agent. This is because most task-oriented dialogue (TOD) models are trained on a single task-specific ontology. An ontology is generally represented as a list of possible user intents (e.g., if the user wants to book a flight, if the user wants to play some music, etc.) and possible parameter slots to extract from the conversation (e.g., the date of the flight, the name of a song, and so on). A rigid ontology can be limiting, preventing the model from generalizing to new tasks or domains. For instance, a TOD model trained on a certain ontology only knows the intents in that ontology, and lacks the ability to generalize its knowledge to unseen intents. This is true even for new ontologies that overlap with ones already known to the agent — for example, if an agent already knows how to book train tickets, adding the ability to book airline tickets would require training on completely new data. Ideally, the agent should be able to leverage its existing knowledge from one ontology, and apply it to new ones.

New benchmarks, such as the the Schema Guided Dialogue (SGD) dataset, have been designed to evaluate the ability to generalize to unseen tasks, by distilling each ontology into a schema of slots and intents. In the SGD setting, TOD models are trained on multiple schemas, and evaluated on how well they generalize to unseen ones — instead of how well they overfit to a single ontology. However, recent work shows the top models still have room for improvement.

To address this problem, we introduce two different sequence-to-sequence approaches toward zero-shot transfer for dialogue modeling, presented in the papers “Description-Driven Task-Oriented Dialogue” and “Show, Don’t Tell: Demonstrations Outperform Descriptions for Schema-Guided Task-Oriented Dialogue”. Both models condition on additional contextual information, either slot and intent descriptions, or single demonstrative examples. Results obtained on multiple dialogue state tracking benchmarks show that by doing away with the fixed schemas and ontologies, these new approaches lead to state-of-the-art results on the dialogue state tracking task with more efficient models. The source code for the described approaches can be found here.

Background: Dialogue State Tracking
To address the challenge of zero-shot transfer for dialogue models, we focus on the problem of Dialogue State Tracking (DST). DST is a fundamental problem for conversational agents, in which a model predicts the belief state of a conversation, i.e., the agent’s understanding of the user’s indicated preferences. The belief state is typically modeled as an assignment of values to slots for which the user has indicated a preference in the conversation. An example is shown below.

An example conversation and its ground truth slots and intents for dialogue state tracking. Here, the active user intent is “Book a train”, and pertinent information for booking this train is recorded in the slot values.

Description-Driven Task-Oriented Dialogue
In our first paper, we introduce Description-Driven Dialogue State Tracking (D3ST), a DST model that leverages slot and intent descriptions when making predictions about the belief state. D3ST is built on top of the T5 sequence-to-sequence language model, which was shown in previous work to be pretrained effectively for DST problems.

D3ST prompts the input sequence with slot and intent descriptions, allowing the T5 model to attend to both this contextual information and the conversation. Its ability to generalize comes from the formulation of these descriptions. Instead of using a name for each slot, we assign a random index for every slot. For categorical slots (i.e., slots that only take values from a small, predefined set), possible values are also arbitrarily enumerated and then listed. The same is done with intents, and together these descriptions form the schema representation to be included in the input string. This is concatenated with the conversation text and fed into the T5 model. The target output is the belief state and user intent, again identified by their assigned indices. An example is shown below.

An example of the D3ST input and output format. The red text contains slot descriptions, while the blue text contains intent descriptions. The yellow text contains the conversation utterances.

This forces the model to predict conversation contexts using a slot’s index, and not that specific slot. By randomizing the index we assign to each slot between different examples, we prevent the model from learning specific schema information. The slot with index 0 could be the “Train Departure” slot in one example, and the “Train Destination” in another — as such, the model is encouraged to use the slot description given in index 0 to find the correct value, and discouraged from overfitting to a specific schema. With this setup, a model that sees enough different tasks or domains will learn to generalize the action of belief state tracking and intent prediction.

Show Don’t Tell
In our subsequent paper, “Show, Don’t Tell: Demonstrations Outperform Descriptions for Schema-Guided Task-Oriented Dialogue”, we employ a single annotated dialogue example that demonstrates the possible slots and values in a conversation, instead of relying on slot descriptions. In this sense, we “show” the semantics of the schema rather than “tell” the model through descriptions — hence the name “Show Don’t Tell” (SDT). SDT is also built on T5, and improves zero-shot performance beyond D3ST.

n example of the SDT input and output format. The text in red contains the demonstrative example, while the text in blue contains its ground truth belief state. The actual conversation for the model to predict is in yellow. While the D3ST prompt relies entirely on slot descriptions, the SDT prompt contains a concise example dialogue followed by the expected dialogue state annotations, resulting in more direct supervision.

The rationale for SDT’s single example demonstration is simple: there can still be ambiguities that are not fully captured in a slot or intent description, and require a concrete example to demonstrate. Moreover, from a developer’s standpoint, creating short dialogue examples to describe a schema can often be easier than writing descriptions that fully capture the meaning behind each slot and intent.

Benchmark Results
We evaluate both D3ST and SDT on a number of benchmarks, most notably the SGD dataset, which tests zero-shot generalization to unseen schemas in its test set. We evaluate our state tracking models on joint goal accuracy (JGA), the fraction of dialogue turns for which the model predicts an exactly correct belief state.

Both of our models either match or outperform existing state-of-the-art baselines (T5DST and paDST) at comparable model sizes, as shown below. In general, SDT performs slightly better than D3ST. Note that our models can be trained on different sizes of the underlying T5 language model. In addition, while the baseline models can only make predictions for one slot per forward pass, both our models can decode the entire dialogue state in a single forward pass — a much more efficient method in both training and inference.

Joint Goal Accuracy on the SGD dataset plotted against model size for existing baselines and our proposed models D3ST and SDT. Note that paDST* includes additional data augmentation.

Additional metrics are reported in both papers. D3ST exhibits state-of-the-art quality on the MultiWOZ dataset, with 75.9% JGA on MultiWOZ 2.4. Both D3ST and SDT show state-of-the-art performance in the MultiWOZ cross-domain leave-one-out setting. In addition, both D3ST and SDT were evaluated using the SGD-X dataset, and demonstrated strong robustness to linguistic variations in schema. These benchmarks all indicate that D3ST and SDT are state-of-the-art TOD models, with the ability to generalize to unseen tasks and domains.

Zero-Shot Capability
D3ST and SDT sometimes demonstrate a surprising ability to generalize to unseen tasks, and we saw many interesting examples when trying completely new dialogues with the model. We’ve included one such example below:

A D3ST model trained on the SGD dataset makes predictions (right) for an unseen meta conversation (left) about creating this blog post. The model predicts a completely correct belief state, even though it is not fine-tuned on anything related to blogs, authors or NLP.

Future Work
These papers demonstrate the feasibility of a zero-shot TOD system that can generalize to unseen tasks or domains. However, we’ve limited ourselves to the DST problem for now — we plan to extend this research to enable zero-shot dialogue policy modeling, allowing TOD systems to take actions following arbitrary instructions. In addition, the current input format can often lead to long input sequences, which can be slow for inference — we’re exploring new and more efficient methods to encode schema information.

Acknowledgements
This post reflects the combined work of Jeffrey Zhao, Raghav Gupta, Harrison Lee, Mingqiu Wang, Dian Yu, Yuan Cao, and Abhinav Rastogi. We’d like to thank Yonghui Wu and Izhak Shafran for their continued advice and guidance.

Read More

Lidar-Camera Deep Fusion for Multi-Modal 3D Detection

LiDAR and visual cameras are two types of complementary sensors used for 3D object detection in autonomous vehicles and robots. LiDAR, which is a remote sensing technique that uses light in the form of a pulsed laser to measure ranges, provides low-resolution shape and depth information, while cameras provide high-resolution shape and texture information. While the features captured by LiDAR and cameras should be merged together to provide optimal 3D object detection, it turns out that most state-of-the-art 3D object detectors use LiDAR as the only input. The main reason is that to develop robust 3D object detection models, most methods need to augment and transform the data from both modalities, making the accurate alignment of the features challenging.

Existing algorithms for fusing LiDAR and camera outputs, such as PointPainting, PointAugmenting, EPNet, 4D-Net and ContinuousFusion, generally follow two approaches — input-level fusion where the features are fused at an early stage, decorating points in the LiDAR point cloud with the corresponding camera features, or mid-level fusion where features are extracted from both sensors and then combined. Despite realizing the importance of effective alignment, these methods struggle to efficiently process the common scenario where features are enhanced and aggregated before fusion. This indicates that effectively fusing the signals from both sensors might not be straightforward and remains challenging.

In our CVPR 2022 paper, “DeepFusion: LiDAR-Camera Deep Fusion for Multi-Modal 3D Object Detection”, we introduce a fully end-to-end multi-modal 3D detection framework called DeepFusion that applies a simple yet effective deep-level feature fusion strategy to unify the signals from the two sensing modalities. Unlike conventional approaches that decorate raw LiDAR point clouds with manually selected camera features, our method fuses the deep camera and deep LiDAR features in an end-to-end framework. We begin by describing two novel techniques, InverseAug and LearnableAlign, that improve the quality of feature alignment and are applied to the development of DeepFusion. We then demonstrate state-of-the-art performance by DeepFusion on the Waymo Open Dataset, one of the largest datasets for automotive 3D object detection.

InverseAug: Accurate Alignment under Geometric Augmentation
To achieve good performance on existing 3D object detection benchmarks for autonomous cars, most methods require strong data augmentation during training to avoid overfitting. However, the necessity of data augmentation poses a non-trivial challenge in the DeepFusion pipeline. Specifically, the data from the two modalities use different augmentation strategies, e.g., rotating along the z-axis for 3D point clouds combined with random flipping for 2D camera images, often resulting in alignment that is inaccurate. Then the augmented LiDAR data has to go through a voxelization step that converts the point clouds into volume data stored in a three dimensional array of voxels. The voxelized features are quite different compared to the raw data, making the alignment even more difficult. To address the alignment issue caused by geometry-related data augmentation, we introduce Inverse Augmentation (InverseAug), a technique used to reverse the augmentation before fusion during the model’s training phase.

In the example below, we demonstrate the difficulties in aligning the augmented LiDAR data with the camera data. In this case, the LiDAR point cloud is augmented by rotation with the result that a given 3D key point, which could be any 3D coordinate, such as a LiDAR data point, cannot be easily aligned in 2D space simply through use of the original LiDAR and camera parameters. To make the localization feasible, InverseAug first stores the augmentation parameters before applying the geometry-related data augmentation. At the fusion stage, it reverses all data augmentation to get the original coordinate for the 3D key point, and then finds its corresponding 2D coordinates in the camera space.

During training, InverseAug resolves the inaccurate alignment from geometric augmentation.
Left: Alignment without InverseAug. Right: Alignment quality improvement with InverseAug.

LearnableAlign: A Cross-Modality-Attention Module to Learn Alignment
We also introduce Learnable Alignment (LearnableAlign), a cross-modality-attention–based feature-level alignment technique, to improve the alignment quality. For input-level fusion methods, such as PointPainting and PointAugmenting, given a 3D LiDAR point, only the corresponding camera pixel can be exactly located as there is a one-to-one mapping. In contrast, when fusing deep features in the DeepFusion pipeline, each LiDAR feature represents a voxel containing a subset of points, and hence, its corresponding camera pixels are in a polygon. So the alignment becomes the problem of learning the mapping between a voxel cell and a set of pixels.

A naïve approach is to average over all pixels corresponding to the given voxel. However, intuitively, and as supported by our visualized results, these pixels are not equally important because the information from the LiDAR deep feature unequally aligns with every camera pixel. For example, some pixels may contain critical information for detection (e.g., the target object), while others may be less informative (e.g., consisting of backgrounds such as roads, plants, occluders, etc.).

LearnableAlign leverages a cross-modality attention mechanism to dynamically capture the correlations between two modalities. Here, the input contains the LiDAR features in a voxel cell, and all its corresponding camera features. The output of the attention is essentially a weighted sum of the camera features, where the weights are collectively determined by a function of the LiDAR and camera features. More specifically, LearnableAlign uses three fully-connected layers to respectively transform the LiDAR features to a vector (ql), and camera features to vectors (kc) and (vc). For each vector (ql), we compute the dot products between (ql) and (kc) to obtain the attention affinity matrix that contains correlations between the LiDAR features and the corresponding camera features. Normalized by a softmax operator, the attention affinity matrix is then used to calculate weights and aggregate the vectors (vc) that contain camera information. The aggregated camera information is then processed by a fully-connected layer, and concatenated (Concat) with the original LiDAR feature. The output is then fed into any standard 3D detection framework, such as PointPillars or CenterPoint for model training.

LearnableAlign leverages the cross-attention mechanism to align LiDAR and camera features.

DeepFusion: A Better Way to Fuse Information from Different Modalities
Powered by our two novel feature alignment techniques, we develop DeepFusion, a fully end-to-end multi-modal 3D detection framework. In the DeepFusion pipeline, the LiDAR points are first fed into an existing feature extractor (e.g., pillar feature net from PointPillars) to obtain LiDAR features (e.g., pseudo-images). In the meantime, the camera images are fed into a 2D image feature extractor (e.g., ResNet) to obtain camera features. Then, InverseAug and LearnableAlign are applied in order to fuse the camera and LiDAR features together. Finally, the fused features are processed by the remaining components of the selected 3D detection model (e.g., the backbone and detection head from PointPillars) to obtain the detection results.

The pipeline of DeepFusion.

Benchmark Results
We evaluate DeepFusion on the Waymo Open Dataset, one of the largest 3D detection challenges for autonomous cars, using the Average Precision with Heading (APH) metric under difficulty level 2, the default metric to rank a model’s performance on the leaderboard. Among the 70 participating teams all over the world, the DeepFusion single and ensemble models achieve state-of-the-art performance in their corresponding categories.

The single DeepFusion model achieves new state-of-the-art performance on Waymo Open Dataset.
The Ensemble DeepFusion model outperforms all other methods on Waymo Open Dataset, ranking No. 1 on the leaderboard.

The Impact of InverseAug and LearnableAlign
We also conduct ablation studies on the effectiveness of the proposed InverseAug and LearnableAlign techniques. We demonstrate that both InverseAug and LearnableAlign individually contribute to a performance gain over the LiDAR-only model, and combining both can further yield an even more significant boost.

Ablation studies on InverseAug (IA) and LearnableAlign (LA) measured in average precision (AP) and APH. Combining both techniques contributes to the best performance gain.

Conclusion
We demonstrate that late-stage deep feature fusion can be more effective when features are aligned well, but aligning features from two different modalities can be challenging. To address this challenge, we propose two techniques, InverseAug and LearnableAlign, to improve the quality of alignment among multimodal features. By integrating these techniques into the fusion stage of our proposed DeepFusion method, we achieve state-of-the-art performance on the Waymo Open Dataset.

Acknowledgements:
Special thanks to co-authors Tianjian Meng, Ben Caine, Jiquan Ngiam, Daiyi Peng, Junyang Shen, Bo Wu, Yifeng Lu, Denny Zhou, Quoc Le, Alan Yuille, Mingxing Tan.

Read More

Investing in Eastern Europe’s AI future

It was an honor and a privilege to attend a special event in the Bulgarian capital, Sofia, today to launch INSAIT, the Institute for Computer Science, Artificial Intelligence and Technology. INSAIT is a new AI and computer science research institute that will provide truly world-class facilities.

It’s fantastic to see the country where I was born leading the charge in bridging Eastern Europe to the world-stage in computer science research.

The institute is modeled on the computer science departments of renowned institutions such as MIT, UC Berkeley and the Max-Planck Institute, and is backed by the Bulgarian government with an endowment fund of nearly $100 million. Its computer science and AI research will span topics such as machine learning, quantum computing, information security, robotics and many more. Within two years, INSAIT expects faculty and students to publish papers in top conferences.

Google is investing $3 million over the next three years to provide INSAIT with cloud computing resources and access to itsTensor Processing Unit Research Cloud, a specialized infrastructure for running high-performance machine learning models. Supported with additional investment from DeepMind and Amazon Web Services, INSAIT aims to attract and develop the best researchers, engineers and top PhD and MSc students.

I know there’s no shortage of talented researchers, computer scientists and engineers in Eastern Europe – indeed, Sofia is already ranked asone of Europe’s top tech cities – but historically, the lack of local facilities, funding and support has meant limited opportunities for basic research. INSAIT has been created in partnership with two of the world’s leading technology universities, ETH Zurich and EPFL Lausanne, and its supervisory and advisory boards consist of leading researchers who are committed to help the institute achieve its ambitious goals.

INSAIT opens in September, and I know the team is particularly keen to receive applications from women and other groups that are often underrepresented in the world of tech.

Google is delighted to support these efforts, and I cannot wait to see what new innovation emerges from this promising venture.

Read More

Large-Scale Matrix Factorization on TPUs

Matrix factorization is one of the oldest, yet still widely used, techniques for learning how to recommend items such as songs or movies from user ratings. In its basic form, it approximates a large, sparse (i.e., mostly empty) matrix of user-item interactions with a product of two smaller, denser matrices representing learned item and user features. These dense matrices, in turn, can be used to recommend items to a user with which they haven’t interacted before.

Despite its algorithmic simplicity, matrix factorization can still achieve competitive performance in recommender benchmarks. Alternating least squares (ALS), and especially its implicit variation, is a fundamental algorithm to learn the parameters of matrix factorization. ALS is known for its high efficiency because it scales linearly in the number of rows, columns and non-zeros. Hence, this algorithm is very well suited for large-scale challenges. But, for very large real-world matrix factorization datasets, a single machine implementation would not suffice, and so, it would require a large distributed system. Most of the distributed implementations of matrix factorization that employ ALS leverage off-the-shelf CPU devices, and rightfully so, due to the inherently sparse nature of the problem (the input matrix is mostly empty).

On the other hand, recent success of deep learning, which has exhibited growing computational capacity, has spurred a new wave of research and progress on hardware accelerators such as Tensor Processing Units (TPUs). TPUs afford domain specific hardware speedups, especially for use cases like deep learning, which involves a large number of dense matrix multiplications. In particular, they allow significant speedups for traditional data-parallel workloads, such as training models with Stochastic Gradient Descent (SGD) in SPMD (single program multiple data) fashion. The SPMD approach has gained popularity in computations like training neural networks with gradient descent algorithms, and can be used for both data-parallel and model-parallel computations, where we distribute parameters of the model across available devices. Nevertheless, while TPUs have been enormously attractive for methods based on SGD, it is not immediately clear if a high performance implementation of ALS, which requires a large number of distributed sparse matrix multiplies, can be developed for a large-scale cluster of TPU devices.

In “ALX: Large Scale Matrix Factorization on TPUs”, we explore a distributed ALS design that makes efficient use of the TPU architecture and can scale well to matrix factorization problems of the order of billions of rows and columns by scaling the number of available TPU cores. The approach we propose leverages a combination of model and data parallelism, where each TPU core both stores a portion of the embedding table and trains over a unique slice of data, grouped in mini-batches. In order to spur future research on large-scale matrix factorization methods and to illustrate the scalability properties of our own implementation, we also built and released a real world web link prediction dataset called WebGraph.

The figure shows the flow of data and computation through the ALX framework on TPU devices. Similar to SGD-based training procedures, each TPU core performs identical computation for its own batch of data in SPMD fashion, which allows for synchronous computation in parallel on multiple TPU cores. Each TPU starts with gathering all the relevant item embeddings in the Sharded Gather stage. These materialized embeddings are used to solve for user embeddings which are scattered to the relevant shard of the embedding table in the Sharded Scatter stage.

Dense Batching for Improved Efficiency
We designed ALX specifically for TPUs, exploiting unique properties of TPU architecture while overcoming a few interesting limitations. For instance, each TPU core has limited memory and restricts all tensors to have a static shape, but each example in a mini-batch can have a wildly varying number of items (i.e., inputs can be long and sparse). To resolve this, we break exceedingly long examples into multiple smaller examples of the same shape, a process called dense batching. More details about dense batching can be found in our paper.

Illustrating example of how sparse batches are densified to increase efficiency on TPUs.

Uniform Sharding of Embedding Tables
With the batching problem solved, we next want to factorize a sparse matrix into two dense embedding matrices (e.g., user and item embeddings) such that the resulting dot product of embeddings approximate the original sparse matrix — this helps us infer predictions for all the positions from the original matrix, including those that were empty, which can be used to recommend items with which users haven’t interacted. Both the resulting embedding tables (W and H in the figure below) can potentially be too large to fit in a single TPU core, thus requiring a distributed training setup for most large-scale use cases.

Most previous attempts of distributed matrix factorization use a parameter server architecture where the model parameters are stored on highly available servers, and the training data is processed in parallel by workers that are solely responsible for the learning task. In our case, since each TPU core has identical compute and memory, it’s wasteful to only use either memory for storing model parameters or compute for training. Thus, we designed our system such that each core is used to do both.

Illustrative example of factorizing a sparse matrix Y into two dense embedding matrices W and H.

In ALX, we uniformly divide both embedding tables, thus fully exploiting both the size of distributed memory available and the dedicated low-latency interconnects between TPUs. This is highly efficient for very large embedding tables and results in good performance for distributed gather and scatter operations.

Uniform sharding of both embedding tables (W and H) across TPU cores (in blue).

WebGraph
Since potential applications may involve very large data sets, scalability is potentially an important opportunity for advancement in matrix factorization. To that end, we also release a large real-world web link prediction dataset called WebGraph. This dataset can be easily modeled as a matrix factorization problem where rows and columns are source and destination links, respectively, and the task is to predict destination links from each source link. We use WebGraph to illustrate the scaling properties of ALX.

The WebGraph dataset was generated from a single crawl performed by CommonCrawl in 2021 where we strip everything and keep only the link->outlinks data. Since the performance of a factorization method depends on the properties of the underlying graph, we created six versions of WebGraph, each varying in the sparsity pattern and locale, to study how well ALS performs on each.

  • To study locale-specific graphs, we filter based on two top level domains: ‘de’ and ‘in’, each producing a graph with an order of magnitude fewer nodes.
  • These graphs can still have arbitrary sparsity patterns and dangling links. Thus we further filter the nodes in each graph to have a minimum of either 10 or 50 inlinks and outlinks.

For easy access, we have made these available as a Tensorflow Dataset package. For reference, the biggest version, WebGraph-sparse, has more than 365M nodes and 30B edges. We create and publish both training and testing splits for evaluation purposes.

Results
We carefully tune the system and quality parameters of ALX. Based on our observations related to precision and choice of linear solvers. ​​We observed that by carefully selecting the precision for storage of the embedding tables (bfloat16) and for the input to the linear solvers (float32), we were able to halve the memory required for the embeddings while still avoiding problems arising from lower precision values during the solve stage. For our linear solvers, we selected conjugate gradients, which we found to be the fastest across the board on TPUs. We use embeddings of dimension 128 and train the model for 16 epochs. In our experience, hyperparameter tuning over both norm penalty (λ) and unobserved weight (α) has been indispensable for good recall metrics as shown in the table below.

Results obtained by running ALX on all versions of WebGraph dataset. Recall values of 1.0 denote perfect recall.

Scaling Analysis
Since the input data are processed in parallel across TPU cores, increasing the number of cores decreases training time, ideally in a linear fashion. But at the same time, a larger number of cores requires more network communication (due to the sharded embedding tables). Thanks to high-speed interconnects, this overhead can be negligible for a small number of cores, but as the number of cores increases, the overhead eventually slows down the ideal linear scaling.

In order to confirm our hypothesis, we analyze scaling properties of the four biggest WebGraph variants in terms of training time as we increase the number of available TPU cores. As shown below, even empirically, we do observe the predicted linear decrease in training time up to a sweet spot, after which the network overhead slows the decline.

Scaling analysis of running time as the number of TPU cores are increased. Each figure plots the time taken to train for one epoch in seconds.

Conclusion
For easy access and reproducibility, the ALX code is open-sourced and can be easily run on Google Cloud. In fact, we illustrate that a sparse matrix like WebGraph-dense of size 135M x 135M (with 22B edges) can be factorized in a colab connected to 8 TPU cores in less than a day. We have designed the ALX framework with scalability in mind. With 256 TPU cores, one epoch of the largest WebGraph variant, WebGraph-sparse (365M x 365M sparse matrix) takes around 20 minutes to finish (5.5 hours for the whole training run). The final model has around 100B parameters. We hope that the ALX and WebGraph will be useful to both researchers and practitioners working in these fields. The code for ALX can be found here on github!

Acknowledgements
The core team includes Steffen Rendle, Walid Krichene and Li Zhang. We thank many Google colleagues for helping at various stages of this project. In particular, we are grateful to the JAX team for numerous discussions, especially James Bradbury and Skye Wanderman-Milne; Blake Hechtman for help with XLA and Rasmus Larsen for useful discussions about performance of linear solvers on TPUs. Finally, we’re also grateful to Nicolas Mayoraz and John Anderson for providing useful feedback.

Read More

VDTTS: Visually-Driven Text-To-Speech

Recent years have seen a tremendous increase in the creation and serving of video content to users across the world in a variety of languages and over numerous platforms. The process of creating high quality content can include several stages from video capturing and captioning to video and audio editing. In some cases dialogue is re-recorded (referred to as dialog replacement, post-sync or dubbing) in a studio in order to achieve high quality and replace original audio that might have been recorded in noisy conditions. However, the dialog replacement process can be difficult and tedious because the newly recorded audio needs to be well synced with the video, requiring several edits to match the exact timing of mouth movements.

In “More than Words: In-the-Wild Visually-Driven Prosody for Text-to-Speech”, we present a proof-of-concept visually-driven text-to-speech model, called VDTTS, that automates the dialog replacement process. Given a text and the original video frames of the speaker, VDTTS is trained to generate the corresponding speech. As opposed to standard visual speech recognition models, which focus on the mouth region, we detect and crop full faces using MediaPipe to avoid potentially excluding information pertinent to the speaker’s delivery. This gives the VDTTS model enough information to generate speech that matches the video while also recovering aspects of prosody, such as timing and emotion. Despite not being explicitly trained to generate speech that is synchronized to the input video, the learned model still does so.

Given a text and video frames of a speaker, VDTTS generates speech with prosody that matches the video signal.

VDTTS Model
The VDTTS model resembles Tacotron at its core and has four main components: (1) text and video encoders that process the inputs; (2) a multi-source attention mechanism that connects encoders to a decoder; (3) a spectrogram decoder that incorporates the speaker embedding (similarly to VoiceFilter), and produces mel-spectrograms (which are a form of compressed representation in the frequency domain); and (4) a frozen, pretrained neural vocoder that produces waveforms from the mel-spectrograms.

The overall architecture of VDTTS. Text and video encoders process the inputs and then a multisource attention mechanism connects these to a decoder that produces mel-spectrograms. A vocoder then produces waveforms from the mel-spectrograms to generate speech as an output.

We train VDTTS using video and text pairs from LSVSR in which the text corresponds to the exact words spoken by a person in a video. Throughout our testing, we have determined that VDTTS cannot generate arbitrary text, thus making it less prevalent for misuse (e.g., the generation of fake content).

Quality
To showcase the unique strength of VDTTS in this post, we have selected two inference examples from the VoxCeleb2 test dataset and compare the performance of VDTTS to a standard text-to-speech (TTS) model. In both examples, the video frames provide prosody and word timing clues, visual information that is not available to the TTS model.

In the first example, the speaker talks at a particular pace that can be seen as periodic gaps in the ground-truth mel-spectrogram (shown below). VDTTS preserves this characteristic and generates audio that is much closer to the ground-truth than the audio generated by standard TTS without access to the video.

Similarly, in the second example, the speaker takes long pauses between some of the words. These pauses are captured by VDTTS and are reflected in the video below, whereas the TTS does not capture this aspect of the speaker’s rhythm.

We also plot fundamental frequency (F0) charts to compare the pitch generated by each model to the ground-truth pitch. In both examples, the F0 curve of VDTTS fits the ground-truth much better than the TTS curve, both in the alignment of speech and silence, and also in how the pitch changes over time. See more original videos and VDTTS generated videos.

We present two examples, (a) and (b), from the VoxCeleb2 test set. From top to bottom: input face images, ground-truth (GT) mel-spectrogram, mel-spectrogram output of VDTTS, mel-spectrogram output of a standard TTS model, and two plots showing the normalized F0 (normalized by mean non-zero pitch, i.e., mean is only over voiced periods) of VDTTS and TTS compared to the ground-truth signal.

Video Samples

Original VDTTS VDTTS video-only TTS
Original displays the original video clip. VDTTS, displays the audio predicted using both the video frames and the text as input. VDTTS video-only displays audio predictions using video frames only. TTS displays audio predictions using text only. Top transcript: “of space for people to make their own judgments and to come to their own”. Bottom transcript: “absolutely love dancing I have no dance experience whatsoever but as that”.

Model Performance
We’ve measured the VDTTS model’s performance using the VoxCeleb2 dataset and compared it to TTS and the TTS with length hint (a TTS that receives the scene length) models. We demonstrate that VDTTS outperforms both models by large margins in most of the aspects we measured: higher sync-to-video quality (measured by SyncNet Distance) and better speech quality as measured by mel cepstral distance (MCD), and lower Gross Pitch Error (GPE), which measures the percentage of frames where pitch differed by more than 20% on frames for which voice was present on both the predicted and reference audio.

SyncNet distance comparison between VDTTS, TTS and the TTS with Length hint (a lower metric is better).
Mel cepstral distance comparison between VDTTS, TTS and the TTS with Length hint (a lower metric is better).
Gross Pitch Error comparison between VDTTS, TTS and the TTS with Length hint (a lower metric is better).

Discussion and Future Work
One thing to note is that, intriguingly, VDTTS can produce video synchronized speech without any explicit losses or constraints to promote this, suggesting complexities such as synchronization losses or explicit modeling are unnecessary.

While this is a proof-of-concept demonstration, we believe that in the future, VDTTS can be upgraded to be used in scenarios where the input text differs from the original video signal. This kind of a model would be a valuable tool for tasks such as translation dubbing.

Acknowledgements
We would like to thank the co-authors of this research: Michelle Tadmor Ramanovich, Ye Jia, Brendan Shillingford, and Miaosen Wang. We are also grateful to the valued contributions, discussions, and feedback from Nadav Bar, Jay Tenenbaum, Zach Gleicher, Paul McCartney, Marco Tagliasacchi, and Yoni Tzafir.

Read More

How AI and imagery build a self-updating map

Building a map is complex, and keeping it up-to-date is even more challenging. Think about how often your city, town or neighborhood changes on a day-to-day basis. Businesses and shops open and close, stretches of highway are added, and roadways change. In today’s Maps 101 installment, we’ll dive into two ways Google Maps uses advancements in AI and imagery to help you see the latest information about the world around you every single day.

Automatically updating business hours

Over the past few years, businesses have experienced a lot of change — including constantly updating operating hours based on changing pandemic-related restrictions. To keep up with this pace of change, we developed a machine learning model that automatically identifies if business hours are likely wrong, then instantly updates them with AI-generated predictions.

Let’s look at Liam’s Lemonade Shop as an example. To start, our systems consider multiple factors — such as when Liam last updated their business profile, what we know about other shops’ hours, and the Popular Times information for the shop, which uses location trends to show when it tends to be busiest. Since it appears that Liam’s business profile hasn’t been updated in over a year and its busiest hours are typically Thursday afternoons — even though Google Maps says that it’s closed at that time — Liam’s business hours are likely out of date.

Still images of a business' hours and Popular Times information on Google Maps

To see if business hours need updating, we check a store’s Popular Times information and when its business profile was last updated.

So what’s next? Our algorithms analyze the business hours of other nearby lemonade shops, information from Liam’s website, and Street View images of Liam’s storefront that look specifically for business hour signs to determine the most accurate business hour prediction. At the same time, we enlist the help of the Google Maps community — including Local Guides and even the business owners themselves through their Google Business Profile — to verify the information we predicted. In Argentina, Australia, Chile, France, Japan, Mexico, New Zealand, Peru, and the United States, we also use Duplex conversational technology to call businesses just like Liam’s and ask for their hours directly. With this new AI-first approach, we’re on track to update the hours for over 20 million businesses around the globe in the next six months – helping you know exactly when your favorite store, restaurant or cafe is open for business .

Road information that reflects the real world

We’re also experimenting with ways we can use imagery to make updates to other helpful information. For instance, starting in the U.S., we’re launching a third-party imagery pilot to let you see the most up-to-date speed limit information in your town, which can help keep you safe while driving. Here’s how it works:

Say our systems think that the speed limit information on a particular highway needs to be updated. With the help of third-party imagery partners that already gather roadway imagery to improve delivery routes, we can request a photo of the specific stretch of road that also includes a speed limit sign. If the partner has this photo available, we then use a combination of AI and help from our operations team to identify the sign in the image, extract the new speed limit information, and update Google Maps.

Picture of an intersection that has a speed limit sign

Representative imagery featuring a speed limit sign, with license plates blurred

Over time, this technology will bring more details to the map that can help make your drives safer and more efficient — like where potholes and school zones are or where new construction is happening. And as with all Google Maps features, we designed this pilot with privacy top of mind. For instance, we only reference images taken on public roads, and partners are required to blur information (like faces and license plates) to avoid potentially identifying someone. For an extra layer of privacy, we blur the photo again when we receive it and delete the photo after we use it to update the map.

AI, imagery and Duplex technology will continue to play a critical role in helping make Google Maps the most comprehensive and useful map possible. For more behind-the-scenes looks at the technology that powers Google Maps, check out the rest of our Maps 101 blog series.

Read More

Efficiently Initializing Reinforcement Learning With Prior Policies

Reinforcement learning (RL) can be used to train a policy to perform a task via trial and error, but a major challenge in RL is learning policies from scratch in environments with hard exploration challenges. For example, consider the setting depicted in the door-binary-v0 environment from the adroit manipulation suite, where an RL agent must control a hand in 3D space to open a door placed in front of it.

An RL agent must control a hand in 3D space to open a door placed in front of it. The agent receives a reward signal only when the door is completely open.

Since the agent receives no intermediary rewards, it cannot measure how close it is to completing the task, and so must explore the space randomly until it eventually opens the door. Given how long the task takes and the precise control required, this is extremely unlikely.

For tasks like this, we can avoid exploring the state space randomly by using prior information. This prior information helps the agent understand which states of the environment are good, and should be further explored. We could use offline data (i.e., data collected by human demonstrators, scripted policies, or other RL agents) to train a policy, then use it to initialize a new RL policy. In the case where we use neural networks to represent the policies, this would involve copying the pre-trained policy’s neural network over to the new RL policy. This procedure makes the new RL policy behave like the pre-trained policy. However, naïvely initializing a new RL policy like this often works poorly, especially for value-based RL methods, as shown below.

A policy is pre-trained on the antmaze-large-diverse-v0 D4RL environment with offline data (negative steps correspond to pre-training). We then use the policy to initialize actor-critic fine-tuning (positive steps starting from step 0) with this pre-trained policy as the initial actor. The critic is initialized randomly. The actor’s performance immediately drops and does not recover, as the untrained critic provides a poor learning signal and causes the good initial policy to be forgotten.

With the above in mind, in “Jump-Start Reinforcement Learning” (JSRL), we introduce a meta-algorithm that can use a pre-existing policy of any form to initialize any type of RL algorithm. JSRL uses two policies to learn tasks: a guide-policy, and an exploration-policy. The exploration-policy is an RL policy that is trained online with new experience that the agent collects from the environment, and the guide-policy is a pre-existing policy of any form that is not updated during online training. In this work, we focus on scenarios where the guide-policy is learned from demonstrations, but many other kinds of guide-policies can be used. JSRL creates a learning curriculum by rolling in the guide-policy, which is then followed by the self-improving exploration-policy, resulting in performance that compares to or improves on competitive IL+RL methods.

The JSRL Approach
The guide-policy can take any form: it could be a scripted policy, a policy trained with RL, or even a live human demonstrator. The only requirements are that the guide-policy is reasonable (i.e., better than random exploration), and it can select actions based on observations of the environment. Ideally, the guide-policy can reach poor or medium performance in the environment, but cannot further improve itself with additional fine-tuning. JSRL then allows us to leverage the progress of this guide-policy to take the performance even higher.

At the beginning of training, we roll out the guide-policy for a fixed number of steps so that the agent is closer to goal states. The exploration-policy then takes over and continues acting in the environment to reach these goals. As the performance of the exploration-policy improves, we gradually reduce the number of steps that the guide-policy takes, until the exploration-policy takes over completely. This process creates a curriculum of starting states for the exploration-policy such that in each curriculum stage, it only needs to learn to reach the initial states of prior curriculum stages.

Here, the task is for the robot arm to pick up the blue block. The guide-policy can move the arm to the block, but it cannot pick it up. It controls the agent until it grips the block, then the exploration-policy takes over, eventually learning to pick up the block. As the exploration-policy improves, the guide-policy controls the agent less and less.

Comparison to IL+RL Baselines
Since JSRL can use a prior policy to initialize RL, a natural comparison would be to imitation and reinforcement learning (IL+RL) methods that train on offline datasets, then fine-tune the pre-trained policies with new online experience. We show how JSRL compares to competitive IL+RL methods on the D4RL benchmark tasks. These tasks include simulated robotic control environments, along with datasets of offline data from human demonstrators, planners, and other learned policies. Out of the D4RL tasks, we focus on the difficult ant maze and adroit dexterous manipulation environments.

Example ant maze (left) and adroit dexterous manipulation (right) environments.

For each experiment, we train on an offline dataset and then run online fine-tuning. We compare against algorithms designed specifically for each setting, which include AWAC, IQL, CQL, and behavioral cloning. While JSRL can be used in combination with any initial guide-policy or fine-tuning algorithm, we use our strongest baseline, IQL, as a pre-trained guide and for fine-tuning. The full D4RL dataset includes one million offline transitions for each ant maze task. Each transition is a sequence of format (S, A, R, S’) which specifies what state the agent started in (S), the action the agent took (A), the reward the agent received (R), and the state the agent ended up in (S’) after taking action A. We find that JSRL performs well with as few as ten thousand offline transitions.

Average score (max=100) on the antmaze-medium-diverse-v0 environment from the D4RL benchmark suite. JSRL can improve even with limited access to offline transitions.

Vision-Based Robotic Tasks
Utilizing offline data is especially challenging in complex tasks such as vision-based robotic manipulation due to the curse of dimensionality. The high dimensionality of both the continuous-control action space and the pixel-based state space present scaling challenges for IL+RL methods in terms of the amount of data required to learn good policies. To study how JSRL scales to such settings, we focus on two difficult simulated robotic manipulation tasks: indiscriminate grasping (i.e., lifting any object) and instance grasping (i.e., lifting a specific target object).

A simulated robot arm is placed in front of a table with various categories of objects. When the robot lifts any object, a sparse reward is given for the indiscriminate grasping task. For the instance grasping task, a sparse reward is only given when a specific target object is grasped.

We compare JSRL against methods that are able to scale to complex vision-based robotics settings, such as QT-Opt and AW-Opt. Each method has access to the same offline dataset of successful demonstrations and is allowed to run online fine-tuning for up to 100,000 steps.

In these experiments, we use behavioral cloning as a guide-policy and combine JSRL with QT-Opt for fine-tuning. The combination of QT-Opt+JSRL improves faster than all other methods while achieving the highest success rate.

Mean grasping success for indiscriminate and instance grasping environments using 2k successful demonstrations.

Conclusion
We proposed JSRL, a method for leveraging a prior policy of any form to improve exploration for initializing RL tasks. Our algorithm creates a learning curriculum by rolling in a pre-existing guide-policy, which is then followed by the self-improving exploration-policy. The job of the exploration-policy is greatly simplified since it starts exploring from states closer to the goal. As the exploration-policy improves, the effect of the guide-policy diminishes, leading to a fully capable RL policy. In the future, we plan to apply JSRL to problems such as Sim2Real, and explore how we can leverage multiple guide-policies to train RL agents.

Acknowledgements
This work would not have been possible without Ikechukwu Uchendu, Ted Xiao, Yao Lu, Banghua Zhu, Mengyuan Yan, Joséphine Simon, Matthew Bennice, Chuyuan Fu, Cong Ma, Jiantao Jiao, Sergey Levine, and Karol Hausman. Special thanks to Tom Small for creating the animations for this post.

Read More