Empowering PyTorch on Intel® Xeon® Scalable processors with Bfloat16

Overview

Recent years, the growing complexity of AI models have been posing requirements on hardware for more and more compute capability. Reduced precision numeric format has been proposed to address this problem. Bfloat16 is a custom 16-bit floating point format for AI which consists of one sign bit, eight exponent bits, and seven mantissa bits. With the same dynamic range as float32, bfloat16 doesn’t require a special handling such as loss scaling. Therefore, bfloat16 is a drop-in replacement for float32 when running deep neural networks for both inference and training.

The 3rd Gen Intel® Xeon® Scalable processor (codenamed Cooper Lake), is the first general purpose x86 CPU with native bfloat16 support. Three new bfloat16 instructions were introduced in Intel® Advanced Vector Extensions-512 (Intel® AVX-512): VCVTNE2PS2BF16, VCVTNEPS2BF16, and VDPBF16PS. The first two instructions perform conversion from float32 to bfloat16, and the last one performs a dot product of bfloat16 pairs. Bfloat16 theoretical compute throughput is doubled over float32 on Cooper Lake. On the next generation of Intel® Xeon® Scalable Processors, bfloat16 compute throughput will be further enhanced through Advanced Matrix Extensions (Intel® AMX) instruction set extension.

Intel and Meta previously collaborated to enable bfloat16 on PyTorch, and the related work was published in an earlier blog during launch of Cooper Lake. In that blog, we introduced the hardware advancement for native bfloat16 support and showcased a performance boost of 1.4x to 1.6x of bfloat16 over float32 from DLRM, ResNet-50 and ResNext-101-32x4d.

In this blog, we will introduce the latest software enhancement on bfloat16 in PyTorch 1.12, which would apply to much broader scope of user scenarios and showcase even higher performance boost.

Native Level Optimization on Bfloat16

On PyTorch CPU bfloat16 path, the compute intensive operators, e.g., convolution, linear and bmm, use oneDNN (oneAPI Deep Neural Network Library) to achieve optimal performance on Intel CPUs with AVX512_BF16 or AMX support. The other operators, such as tensor operators and neural network operators, are optimized at PyTorch native level. We have enlarged bfloat16 kernel level optimizations to majority of operators on dense tensors, both inference and training applicable (sparse tensor bfloat16 support will be covered in future work), specifically:

  • Bfloat16 vectorization: Bfloat16 is stored as unsigned 16-bit integer, which requires it to be casted to float32 for arithmetic operations such as add, mul, etc. Specifically, each bfloat16 vector will be converted to two float32 vectors, processed accordingly and then converted back. While for non-arithmetic operations such as cat, copy, etc., it is a straight memory copy and no data type conversion will be involved.
  • Bfloat16 reduction: Reduction on bfloat16 data uses float32 as accumulation type to guarantee numerical stability, e.g., sum, BatchNorm2d, MaxPool2d, etc.
  • Channels Last optimization: For vision models, Channels Last is the preferable memory format over Channels First from performance perspective. We have implemented fully optimized CPU kernels for all the commonly used CV modules on channels last memory format, taking care of both float32 and bfloat16.

Run Bfloat16 with Auto Mixed Precision

To run model on bfloat16, typically user can either explicitly convert the data and model to bfloat16, for example:

# with explicit conversion
input = input.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)

or utilize torch.amp (Automatic Mixed Precision) package. The autocast instance serves as context managers or decorators that allow regions of your script to run in mixed precision, for example:

# with AMP
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
    output = model(input)

Generally, the explicit conversion approach and AMP approach have similar performance. Even though, we recommend run bfloat16 models with AMP, because:

  • Better user experience with automatic fallback: If your script includes operators that don’t have bfloat16 support, autocast will implicitly convert them back to float32 while the explicit converted model will give a runtime error.

  • Mixed data type for activation and parameters: Unlike the explicit conversion which converts all the model parameters to bfloat16, AMP mode will run in mixed data type. To be specific, input/output will be kept in bfloat16 while parameters, e.g., weight/bias, will be kept in float32. The mixed data type of activation and parameters will help improve performance while maintaining the accuracy.

Performance Gains

We benchmarked inference performance of TorchVision models on Intel® Xeon® Platinum 8380H CPU @ 2.90GHz (codenamed Cooper Lake), single instance per socket (batch size = 2 x number of physical cores). Results show that bfloat16 has 1.4x to 2.2x performance gain over float32.

The performance boost of bfloat16 over float32 primarily comes from 3 aspects:

  • The compute intensive operators take advantage of the new bfloat16 native instruction VDPBF16PS which doubles the hardware compute throughput.
  • Bfloat16 have only half the memory footprint of float32, so theoretically the memory bandwidth intensive operators will be twice faster.
  • On Channels Last, we intentionally keep the same parallelization scheme for all the memory format aware operators (can’t do this on Channels First though), which increases the data locality when passing each layer’s output to the next. Basically, it keeps the data closer to CPU cores while data would reside in cache anyway. And bfloat16 will have a higher cache hit rate compared with float32 in such scenarios due to smaller memory footprint.

Conclusion & Future Work

In this blog, we introduced recent software optimizations on bfloat16 introduced in PyTorch 1.12. Results on the 3rd Gen Intel® Xeon® Scalable processor show that bfloat16 has 1.4x to 2.2x performance gain over float32 on the TorchVision models. Further improvement is expected on the next generation of Intel® Xeon® Scalable Processors with AMX instruction support. Though the performance number for this blog is collected with TorchVision models, the benefit is broad across all topologies. And we will continue to extend the bfloat16 optimization effort to a broader scope in the future!

Acknowledgement

The results presented in this blog is a joint effort of Meta and Intel PyTorch team. Special thanks to Vitaly Fedyunin and Wei Wei from Meta who spent precious time and gave substantial assistance! Together we made one more step on the path of improving the PyTorch CPU eco system.

Reference

Read More

Introducing the PlayTorch app: Rapidly Create Mobile AI Experiences

In December, we announced PyTorch Live, a toolkit for building AI-powered mobile prototypes in minutes. The initial release included a command-line interface to set up a development environment and an SDK for building AI-powered experiences in React Native. Today, we’re excited to share that PyTorch Live will now be known as PlayTorch. This new release provides an improved and simplified developer experience. PlayTorch development is independent from the PyTorch project and the PlayTorch code repository is moving into the Meta Research GitHub organization.

A New Workflow: The PlayTorch App

The PlayTorch team is excited to announce that we have partnered with Expo to change the way AI powered mobile experiences are built. Our new release simplifies the process of building mobile AI experiences by eliminating the need for a complicated development environment. You will now be able to build cross platform AI powered prototypes from the very browser you are using to read this blog.

In order to make this happen, we are releasing the PlayTorch app which is able to run AI-powered experiences built in the Expo Snack web based code editor.

The PlayTorch app can be downloaded from the Apple App Store and Google Play Store. With the app installed, you can head over to playtorch.dev/snack and write the code for your AI-powered PlayTorch Snack. When you want to try what you’ve built, you can use the PlayTorch app’s QR code scanner to scan the QR code on the Snack page and load the code to your device.

NOTE: PlayTorch Snacks will not work in the Expo Go app.

More to Explore in the PlayTorch App

AI Demos

The PlayTorch app comes with several examples of how you can build AI powered experiences with a variety of different machine learning models from object detection to natural language processing. See what can be built with the PlayTorch SDK and be inspired to make something of your own as you play with the examples.

Sharing Your Creations

Any PlayTorch Snack that you run in the PlayTorch app can be shared with others in an instant. When they open the link on their device, the PlayTorch app will instantly load what you’ve built from the cloud so they can experience it first hand.

When you have something you want to share, let us know on Discord or Twitter or embed the PlayTorch Snack on your own webpage.

SDK Overhaul

We learned a lot from the community after our initial launch in December and have been hard at work over the past several months to make the PlayTorch SDK (formerly known as PyTorch Live) simple, performant, and robust. In our initial version, the SDK relied on config files to define how a model ingested and output data.

Today, we are happy to announce the next version of our SDK can handle data processing in JavaScript for your prototypes with the new PlayTorch API that leverages the JavaScript Interface (JSI) to directly call C++ code. Not only have we completely redone the way you can interact with models, but we have also greatly expanded the variety of supported model architectures.

A New Data Processing API for Prototyping

With this JSI API, we now allow users direct access to tensors (data format for machine learning). Instead of only having access to predefined transformations, you can now manipulate tensors however you would like for your prototypes.

No more switching back and forth between code and config. You will now be able to write everything in JavaScript and have access to all of the type annotations and autocomplete features available to you in those languages.

Check out our tutorials to see the new Data Processing API in action, take a deeper dive in the API docs, or inspect the code yourself on GitHub.

Expanded Use Cases

With the new version of the SDK, we have added support for several cutting edge models.

Image-to-image transformations are now supported thanks to our robust JSI API, so you can see what your world would look like if it were an anime.

Translate French to English with an AI powered translator using the Seq2Seq model.

Use DeepLab V3 to segment images!

Start Playing

If you want to start creating AI experiences yourself, head over to playtorch.dev and try out our tutorials. Each tutorial will guide you through building a simple AI powered experience that you can instantly run on your phone and share with others.

How to Get Support

Join us on Discord, collaborate with us on GitHub, or follow us on Twitter. Got questions or feedback? We’d love to hear from you!

Read More

What Every User Should Know About Mixed Precision Training in PyTorch

Efficient training of modern neural networks often relies on using lower precision data types. Peak float16 matrix multiplication and convolution performance is 16x faster than peak float32 performance on A100 GPUs. And since the float16 and bfloat16 data types are only half the size of float32 they can double the performance of bandwidth-bound kernels and reduce the memory required to train a network, allowing for larger models, larger batches, or larger inputs. Using a module like torch.amp (short for “Automated Mixed Precision”) makes it easy to get the speed and memory usage benefits of lower precision data types while preserving convergence behavior.

Going faster and using less memory is always advantageous – deep learning practitioners can test more model architectures and hyperparameters, and larger, more powerful models can be trained. Training very large models like those described in Narayanan et al. and Brown et al. (which take thousands of GPUs months to train even with expert handwritten optimizations) is infeasible without using mixed precision.

We’ve talked about mixed precision techniques before (here, here, and here), and this blog post is a summary of those techniques and an introduction if you’re new to mixed precision.

Mixed Precision Training in Practice

Mixed precision training techniques – the use of the lower precision float16 or bfloat16 data types alongside the float32 data type – are broadly applicable and effective. See Figure 1 for a sampling of models successfully trained with mixed precision, and Figures 2 and 3 for example speedups using torch.amp.

Figure 1: Sampling of DL Workloads Successfully Trained with float16 (Source).

Figure 2: Performance of mixed precision training using torch.amp on NVIDIA 8xV100 vs. float32 training on 8xV100 GPU. Bars represent the speedup factor of torch.amp over float32.
(Higher is better.) (Source).

Figure 3. Performance of mixed precision training using torch.amp on NVIDIA 8xA100 vs. 8xV100 GPU. Bars represent the speedup factor of A100 over V100.
(Higher is Better.) (Source).

See the NVIDIA Deep Learning Examples repository for more sample mixed precision workloads.

Similar performance charts can be seen in 3D medical image analysis, gaze estimation, video synthesis, conditional GANs, and convolutional LSTMs. Huang et al. showed that mixed precision training is 1.5x to 5.5x faster over float32 on V100 GPUs, and an additional 1.3x to 2.5x faster on A100 GPUs on a variety of networks. On very large networks the need for mixed precision is even more evident. Narayanan et al. reports that it would take 34 days to train GPT-3 175B on 1024 A100 GPUs (with a batch size of 1536), but it’s estimated it would take over a year using float32!

Getting Started With Mixed Precision Using torch.amp

torch.amp, introduced in PyTorch 1.6, makes it easy to leverage mixed precision training using the float16 or bfloat16 dtypes. See this blog post, tutorial, and documentation for more details. Figure 4 shows an example of applying AMP with grad scaling to a network.

import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()

for data, label in data_iter:
   optimizer.zero_grad()
   # Casts operations to mixed precision
   with torch.amp.autocast(device_type=“cuda”, dtype=torch.float16):
      loss = model(data)

   # Scales the loss, and calls backward()
   # to create scaled gradients
   scaler.scale(loss).backward()

   # Unscales gradients and calls
   # or skips optimizer.step()
   scaler.step(optimizer)

   # Updates the scale for next iteration
   scaler.update()

Figure 4: AMP recipe

Picking The Right Approach

Out-of-the-box mixed precision training with either float16 or bfloat16 is effective at speeding up the convergence of many deep learning models, but some models may require more careful numerical accuracy management. Here are some options:

  • Full float32 precision. Floating point tensors and modules are created in float32 precision by default in PyTorch, but this is a historic artifact not representative of training most modern deep learning networks. It’s rare that networks need this much numerical accuracy.
  • Enabling TensorFloat32 (TF32) mode. On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. See the Accelerating AI Training with NVIDIA TF32 Tensor Cores blog post for more details. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too (see the documentation here for how to do so). It can significantly speed up computations with typically negligible loss of numerical accuracy.
  • Using torch.amp with bfloat16 or float16. Both these low precision floating point data types are usually comparably fast, but some networks may only converge with one vs the other. If a network requires more precision it may need to use float16, and if a network requires more dynamic range it may need to use bfloat16, whose dynamic range is equal to that of float32. If overflows are observed, for example, then we suggest trying bfloat16.

There are even more advanced options than those presented here, like using torch.amp’s autocasting for only parts of a model, or managing mixed precision directly. These topics are largely beyond the scope of this blog post, but see the “Best Practices” section below.

Best Practices

We strongly recommend using mixed precision with torch.amp or the TF32 mode (on Ampere and later CUDA devices) whenever possible when training a network. If one of those approaches doesn’t work, however, we recommend the following:

  • High Performance Computing (HPC) applications, regression tasks, and generative networks may simply require full float32 IEEE precision to converge as expected.
  • Try selectively applying torch.amp. In particular we recommend first disabling it on regions performing operations from the torch.linalg module or when doing pre- or post-processing. These operations are often especially sensitive. Note that TF32 mode is a global switch and can’t be used selectively on regions of a network. Enable TF32 first to check if a network’s operators are sensitive to the mode, otherwise disable it.
  • If you encounter type mismatches while using torch.amp we don’t suggest inserting manual casts to start. This error is indicative of something being off with the network, and it’s usually worth investigating first.
  • Figure out by experimentation if your network is sensitive to range and/or precision of a format. For example fine-tuning bfloat16-pretrained models in float16 can easily run into range issues in float16 because of the potentially large range from training in bfloat16, so users should stick with bfloat16 fine-tuning if the model was trained in bfloat16.
  • The performance gain of mixed precision training can depend on multiple factors (e.g. compute-bound vs memory-bound problems) and users should use the tuning guide to remove other bottlenecks in their training scripts. Although having similar theoretical performance benefits, BF16 and FP16 can have different speeds in practice. It’s recommended to try the mentioned formats and use the one with best speed while maintaining the desired numeric behavior.

For more details, refer to the AMP Tutorial, Training Neural Networks with Tensor Cores, and see the post “More In-Depth Details of Floating Point Precision” on PyTorch Dev Discussion.

Conclusion

Mixed precision training is an essential tool for training deep learning models on modern hardware, and it will become even more important in the future as the performance gap between lower precision operations and float32 continues to grow on newer hardware, as reflected in Figure 5.

Figure 5: Relative peak throughput of float16 (FP16) vs float32 matrix multiplications on Volta and Ampere GPUs. On Ampere relative peak throughput for the TensorFloat32 (TF32) mode and bfloat16 matrix multiplications are shown, too. The relative peak throughput of low precision data types like float16 and bfloat16 vs. float32 matrix multiplications is expected to grow as new hardware is released.

PyTorch’s torch.amp module makes it easy to get started with mixed precision, and we highly recommend using it to train faster and reduce memory usage. torch.amp supports both float16 and bfloat16 mixed precision.

There are still some networks that are tricky to train with mixed precision, and for these networks we recommend trying TF32 accelerated matrix multiplications on Ampere and later CUDA hardware. Networks are rarely so precision sensitive that they require full float32 precision for every operation.

If you have questions or suggestions for torch.amp or mixed precision support in PyTorch then let us know by posting to the mixed precision category on the PyTorch Forums or filing an issue on the PyTorch GitHub page.

Read More

Case Study: PathAI Uses PyTorch to Improve Patient Outcomes with AI-powered Pathology

​PathAI is the leading provider of AI-powered technology tools and services for pathology (the study of disease). Our platform was built to enable substantial improvements to the accuracy of diagnosis and the measurement of therapeutic efficacy for complex diseases, leveraging modern approaches in machine learning like image segmentation, graph neural networks, and multiple instance learning.

Traditional manual pathology is prone to subjectivity and observer variability that can negatively affect diagnoses and drug development trials. Before we dive into how we use PyTorch to improve our diagnosis workflow, let us first lay out the traditional analog Pathology workflow without machine learning.

How Traditional Biopharma Works

There are many avenues that biopharma companies take to discover novel therapeutics or diagnostics. One of those avenues relies heavily on the analysis of pathology slides to answer a variety of questions: how does a particular cellular communication pathway work? Can a specific disease state be linked to the presence or lack of a particular protein? Why did a particular drug in a clinical trial work for some patients but not others? Might there be an association between patient outcomes and a novel biomarker?

To help answer these questions, biopharma companies rely on expert pathologists to analyze slides and help evaluate the questions they might have. 

As you might imagine, it takes an expert board certified pathologist to make accurate interpretations and diagnosis. In one study, a single biopsy result was given to 36 different pathologists and the outcome was 18 different diagnoses varying in severity from no treatment to aggressive treatment necessary. Pathologists also often solicit feedback from colleagues in difficult edge cases. Given the complexity of the problem, even with expert training and collaboration, pathologists can still have a hard time making a correct diagnosis. This potential variance can be the difference between a drug being approved and it failing the clinical trial.

How PathAI utilizes machine learning to power drug development

PathAI develops machine learning models which provide insights for drug development R&D, for powering clinical trials, and for making diagnoses. To this end, PathAI leverages PyTorch for slide level inference using a variety of methods including graph neural networks (GNN) as well as multiple instance learning. In this context, “slides” refers to full size scanned images of glass slides, which are pieces of glass with a thin slice of tissue between them, stained to show various cell formations. PyTorch enables our teams using these different methodologies to share a common framework which is robust enough to work in all the conditions we need. PyTorch’s high level, imperative, and pythonic syntax allows us to prototype models quickly and then take those models to scale once we have the results we want. 

Multi-instance learning on gigabyte images

One of the uniquely challenging aspects of applying ML to pathology is the immense size of the images. These digital slides can often be 100,000 x 100,000 pixels or more in resolution and gigabytes in size. Loading the full image in GPU memory and applying traditional computer vision algorithms on them is an almost impossible task. It also takes both a considerable amount of time and resources to have a full slide image (100k x 100k) annotated, especially when annotators need to be domain experts (board-certified pathologists). We often build models to predict image-level labels, like the presence of cancer, on a patient slide which covers a few thousand pixels in the whole image. The cancerous area is sometimes a tiny fraction of the entire slide, which makes the ML problem similar to finding a needle in a haystack. On the other hand, some problems like the prediction of certain histological biomarkers require an aggregation of information from the whole slide which is again hard due to the size of the images. All these factors add significant algorithmic, computational, and logistical complexity when applying ML techniques to pathology problems.

Breaking down the image into smaller patches, learning patch representations, and then pooling those representations to predict an image-level label is one way to solve this problem as is depicted in the image below. One popular method for doing this is called Multiple Instance Learning (MIL). Each patch is considered an ‘instance’ and a set of patches forms a ‘bag’. The individual patch representations are pooled together to predict a final bag-level label. Algorithmically, the individual patch instances in the bag do not require labels and hence allow us to learn bag-level labels in a weakly-supervised way. They also use permutation invariant pooling functions which make the prediction independent of the order of patches and allows for an efficient aggregation of information. Typically, attention based pooling functions are used which not only allow for efficient aggregation but also provide attention values for each patch in the bag. These values indicate the importance of the corresponding patch in the prediction and can be visualized to better understand the model predictions. This element of interpretability can be very important to drive adoption of these models in the real world and we use variations like Additive MIL models to enable such spatial explainability. Computationally, MIL models circumvent the problem of applying neural networks to large image sizes since patch representations are obtained independently of the size of the image.

At PathAI, we use custom MIL models based on deep nets to predict image-level labels. The overview of this process is as follows:

  1. Select patches from a slide using different sampling approaches.
  2. Construct a bag of patches based on random sampling or heuristic rules.
  3. Generate patch representations for each instance based on pre-trained models or large-scale representation learning models.
  4. Apply permutation invariant pooling functions to get the final slide-level score.

Now that we have walked through some of the high-level details around MIL in PyTorch, let’s look at some code to see how simple it is to go from ideation to code in production with PyTorch. We begin by defining a sampler, transformations, and our MIL dataset:

# Create a bag sampler which randomly samples patches from a slide
bag_sampler = RandomBagSampler(bag_size=12)

# Setup the transformations
crop_transform = FlipRotateCenterCrop(use_flips=True)

# Create the dataset which loads patches for each bag
train_dataset = MILDataset(
  bag_sampler=bag_sampler,
  samples_loader=sample_loader,
  transform=crop_transform,
)

After we have defined our sampler and dataset, we need to define the model we will actually train with said dataset. PyTorch’s familiar model definition syntax makes this easy to do while also allowing us to create bespoke models at the same time.

classifier = DefaultPooledClassifier(hidden_dims=[256, 256], input_dims=1024, output_dims=1)

pooling = DefaultAttentionModule(
  input_dims=1024,
  hidden_dims=[256, 256],
  output_activation=StableSoftmax()
)

# Define the model which is a composition of the featurizer, pooling module and a classifier
model = DefaultMILGraph(featurizer=ShuffleNetV2(), classifier=classifier, pooling = pooling)

Since these models are trained end-to-end, they offer a powerful way to go directly from a gigapixel whole slide image to a single label. Due to their wide applicability to different biological problems, two aspects of their implementation and deployment are important:

  1. Configurable control over each part of the pipeline including the data loaders, the modular parts of the model, and their interaction with each other.
  2. Ability to rapidly iterate through the ideate-implement-experiment-productionize loop.

PyTorch has various advantages when it comes to MIL modeling. It offers an intuitive way to create dynamic computational graphs with flexible control flow which is great for rapid research experimentation. The map-style datasets, configurable sampler and batch-samplers allow us to customize how we construct bags of patches, enabling faster experimentation. Since MIL models are IO heavy, data parallelism and pythonic data loaders make the task very efficient and user friendly. Lastly, the object-oriented nature of PyTorch enables building of reusable modules which aid in the rapid experimentation, maintainable implementation and ease of building compositional components of the pipeline.

Exploring spatial tissue organization with GNNs in PyTorch

In both healthy and diseased tissue, the spatial arrangement and structure of cells can oftentimes be as important as the cells themselves. For example, when assessing lung cancers, pathologists try to look at the overall grouping and structure of tumor cells (do they form solid sheets? Or do they occur in smaller, localized clusters?) to determine if the cancer belongs to specific subtypes which can have vastly different prognosis. Such spatial relationships between cells and other tissue structures can be modeled using graphs to capture tissue topology and cellular composition at the same time. Graph Neural Networks (GNNs) allow learning spatial patterns within these graphs that relate to other clinical variables, for example overexpression of genes in certain cancers.

In late 2020, when PathAI started using GNNs on tissue samples, PyTorch had the best and most mature support for GNN functionality via the PyG package. This made PyTorch the natural choice for our team given that GNN models were something that we knew would be an important ML concept we wanted to explore. 

One of the main value-adds of GNN’s in the context of tissue samples is that the graph itself can uncover spatial relationships that would otherwise be very difficult to find by visual inspection alone. In our recent AACR publication, we showed that by using GNNs, we can better understand the way the presence of immune cell aggregates (specifically tertiary lymphoid structures, or TLS) in the tumor microenvironment can influence patient prognosis. In this case, the GNN approach was used to predict expression of genes associated with the presence of TLS, and identify histological features beyond the TLS region itself that are relevant to TLS. Such insights into gene expression are difficult to identify from tissue sample images when unassisted by ML models. 

One of the most promising GNN variations we have had success with is self attention graph pooling. Let’s take a look at how we define our Self Attention Graph Pooling (SAGPool) model using PyTorch and PyG:

class SAGPool(torch.nn.Module):
  def __init__(self, ...):
    super().__init__()
    self.conv1 = GraphConv(in_features, hidden_features, aggr='mean')
    self.convs = torch.nn.ModuleList()
    self.pools = torch.nn.ModuleList()
    self.convs.extend([GraphConv(hidden_features, hidden_features, aggr='mean') for i in range(num_layers - 1)])
    self.pools.extend([SAGPooling(hidden_features, ratio, GNN=GraphConv, min_score=min_score) for i in range((num_layers) // 2)])
    self.jump = JumpingKnowledge(mode='cat')
    self.lin1 = Linear(num_layers * hidden_features, hidden_features)
    self.lin2 = Linear(hidden_features, out_features)
    self.out_activation = out_activation
    self.dropout = dropout

In the above code, we begin by defining a single convolutional graph layer and then add two module list layers which allow us to pass in a variable number of layers. We then take our empty module list and append a variable number of GraphConv layers followed by a variable number of SAGPooling layers. We finish up our SAGPool definition by adding a JumpingKnowledge Layer, two linear layers, our activation function, and our dropout value. PyTorch’s intuitive syntax allows us to abstract away the complexity of working with state of the art methods like SAG Poolings while also maintaining the common approach to model development we are familiar with.

Models like our SAG Pool one described above are just one example of how GNNs with PyTorch are allowing us to explore new and novel ideas. We also recently explored multimodal CNN – GNN hybrid models which ended up being 20% more accurate than traditional Pathologist consensus scores. These innovations and interplay between traditional CNNs and GNNs are again enabled by the short research to production model development loop.

Improving Patient Outcomes

In order to achieve our mission of improving patient outcomes with AI-powered pathology, PathAI needs to rely on an ML development framework that (1) facilitates quick iteration and easy extension (i.e. Model configuration as code) during initial phases of development and exploration (2) scales model training and inference to massive images (3) easily and robustly serves models for production uses of our products (in clinical trials and beyond). As we’ve demonstrated, PyTorch offers us all of these capabilities and more. We are incredibly excited about the future of PyTorch and cannot wait to see what other impactful challenges we can solve using the framework.

Read More

PyTorch 1.12: TorchArrow, Functional API for Modules and nvFuser, are now available

We are excited to announce the release of PyTorch 1.12 (release note)! This release is composed of over 3124 commits, 433 contributors. Along with 1.12, we are releasing beta versions of AWS S3 Integration, PyTorch Vision Models on Channels Last on CPU, Empowering PyTorch on Intel® Xeon® Scalable processors with Bfloat16 and FSDP API. We want to sincerely thank our dedicated community for your contributions.

Summary:

  • Functional APIs to functionally apply module computation with a given set of parameters
  • Complex32 and Complex Convolutions in PyTorch
  • DataPipes from TorchData fully backward compatible with DataLoader
  • functorch with improved coverage for APIs
  • nvFuser a deep learning compiler for PyTorch
  • Changes to float32 matrix multiplication precision on Ampere and later CUDA hardware
  • TorchArrow, a new beta library for machine learning preprocessing over batch data

Frontend APIs

Introducing TorchArrow

We’ve got a new Beta release ready for you to try and use: TorchArrow. This is a library for machine learning preprocessing over batch data. It features a performant and Pandas-style, easy-to-use API in order to speed up your preprocessing workflows and development.

Currently, it provides a Python DataFrame interface with the following features:

  • High-performance CPU backend, vectorized and extensible User-Defined Functions (UDFs) with Velox
  • Seamless handoff with PyTorch or other model authoring, such as Tensor collation and easily plugging into PyTorch DataLoader and DataPipes
  • Zero copy for external readers via Arrow in-memory columnar format

For more details, please find our 10-min tutorial, installation instructions, API documentation, and a prototype for data preprocessing in TorchRec.

(Beta) Functional API for Modules

PyTorch 1.12 introduces a new beta feature to functionally apply Module computation with a given set of parameters. Sometimes, the traditional PyTorch Module usage pattern that maintains a static set of parameters internally is too restrictive. This is often the case when implementing algorithms for meta-learning, where multiple sets of parameters may need to be maintained across optimizer steps.

The new torch.nn.utils.stateless.functional_call() API allows for:

  • Module computation with full flexibility over the set of parameters used
  • No need to reimplement your module in a functional way
  • Any parameter or buffer present in the module can be swapped with an externally-defined value for use in the call. Naming for referencing parameters / buffers follows the fully-qualified form in the module’s state_dict()

Example:

import torch
from torch import nn
from torch.nn.utils.stateless import functional_call

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 3)
        self.bn = nn.BatchNorm1d(3)
        self.fc2 = nn.Linear(3, 3)

    def forward(self, x):
        return self.fc2(self.bn(self.fc1(x)))

m = MyModule()

# Define parameter / buffer values to use during module computation.
my_weight = torch.randn(3, 3, requires_grad=True)
my_bias = torch.tensor([1., 2., 3.], requires_grad=True)
params_and_buffers = {
    'fc1.weight': my_weight,
    'fc1.bias': my_bias,
    # Custom buffer values can be used too.
    'bn.running_mean': torch.randn(3),
}

# Apply module computation to the input with the specified parameters / buffers.
inp = torch.randn(5, 3)
output = functional_call(m, params_and_buffers, inp)

(Beta) Complex32 and Complex Convolutions in PyTorch

PyTorch today natively supports complex numbers, complex autograd, complex modules, and numerous complex operations, including linear algebra and Fast Fourier Transform (FFT) operators. Many libraries, including torchaudio and ESPNet, already make use of complex numbers in PyTorch, and PyTorch 1.12 further extends complex functionality with complex convolutions and the experimental complex32 (“complex half”) data type that enables half precision FFT operations. Due to the bugs in CUDA 11.3 package, we recommend using CUDA 11.6 package from wheels if you are using complex numbers.

(Beta) Forward-mode Automatic Differentiation

Forward-mode AD allows the computation of directional derivatives (or equivalently, Jacobian-vector products) eagerly in the forward pass. PyTorch 1.12 significantly improves the operator coverage for forward-mode AD. See our tutorial for more information.

TorchData

BC DataLoader + DataPipe

`DataPipe` from TorchData becomes fully backward compatible with the existing `DataLoader` regarding shuffle determinism and dynamic sharding in both multiprocessing and distributed environments. For more details, please check out the tutorial.

(Beta) AWS S3 Integration

DataPipes based on AWSSDK have been integrated into TorchData. It provides the following features backed by native AWSSDK:

  • Retrieve list of urls from each S3 bucket based on prefix
    • Support timeout to prevent hanging indefinitely
    • Support to specify S3 bucket region
  • Load data from S3 urls
    • Support buffered and multi-part download
    • Support to specify S3 bucket region

AWS native DataPipes are still in the beta phase. And, we will keep tuning them to improve their performance.

(Prototype) DataLoader2

DataLoader2 became available in prototype mode. We are introducing new ways to interact between DataPipes, DataLoading API, and backends (aka ReadingServices). Feature is stable in terms of API, but functionally not complete yet. We welcome early adopters and feedback, as well as potential contributors.

For more details, please checkout the link.

functorch

Inspired by Google JAX, functorch is a library that offers composable vmap (vectorization) and autodiff transforms. It enables advanced autodiff use cases that would otherwise be tricky to express in PyTorch. Examples of these include:

We’re excited to announce functorch 0.2.0 with a number of improvements and new experimental features.

Significantly improved coverage

We significantly improved coverage for functorch.jvp (our forward-mode autodiff API) and other APIs that rely on it (functorch.{jacfwd, hessian}).

(Prototype) functorch.experimental.functionalize

Given a function f, functionalize(f) returns a new function without mutations (with caveats). This is useful for constructing traces of PyTorch functions without in-place operations. For example, you can use make_fx(functionalize(f)) to construct a mutation-free trace of a pytorch function. To learn more, please see the documentation.

For more details, please see our installation instructions, documentation, tutorials, and release notes.

Performance Improvements

Introducing nvFuser, a deep learning compiler for PyTorch

In PyTorch 1.12, Torchscript is updating its default fuser (for Volta and later CUDA accelerators) to nvFuser, which supports a wider range of operations and is faster than NNC, the previous fuser for CUDA devices. A soon to be published blog post will elaborate on nvFuser and show how it speeds up training on a variety of networks.

See the nvFuser documentation for more details on usage and debugging.

Changes to float32 matrix multiplication precision on Ampere and later CUDA hardware

PyTorch supports a variety of “mixed precision” techniques, like the torch.amp (Automated Mixed Precision) module and performing float32 matrix multiplications using the TensorFloat32 datatype on Ampere and later CUDA hardware for faster internal computations. In PyTorch 1.12 we’re changing the default behavior of float32 matrix multiplications to always use full IEEE fp32 precision, which is more precise but slower than using the TensorFloat32 datatype for internal computation. For devices with a particularly high ratio of TensorFloat32 to float32 throughput such as A100, this change in defaults can result in a large slowdown.

If you’ve been using TensorFloat32 matrix multiplications then you can continue to do so by setting torch.backends.cuda.matmul.allow_tf32 = True

which is supported since PyTorch 1.7. Starting in PyTorch 1.12 the new matmul precision API can be used, too: torch.set_float32_matmul_precision(“highest”|”high”|”medium”)

To reiterate, PyTorch’s new default is “highest” precision for all device types. We think this provides better consistency across device types for matrix multiplications. Documentation for the new precision API can be found here. Setting the “high” or “medium” precision types will enable TensorFloat32 on Ampere and later CUDA devices. If you’re updating to PyTorch 1.12 then to preserve the current behavior and faster performance of matrix multiplications on Ampere devices, set precision to “high”.

Using mixed precision techniques is essential for training many modern deep learning networks efficiently, and if you’re already using torch.amp this change is unlikely to affect you. If you’re not familiar with mixed precision training then see our soon to be published “What Every User Should Know About Mixed Precision Training in PyTorch” blogpost.

(Beta) Accelerating PyTorch Vision Models with Channels Last on CPU

Memory formats have a significant impact on performance when running vision models, generally Channels Last is more favorable from a performance perspective due to better data locality. 1.12 includes fundamental concepts of memory formats and demonstrates performance benefits using Channels Last on popular PyTorch vision models on Intel® Xeon® Scalable processors.

  • Enables Channels Last memory format support for the commonly used operators in CV domain on CPU, applicable for both inference and training
  • Provides native level optimization on Channels Last kernels from ATen, applicable for both AVX2 and AVX512
  • Delivers 1.3x to 1.8x inference performance gain over Channels First for TorchVision models on Intel® Xeon® Ice Lake (or newer) CPUs

(Beta) Empowering PyTorch on Intel® Xeon® Scalable processors with Bfloat16

Reduced precision numeric formats like bfloat16 improves PyTorch performance across multiple deep learning training workloads. PyTorch 1.12 includes the latest software enhancements on bfloat16 which applies to a broader scope of user scenarios and showcases even higher performance gains. The main improvements include:

  • 2x hardware compute throughput vs. float32 with the new bfloat16 native instruction VDPBF16PS, introduced on Intel® Xeon® Cooper Lake CPUs
  • 1/2 memory footprint of float32, faster speed for memory bandwidth intensive operators
  • 1.4x to 2.2x inference performance gain over float32 for TorchVision models on Intel® Xeon® Cooper Lake (or newer) CPUs

(Prototype) Introducing Accelerated PyTorch Training on Mac

With the PyTorch 1.12 release, developers and researchers can now take advantage of Apple silicon GPUs for significantly faster model training. This unlocks the ability to perform machine learning workflows like prototyping and fine-tuning locally, right on Mac. Accelerated GPU training is enabled using Apple’s Metal Performance Shaders (MPS) as a backend. The benefits include performance speedup from accelerated GPU training and the ability to train larger networks or batch sizes locally. Learn more here.

Accelerated GPU training and evaluation speedups over CPU-only (times faster)

Alongside the new MPS device support, the M1 binaries for Core and Domain libraries that have been available for the last few releases are now an official prototype feature. These binaries can be used to run PyTorch natively on Apple Silicon.

(Prototype) BetterTransformer: Fastpath execution for Transformer Encoder Inference

PyTorch now supports CPU and GPU fastpath implementations (“BetterTransformer”) for several Transformer Encoder modules including TransformerEncoder, TransformerEncoderLayer, and MultiHeadAttention (MHA). The BetterTransformer fastpath architecture Better Transformer is consistently faster – 2x for many common execution scenarios, depending on model and input characteristics. The new BetterTransformer-enabled modules are API compatible with previous releases of the PyTorch Transformer API and will accelerate existing models if they meet fastpath execution requirements, as well as read models trained with previous versions of PyTorch. PyTorch 1.12 includes:

  • BetterTransformer integration for Torchtext’s pretrained RoBERTa and XLM-R models
  • Torchtext which builds on the PyTorch Transformer API
  • Fastpath execution for improved performance by reducing execution overheads with fused kernels which combines multiple operators into a single kernel
  • Option to achieve additional speedups by taking advantage of data sparsity during the processing of padding tokens in natural-language processing (by setting enable_nested_tensor=True when creating a TransformerEncoder)
  • Diagnostics to help users understand why fastpath execution did not occur

Distributed

(Beta) Fully Sharded Data Parallel (FSDP) API

FSDP API helps easily scale large model training by sharding a model’s parameters, gradients and optimizer states across data parallel workers while maintaining the simplicity of data parallelism. The prototype version was released in PyTorch 1.11 with a minimum set of features that helped scaling tests of models with up to 1T parameters.

In this beta release, FSDP API added the following features to support various production workloads. Highlights of the the newly added features in this beta release include:

  1. Universal sharding strategy API – Users can easily change between sharding strategies with a single line change, and thus compare and use DDP (only data sharding), FSDP (full model and data sharding), or Zero2 (only sharding of optimizer and gradients) to optimize memory and performance for their specific training needs
  2. Fine grained mixed precision policies – Users can specify a mix of half and full data types (bfloat16, fp16 or fp32) for model parameters, gradient communication, and buffers via mixed precision policies. Models are automatically saved in fp32 to allow for maximum portability
  3. Transformer auto wrapping policy – allows for optimal wrapping of Transformer based models by registering the models layer class, and thus accelerated training performance
  4. Faster model initialization using device_id init – initialization is performed in a streaming fashion to avoid OOM issues and optimize init performance vs CPU init
  5. Rank0 streaming for full model saving of larger models – Fully sharded models can be saved by all GPU’s streaming their shards to the rank 0 GPU, and the model is built in full state on the rank 0 CPU for saving

For more details and example code, please checkout the documentation and the tutorial.

Thanks for reading, If you’re interested in these updates and want to join the PyTorch community, we encourage you to join the discussion forums and open GitHub issues. To get the latest news from PyTorch, follow us on Twitter, Medium, YouTube, and LinkedIn.

Cheers!

Team PyTorch

Read More

New library updates in PyTorch 1.12

We are bringing a number of improvements to the current PyTorch libraries, alongside the PyTorch 1.12 release. These updates demonstrate our focus on developing common and extensible APIs across all domains to make it easier for our community to build ecosystem projects on PyTorch.

Summary:

  • TorchVision – Added multi-weight support API, new architectures, model variants, and pretrained weight. See the release notes here.
  • TorchAudio – Introduced beta features including a streaming API, a CTC beam search decoder, and new beamforming modules and methods. See the release notes here.
  • TorchText – Extended support for scriptable BERT tokenizer and added datasets for GLUE benchmark. See the release notes here.
  • TorchRec – Added EmbeddingModule benchmarks, examples for TwoTower Retrieval, inference and sequential embeddings, metrics, improved planner and demonstrated integration with production components. See the release notes here.
  • TorchX – Launch PyTorch trainers developed on local workspaces onto five different types of schedulers. See the release notes here.
  • FBGemm – Added and improved kernels for Recommendation Systems inference workloads, including table batched embedding bag, jagged tensor operations, and other special-case optimizations.

TorchVision v0.13

Multi-weight support API

TorchVision v0.13 offers a new Multi-weight support API for loading different weights to the existing model builder methods:

from torchvision.models import *

# Old weights with accuracy 76.130%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# New weights with accuracy 80.858%
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

# Best available weights (currently alias for IMAGENET1K_V2)
# Note that these weights may change across versions
resnet50(weights=ResNet50_Weights.DEFAULT)

# Strings are also supported
resnet50(weights="IMAGENET1K_V2")

# No weights - random initialization
resnet50(weights=None)

The new API bundles along with the weights important details such as the preprocessing transforms and meta-data such as labels. Here is how to make the most out of it:

from torchvision.io import read_image
from torchvision.models import resnet50, ResNet50_Weights

img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")

# Step 1: Initialize model with the best available weights
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(img).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score:.1f}%")

You can read more about the new API in the docs. To provide your feedback, please use this dedicated Github issue.

New architectures and model variants

Classification

The Swin Transformer and EfficienetNetV2 are two popular classification models which are often used for downstream vision tasks. This release includes 6 pre-trained weights for their classification variants. Here is how to use the new models:

import torch
from torchvision.models import *

image = torch.rand(1, 3, 224, 224)
model = swin_t(weights="DEFAULT").eval()
prediction = model(image)

image = torch.rand(1, 3, 384, 384)
model = efficientnet_v2_s(weights="DEFAULT").eval()
prediction = model(image)

In addition to the above, we also provide new variants for existing architectures such as ShuffleNetV2, ResNeXt and MNASNet. The accuracies of all the new pre-trained models obtained on ImageNet-1K are seen below:

Model Acc@1 Acc@5
swin_t 81.474 95.776
swin_s 83.196 96.36
swin_b 83.582 96.64
efficientnet_v2_s 84.228 96.878
efficientnet_v2_m 85.112 97.156
efficientnet_v2_l 85.808 97.788
resnext101_64x4d 83.246 96.454
resnext101_64x4d (quantized) 82.898 96.326
shufflenet_v2_x1_5 72.996 91.086
shufflenet_v2_x1_5 (quantized) 72.052 0.700
shufflenet_v2_x2_0 76.230 93.006
shufflenet_v2_x2_0 (quantized) 75.354 92.488
mnasnet0_75 71.180 90.496
mnas1_3 76.506 93.522

We would like to thank Hu Ye for contributing to TorchVision the Swin Transformer implementation.

(BETA) Object Detection and Instance Segmentation

We have introduced 3 new model variants for RetinaNet, FasterRCNN and MaskRCNN that include several post-paper architectural optimizations and improved training recipes. All models can be used similarly:

import torch
from torchvision.models.detection import *

images = [torch.rand(3, 800, 600)]
model = retinanet_resnet50_fpn_v2(weights="DEFAULT")
# model = fasterrcnn_resnet50_fpn_v2(weights="DEFAULT")
# model = maskrcnn_resnet50_fpn_v2(weights="DEFAULT")
model.eval()
prediction = model(images)

Below we present the metrics of the new variants on COCO val2017. In parenthesis we denote the improvement over the old variants:

Model Box mAP Mask mAP
retinanet_resnet50_fpn_v2 41.5 (+5.1)
fasterrcnn_resnet50_fpn_v2 46.7 (+9.7)
maskrcnn_resnet50_fpn_v2 47.4 (+9.5) 41.8 (+7.2)

We would like to thank Ross Girshick, Piotr Dollar, Vaibhav Aggarwal, Francisco Massa and Hu Ye for their past research and contributions to this work.

New pre-trained weights

SWAG weights

The ViT and RegNet model variants offer new pre-trained SWAG (​​Supervised Weakly from hashtAGs) weights. One of the biggest of these models achieves a whopping 88.6% accuracy on ImageNet-1K. We currently offer two versions of the weights: 1) fine-tuned end-to-end weights on ImageNet-1K (highest accuracy) and 2) frozen trunk weights with a linear classifier fit on ImageNet-1K (great for transfer learning). Below we see the detailed accuracies of each model variant:

Model Weights Acc@1 Acc@5
RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1 86.012 98.054
RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1 83.976 97.244
RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1 86.838 98.362
RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1 84.622 97.48
RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1 88.228 98.682
RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1 86.068 97.844
ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 85.304 97.65
ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1 81.886 96.18
ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1 88.064 98.512
ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1 85.146 97.422
ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1 88.552 98.694
ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1 85.708 97.73

The SWAG weights are released under the Attribution-NonCommercial 4.0 International license. We would like to thank Laura Gustafson, Mannat Singh and Aaron Adcock for their work and support in making the weights available to TorchVision.

Model Refresh

The release of the Multi-weight support API enabled us to refresh the most popular models and offer more accurate weights. We improved on average each model by ~3 points. The new recipe used was learned on top of ResNet50 and its details were covered on a previous blog post.

Model Old weights New weights
efficientnet_b1 78.642 79.838
mobilenet_v2 71.878 72.154
mobilenet_v3_large 74.042 75.274
regnet_y_400mf 74.046 75.804
regnet_y_800mf 76.42 78.828
regnet_y_1_6gf 77.95 80.876
regnet_y_3_2gf 78.948 81.982
regnet_y_8gf 80.032 82.828
regnet_y_16gf 80.424 82.886
regnet_y_32gf 80.878 83.368
regnet_x_400mf 72.834 74.864
regnet_x_800mf 75.212 77.522
regnet_x_1_6gf 77.04 79.668
regnet_x_3_2gf 78.364 81.196
regnet_x_8gf 79.344 81.682
regnet_x_16gf 80.058 82.716
regnet_x_32gf 80.622 83.014
resnet50 76.13 80.858
resnet50 (quantized) 75.92 80.282
resnet101 77.374 81.886
resnet152 78.312 82.284
resnext50_32x4d 77.618 81.198
resnext101_32x8d 79.312 82.834
resnext101_32x8d (quantized) 78.986 82.574
wide_resnet50_2 78.468 81.602
wide_resnet101_2 78.848 82.51

We would like to thank Piotr Dollar, Mannat Singh and Hugo Touvron for their past research and contributions to this work.

New Augmentations, Layers and Losses

This release brings a bunch of new primitives which can be used to produce SOTA models. Some highlights include the addition of AugMix data-augmentation method, the DropBlock layer, the cIoU/dIoU loss and many more. We would like to thank Aditya Oke, Abhijit Deo, Yassine Alouini and Hu Ye for contributing to the project and for helping us maintain TorchVision relevant and fresh.

Documentation

We completely revamped our models documentation to make them easier to browse, and added various key information such as supported image sizes, or image pre-processing steps of pre-trained weights. We now have a main model page with various summary tables of available weights, and each model has a dedicated page. Each model builder is also documented in their own page, with more details about the available weights, including accuracy, minimal image size, link to training recipes, and other valuable info. For comparison, our previous models docs are here. To provide feedback on the new documentation, please use the dedicated Github issue.

TorchAudio v0.12

(BETA) Streaming API

StreamReader is TorchAudio’s new I/O API. It is backed by FFmpeg†, and allows users to:

  • Decode audio and video formats, including MP4 and AAC
  • Handle input forms, such as local files, network protocols, microphones, webcams, screen captures and file-like objects
  • Iterate over and decode chunk-by-chunk, while changing the sample rate or frame rate
  • Apply audio and video filters, such as low-pass filter and image scaling
  • Decode video with Nvidia’s hardware-based decoder (NVDEC)

For usage details, please check out the documentation and tutorials:

† To use StreamReader, FFmpeg libraries are required. Please install FFmpeg. The coverage of codecs depends on how these libraries are configured. TorchAudio official binaries are compiled to work with FFmpeg 4 libraries; FFmpeg 5 can be used if TorchAudio is built from source.

(BETA) CTC Beam Search Decoder

TorchAudio integrates the wav2letter CTC beam search decoder from Flashlight (GitHub). The addition of this inference time decoder enables running end-to-end CTC ASR evaluation using TorchAudio utils.

Customizable lexicon and lexicon-free decoders are supported, and both are compatible with KenLM n-gram language models or without using a language model. TorchAudio additionally supports downloading token, lexicon, and pretrained KenLM files for the LibriSpeech dataset.

For usage details, please check out the documentation and ASR inference tutorial.

(BETA) New Beamforming Modules and Methods

To improve flexibility in usage, the release adds two new beamforming modules under torchaudio.transforms: SoudenMVDR and RTFMVDR. The main differences from MVDR are:

  • Use power spectral density (PSD) and relative transfer function (RTF) matrices as inputs instead of time-frequency masks. The module can be integrated with neural networks that directly predict complex-valued STFT coefficients of speech and noise
  • Add ‘reference_channel’ as an input argument in the forward method, to allow users to select the reference channel in model training or dynamically change the reference channel in inference

Besides the two modules, new function-level beamforming methods are added under torchaudio.functional. These include:

For usage details, please check out the documentation at torchaudio.transforms and torchaudio.functional and the Speech Enhancement with MVDR Beamforming tutorial.

TorchText v0.13

Glue Datasets

We increased the number of datasets in TorchText from 22 to 30 by adding the remaining 8 datasets from the GLUE benchmark (SST-2 was already supported). The complete list of GLUE datasets is as follows:

  • CoLA (paper): Single sentence binary classification acceptability task
  • SST-2 (paper): Single sentence binary classification sentiment task
  • MRPC (paper): Dual sentence binary classification paraphrase task
  • QQP: Dual sentence binary classification paraphrase task
  • STS-B (paper): Single sentence to float regression sentence similarity task
  • MNLI (paper): Sentence ternary classification NLI task
  • QNLI (paper): Sentence binary classification QA and NLI tasks
  • RTE (paper): Dual sentence binary classification NLI task
  • WNLI (paper): Dual sentence binary classification coreference and NLI tasks

Scriptable BERT Tokenizer

TorchText has extended support for scriptable tokenizer by adding the WordPiece tokenizer used in BERT. It is one of the commonly used algorithms for splitting input text into sub-words units and was introduced in Japanese and Korean Voice Search (Schuster et al., 2012).

TorchScriptabilty support would allow users to embed the BERT text-preprocessing natively in C++ without needing the support of python runtime. As TorchText now supports the CMAKE build system to natively link torchtext binaries with application code, users can easily integrate BERT tokenizers for deployment needs.

For usage details, please refer to the corresponding documentation.

TorchRec v0.2.0

EmbeddingModule + DLRM benchmarks

A set of benchmarking tests, showing performance characteristics of TorchRec’s base modules and research models built out of TorchRec.

TwoTower Retrieval Example, with FAISS

We provide an example demonstrating training a distributed TwoTower (i.e. User-Item) Retrieval model that is sharded using TorchRec. The projected item embeddings are added to an IVFPQ FAISS index for candidate generation. The retrieval model and KNN lookup are bundled in a Pytorch model for efficient end-to-end retrieval.

Integrations

We demonstrate that TorchRec works out of the box with many components commonly used alongside PyTorch models in production like systems, such as

  • Training a TorchRec model on Ray Clusters utilizing the Torchx Ray scheduler
  • Preprocessing and DataLoading with NVTabular on DLRM
  • Training a TorchRec model with on-the-fly preprocessing with TorchArrow showcasing RecSys domain UDFs

Sequential Embeddings Example: Bert4Rec

We provide an example, using TorchRec, that reimplements the BERT4REC paper, showcasing EmbeddingCollection for non-pooled embeddings. Using DistributedModelParallel we see a 35% QPS gain over conventional data parallelism.

(Beta) Planner

The TorchRec library includes a built-in planner that selects near optimal sharding plan for a given model. The planner attempts to identify the best sharding plan by evaluating a series of proposals which are statically analyzed and fed into an integer partitioner. The planner is able to automatically adjust plans for a wide range of hardware setups, allowing users to scale performance seamlessly from local development environment to large scale production hardware. See this notebook for a more detailed tutorial.

(Beta) Inference

TorchRec Inference is a C++ library that supports multi-gpu inference. The TorchRec library is used to shard models written and packaged in Python via torch.package (an alternative to TorchScript). The torch.deploy library is used to serve inference from C++ by launching multiple Python interpreters carrying the packaged model, thus subverting the GIL. Two models are provided as examples: DLRM multi-GPU (sharded via TorchRec) and DLRM single-GPU.

(Beta) RecMetrics

RecMetrics is a metrics library that collects common utilities and optimizations for Recommendation models. It extends torchmetrics.

  • A centralized metrics module that allows users to add new metrics
  • Commonly used metrics, including AUC, Calibration, CTR, MSE/RMSE, NE & Throughput
  • Optimization for metrics related operations to reduce the overhead of metric computation
  • Checkpointing

(Prototype) Single process Batched + Fused Embeddings

Previously TorchRec’s abstractions (EmbeddingBagCollection/EmbeddingCollection) over FBGEMM kernels, which provide benefits such as table batching, optimizer fusion, and UVM placement, could only be used in conjunction with DistributedModelParallel. We’ve decoupled these notions from sharding, and introduced the FusedEmbeddingBagCollection, which can be used as a standalone module, with all of the above features, and can also be sharded.

TorchX v0.2.0

TorchX is a job launcher that makes it easier to run PyTorch in distributed training clusters with many scheduler integrations including Kubernetes and Slurm. We’re excited to release TorchX 0.2.0 with a number of improvements. TorchX is currently being used in production in both on-premise and cloud environments.

Check out the quickstart to start launching local and remote jobs.

Workspaces

TorchX now supports workspaces which allows users to easily launch training jobs using their local workspace. TorchX can automatically build a patch with your local training code on top of a base image to minimize iteration time and time to training.

.torchxconfig

Specifying options in .torchxconfig saves you from having to type long CLI commands each time you launch a job. You can also define project level generic configs and drop a config file in your home directory for user-level overrides.

Expanded Scheduler Support

TorchX now supports AWS Batch and Ray (experimental) schedulers in addition to our existing integrations.

Distributed Training On All Schedulers

The TorchX dist.ddp component now works on all schedulers without any configuration. Distributed training workers will automatically discover each other when using torchelastic via the builtin dist.ddp component.

Hyper Parameter Optimization

TorchX integrates with Ax to let you scale hyper-parameter optimizations (HPO) by launching the search trials onto remote clusters.

File and Device Mounts

TorchX now supports remote filesystem mounts and custom devices. This enables your PyTorch jobs to efficiently access cloud storage such as NFS or Lustre. The device mounts enables usage of network accelerators like Infiniband and custom inference/training accelerators.

FBGemm v0.2.0

The FBGEMM library contains optimized kernels meant to improve the performance of PyTorch workloads. We’ve added a number of new features and optimizations over the last few months that we are excited to report.

Inference Table Batched Embedding (TBE)

The table batched embedding bag (TBE) operator is an important base operation for embedding lookup for recommendation system inference on GPU. We added the following enhancements for performance and flexibility:

Alignment restriction removed

  • Embedding dimension * data type size had to be multiple of 4B before and now, it is 1B.

Unified Virtual Memory (UVM) caching kernel optimizations

  • UVM caching kernels now scale linearly with # of tables using UVM caching. Previously, it was having similar overhead as all tables using UVM caching
  • UVM caching kernel overhead is much smaller than before

Inference FP8 Table Batched Embedding (TBE)

The table batched embedding bag (TBE) previously supported FP32, FP16, INT8, INT4, and INT2 embedding weight types. While these weight types work well in many models, we integrate FP8 weight types (in both GPU and CPU operations) to allow for numerical and performance evaluations of FP8 in our models. Compared to INT8, FP8 does not require the additional bias and scale storage and calculations. Additionally, the next generation of H100 GPUs has the FP8 support on Tensor Core (mainly matmul ops).

Jagged Tensor Kernels

We added optimized kernels to speed up TorchRec JaggedTensor. The purpose of JaggedTensor is to handle the case where one dimension of the input data is “jagged”, meaning that each consecutive row in a given dimension may be a different length, which is often the case with sparse feature inputs in recommendation systems. The internal representation is shown below:

We added ops for converting jagged tensors from sparse to dense formats and back, performing matrix multiplications with jagged tensors, and elementwise ops.

Optimized permute102-baddbmm-permute102

It is difficult to fuse various matrix multiplications where the batch size is not the batch size of the model, switching the batch dimension is a quick solution. We created the permute102_baddbmm_permute102 operation that switches the first and the second dimension, performs the batched matrix multiplication and then switches back. Currently we only support forward pass with FP16 data type and will support FP32 type and backward pass in the future.

Optimized index_select for dim 0 index selection

index_select is normally used as part of a sparse operation. While PyTorch supports a generic index_select for an arbitrary-dimension index selection, its performance for a special case like the dim 0 index selection is suboptimal. For this reason, we implement a specialized index_select for dim 0. In some cases, we have observed 1.4x performance gain from FBGEMM’s index_select compared to the one from PyTorch (using uniform index distribution).

More about the implementation of influential instances can be found on our GitHub page and tutorials.

Thanks for reading, If you’re interested in these updates and want to join the PyTorch community, we encourage you to join the discussion forums and open GitHub issues. To get the latest news from PyTorch, follow us on Twitter, Medium, YouTube, and LinkedIn.

Cheers!

Team PyTorch

Read More

A BetterTransformer for Fast Transformer Inference

tl;dr Transformers achieve state-of-the-art performance for NLP, and are becoming popular for a myriad of other tasks. They are computationally expensive which has been a blocker to their widespread productionisation. Launching with PyTorch 1.12, BetterTransformer implements a backwards-compatible fast path of torch.nn.TransformerEncoder for Transformer Encoder Inference and does not require model authors to modify their models. BetterTransformer improvements can exceed 2x in speedup and throughput for many common execution scenarios. To use BetterTransformer, install PyTorch 1.12 and start using high-quality, high-performance Transformer models with the PyTorch API today.

Diagram of the Transformer Encoder Architecture (from “Attention Is All You Need“). During Inference, the entire module will execute as a single PyTorch-native function.

In this blog post, we share the following topics — Performance Improvements, Backwards compatibility, and Taking advantage of the FastPath. Learn more about these topics below.

Performance Improvements

BetterTransformer launches with accelerated native implementations of MultiHeadAttention and TransformerEncoderLayer for CPUs and GPUs. These fast paths are integrated in the standard PyTorch Transformer APIs, and will accelerate TransformerEncoder, TransformerEncoderLayer and MultiHeadAttention nn.modules. These new modules implement two types of optimizations: (1) fused kernels combine multiple individual operators normally used to implement Transformers to provide a more efficient implementation, and (2) take advantage of sparsity in the inputs to avoid performing unnecessary operations on padding tokens. Padding tokens frequently account for a large fraction of input batches in many Transformer models used for Natural Language Processing.

Backwards compatibility

Advantageously, no model changes are necessary to benefit from the performance boost offered by BetterTransformer. To benefit from fast path execution, inputs and operating conditions must satisfy some access conditions (see below). While the internal implementation of Transformer APIs has changed, PyTorch 1.12 maintains strict compatibility with Transformer modules shipped in previous versions, enabling PyTorch users to use models created and trained with previous PyTorch releases while benefiting from BetterTransformer improvements.

In addition to enabling the PyTorch nn.Modules, BetterTransformer provides improvements for PyTorch libraries. Performance benefits will become available through two different enablement paths:

  1. Transparent acceleration: Current users of PyTorch nn.Modules such as MultiHeadAttention as well as higher-level Transformer components will benefit from the improved performance of the new nn.Modules automatically. An example of this is the visual transformer (ViT) implementation used in the torchvision library (code link).

  2. Torchtext library acceleration: As part of this project, we have optimized Torchtext to build on the PyTorch core API to benefit from BetterTransformer enhancements while maintaining strict and transparent compatibility with previous library versions and models trained with previous Torchtext versions. Using PyTorch Transformers in Torchtext also ensures that Torchtext will benefit from expected future enhancements to the PyTorch Transformer implementation.

Taking advantage of the Fastpath

BetterTransformer is a fastpath for the PyTorch Transformer API. The fastpath is a native, specialized implementation of key Transformer functions for CPU and GPU that applies to common Transformer use cases.

To take advantage of input sparsity (i.e. padding) in accelerating your model (see Figure 2), set the keyword argument enable_nested_tensor=True when instantiating a TransformerEncoder and pass in the src_key_padding_mask argument (which denotes padding tokens) during inference. This requires the padding mask to be contiguous, which is the typical case.

Currently, the BetterTransformer speedup only applies to transformer encoder models used in inference. To benefit from fastpath execution, models must be composed of any of the following components: TransformerEncoder, TransformerEncoderLayer or MultiheadAttention (MHA). Fastpath execution is also subject to some criteria. Most importantly, the model must be executed in inference mode and operate on input tensors that do not collect gradient tape information (e.g., running with torch.no_grad). The full list of conditions can be found at these links for nn.MultiHeadAttention and nn.TransformerEncoder, respectively. If the criteria are not met, control flows to the legacy PyTorch 1.11 Transformer implementation which has the same API, but lacks the fastpath performance boost.

Other transformer models (such as decoder models) which use the PyTorch MultiheadAttention module will benefit from the BetterTransformer fastpath. Planned future work is to expand the end-to-end BetterTransformer fastpath to models based on TransformerDecoder to support popular seq2seq and decoder-only (e.g., OPT) model architectures, and to training.

Speedups

The following graphs show the performance achieved for the BERT-base model with small and large-scale inputs:

Figure 1: PyTorch 1.12 Improvements with BetterTransformer fastpath execution

Figure 2: PyTorch 1.12 Improvements with BetterTransformer fastpath execution
with sparsity optimization enabled by enable_nested_tensor=True

BetterTransformer includes two types of optimization: (1) fused kernels implementing multiple operations more efficiently in a single kernel, and (2) exploiting sparsity by avoiding unnecessary processing on padding tokens. Enhanced performance for small input sizes benefits primarily from the fused kernel implementations, and shows a constant performance improvement regardless of padding amount. While large inputs still benefit from fused kernels, the computation heavy processing limits the benefits that may be obtained by the fused kernels as baseline performance is already closer to the theoretical peak. However, as we increase the amount of padding, performance increases dramatically as increasingly large amounts of computation can be avoided by exploiting the sparsity introduced by padding in NLP workloads.

Future Work

As part of our ongoing work on PyTorch BetterTransformer, we are working on extending BetterTransformer improvements to Transformer Decoders. We aim to expand beyond inference to training as well.

We are partnering to enable BetterTransformer on additional libraries such as FairSeq, MetaSeq, and HuggingFace to benefit all Transformer-based PyTorch models. We’ll provide future updates on the progress of BetterTransformer accelerations for the larger PyTorch ecosystem as part of this blog series.

Acknowledgements: The authors would like to thank Lin Qiao, Ajit Mathews, Andrew Tulloch, Dmytro Dzhulgakov, Natalia Gimelshein, Emad El-Haraty, Mark Saroufim, Adnan Aziz, Geeta Chauhan, and Hamid Shojanazeri for their support, contributions and many helpful suggestions throughout the course of this project, and in the preparation of this blog.

Read More

How Computational Graphs are Executed in PyTorch

Welcome to the last entry into understanding the autograd engine of PyTorch series!
If you haven’t read parts 1 & 2 check them now to understand how PyTorch creates the computational graph for the backward pass!

This post is based on PyTorch v1.11, so some highlighted parts may differ across versions.

PyTorch autograd graph execution

The last post showed how PyTorch constructs the graph to calculate the outputs’ derivatives w.r.t. the inputs when executing the forward pass. Now we will see how the execution of the backward pass is coordinated and done by looking at the whole process, starting from Python down to the lower C++ level internals.

What Happens when Calling backward()/grad() from Python

Using variable.backward()

After doing all our calculations with an input set to require the gradient, we call .backward() on the result to initiate the backward pass execution.

>>> x = torch.tensor([0.5, 0.75], requires_grad=True)
>>> y = torch.exp(x).sum()
>>> y.backward()

Calling .backward() on a tensor results in a call to torch.autograd.backward().

# torch/_tensor.py

def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=None):
    
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

torch.autograd.backward() checks the arguments and calls the autograd engine in the C++ layer.

def backward(
    tensors: _TensorOrTensors,
    grad_tensors: Optional[_TensorOrTensors] = None,
    retain_graph: Optional[bool] = None,
    create_graph: bool = False,
    grad_variables: Optional[_TensorOrTensors] = None,
    inputs: Optional[_TensorOrTensors] = None,
) -> None:
    

    if inputs is not None and len(inputs) == 0:
        raise RuntimeError("'inputs' argument to backward() cannot be empty.")

    tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
    inputs = (inputs,) if isinstance(inputs, torch.Tensor) else 
        tuple(inputs) if inputs is not None else tuple()

    grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
    grad_tensors_ = _make_grads(tensors, grad_tensors_)
    if retain_graph is None:
        retain_graph = create_graph

    Variable._execution_engine.run_backward(
        tensors, grad_tensors_, retain_graph, create_graph, inputs,
        allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag

First, whether the grad_tensors argument was specified or not, there is a call to the _make_grads function. This is used to check the provided grad_tensors or to specify the default value for them by looking at the tensors argument values’ shapes. Check the first blog post for details on the default value for the grad_tensors of the backward pass. This function just provides the vector of the vector jacobian product if it was not initially specified.

In the above code, Variable has an _execution_engine attribute that is defined in torch.autograd.variable to be of type ImperativeEngine; the C++ engine exported to python and declared in torch/csrc/autograd/python_engine.cpp. In the following sections, we explain in detail how this object executes the backward pass.

Note that the torch.autograd.backward function has an inputs optional argument. This argument is used when we want to calculate the .grad field of only a subset of input tensors in the forward pass.

>>> x = torch.tensor([0.5, 0.75], requires_grad=True)
>>> y = torch.tensor([0.1, 0.90], requires_grad=True)
>>> z = torch.exp(x * y).sum()
>>> torch.autograd.backward([z], inputs=[x])
>>> x.grad
tensor([0.1051, 1.7676])
>>> y.grad  # None
>>>

Using torch.autograd.grad

An alternative to backward() is to use torch.autograd.grad(). The main difference to backward() is that grad() returns a tuple of tensors with the gradients of the outputs w.r.t. the inputs kwargs instead of storing them in the .grad field of the tensors. As you can see, the grad() code shown below is very similar to backward.

def grad(
    outputs: _TensorOrTensors,
    inputs: _TensorOrTensors,
    grad_outputs: Optional[_TensorOrTensors] = None,
    retain_graph: Optional[bool] = None,
    create_graph: bool = False,
    only_inputs: bool = True,
    allow_unused: bool = False,
   is_grads_batched: bool = False
) -> Tuple[torch.Tensor, ...]:
   
    outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
    inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
    overridable_args = outputs + inputs
    if has_torch_function(overridable_args):
        return handle_torch_function(
            grad,
            overridable_args,
            outputs,
            inputs,
            grad_outputs=grad_outputs,
            retain_graph=retain_graph,
            create_graph=create_graph,
            only_inputs=only_inputs,
            allow_unused=allow_unused,
        )

    grad_outputs_ = _tensor_or_tensors_to_tuple(grad_outputs, len(outputs))
    grad_outputs_ = _make_grads(outputs, grad_outputs_)

    if retain_graph is None:
        retain_graph = create_graph

    if is_grads_batched:
        # …. It will not be covered here
    else:
        return Variable._execution_engine.run_backward(
            outputs, grad_outputs_, retain_graph, create_graph, inputs,
            allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the backward pass

Figure 1 shows the computational graph with the backward() and grad() arguments highlighted in red and blue, respectively:

Fgiure 1: Correspondence of `backward`/`grad` arguments in the graphs.

Going Inside the Autograd Engine

Refreshing Concepts: Nodes and Edges

As we saw in 2
The computational graph comprises Node and Edge objects. Please read that post if you haven’t done it yet.

Nodes

Node objects are defined in torch/csrc/autograd/function.h, and they provide an overload of operator() for the associated function and a list of edges to do the graph traversal. Note that Node is a base class that autograd functions inherit from and override the apply method to execute the backward function.

struct TORCH_API Node : std::enable_shared_from_this<Node> {
 ...
 /// Evaluates the function on the given inputs and returns the result of the
  /// function call.
  variable_list operator()(variable_list&& inputs) {
  ...
  }

protected:
  /// Performs the `Node`'s actual operation.
  virtual variable_list apply(variable_list&& inputs) = 0;
  
  edge_list next_edges_;
  uint64_t topological_nr_ = 0;
  

There is an attribute called topological_nr_ in every node object. This number is used to optimize the graph execution as it allows to discard of graph branches under certain conditions. The topological number is the longest distance between this node and any leaf node and it is shown in Figure 2. Its main property is that for any pair of nodes x, y in a directed graph topo_nr(x) < topo_nr(y) means that there is no path from x to y. So this allows for reducing the number of paths in the graph in need of traversal. Check the topological_nr
) method comment for further details.

Figure 2: Example of the Topological Number calculation

Edges

The Edge object links Nodes together, and its implementation is straightforward.

struct Edge {
  ...
  /// The function this `Edge` points to.
  std::shared_ptr<Node> function;
  /// The identifier of a particular input to the function.
  uint32_t input_nr;
};

It only requires a function pointer to the Node and an input number that is the index of the output from the forward function this edge points to. When preparing the set of gradients before calling “function”, we know that what is flowing from this edge should be accumulated in the “input_nr”th argument. Note that the input/output name is flipped here and this is the input to the backward function.
Edge objects are constructed using the gradient_edge function method.

 Edge gradient_edge(const Variable& self) {
    if (const auto& gradient = self.grad_fn()) {
      return Edge(gradient, self.output_nr());
    } else {
      return Edge(grad_accumulator(self), 0);
    }
  }

Entering the C++ Realm

Once that torch.autograd.backward() has been invoked, the
THPEngine_run_backward routine starts the graph traversal. Following is a schema of the function body:

PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwargs)
{
  HANDLE_TH_ERRORS
  PyObject *tensors = nullptr;
  PyObject *grad_tensors = nullptr;
  unsigned char keep_graph = 0;
  unsigned char create_graph = 0;
  PyObject *inputs = nullptr;
  
  // Convert the python arguments to C++ objects
  const char *accepted_kwargs[] = { // NOLINT
      "tensors", "grad_tensors", "keep_graph", "create_graph", "inputs",
      "allow_unreachable", "accumulate_grad", nullptr
  };
  if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OObb|Obb", (char**)accepted_kwargs,
        &tensors, &grad_tensors, &keep_graph, &create_graph, &inputs, &allow_unreachable, &accumulate_grad))

 // Prepare arguments
 for(const auto i : c10::irange(num_tensors)) {
   // Check that the tensors require gradients
  }

  std::vector<Edge> output_edges;
  if (inputs != nullptr) {
     // Prepare outputs
  }

  {
      // Calls the actual autograd engine
    pybind11::gil_scoped_release no_gil;
    outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
  }
    // Clean up and finish
}

First, we prepare the input arguments after converting the PyObject arguments to actual C++ objects. The tensors list contains the tensors from which we start the backward pass. These tensors are converted to edges using torch::autograd::impl::gradient_edge and added to a list called roots where the graph traversal starts.

 edge_list roots;
  roots.reserve(num_tensors);
  variable_list grads;
  grads.reserve(num_tensors);
  for(const auto i : c10::irange(num_tensors)) {
    PyObject *_tensor = PyTuple_GET_ITEM(tensors, i);
       const auto& variable = THPVariable_Unpack(_tensor);
       auto gradient_edge = torch::autograd::impl::gradient_edge(variable);
     roots.push_back(std::move(gradient_edge));

    PyObject *grad = PyTuple_GET_ITEM(grad_tensors, i);
    if (THPVariable_Check(grad)) {
      const Variable& grad_var = THPVariable_Unpack(grad);
      grads.push_back(grad_var);
    } 
  }

Now, if the inputs argument was specified in backward or we used the torch.autograd.grad api, the following code creates a list of edges to accumulate the gradients in the specified tensors at the end of the computation. The engine uses this later to optimize the execution as it doesn’t add the gradients in all the leaf nodes, just the specified ones.

  std::vector<Edge> output_edges;
  if (inputs != nullptr) {
    int num_inputs = PyTuple_GET_SIZE(inputs);
    output_edges.reserve(num_inputs);
    for (const auto i : c10::irange(num_inputs)) {
      PyObject *input = PyTuple_GET_ITEM(inputs, i);
      const auto& tensor = THPVariable_Unpack(input);
      const auto output_nr = tensor.output_nr();
      auto grad_fn = tensor.grad_fn();
      if (!grad_fn) {
        grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
      }
      if (accumulate_grad) {
        tensor.retain_grad();
      }
      if (!grad_fn) {
        output_edges.emplace_back(std::make_shared<Identity>(), 0);
      } else {
        output_edges.emplace_back(grad_fn, output_nr);
      }
    }
  }

The next step is the actual graph traversal and node function execution, and finally, the cleanup and return.

  {
    // Calls the actual autograd engine
    pybind11::gil_scoped_release no_gil;
    auto& engine = python::PythonEngine::get_python_engine();
    outputs = engine.execute(roots, grads, keep_graph, create_graph, accumulate_grad, output_edges);
  }
  // Clean up and finish
}

Starting the Real Execution

engine.executeis present in torch/csrc/autograd/engine.cpp

There are two differentiated steps here:

Analyze the graph to find dependencies between functions
Create worker threads that traverse the graph

Data Structures Used for the Execution

GraphTask

All the execution metadata is managed by the GraphTask class in torch/csrc/autograd/engine.h

struct GraphTask: std::enable_shared_from_this<GraphTask> {
  std::atomic<uint64_t> outstanding_tasks_{0};
  //  … 
  std::unordered_map<Node*, InputBuffer> not_ready_;
  std::unordered_map<Node*, int> dependencies_;

  struct ExecInfo {
     // …
  };
  std::unordered_map<Node*, ExecInfo> exec_info_;
  std::vector<Variable> captured_vars_;
  // …
  std::shared_ptr<ReadyQueue> cpu_ready_queue_;
};

Here we see a series of variables dedicated to maintaining the execution state.
outstanding_tasks_ tracks the number of tasks left to be executed for the backward pass to complete. not_ready_ holds the input arguments for the Nodes that are not ready to be executed. dependencies_ track the number of predecessors that a Node has. As the count reaches 0, the Node is ready for execution; it is placed in a ready queue to be retrieved and executed later.

exec_info_ and the associated ExecInfo struct are used only when the inputs argument is specified or it is a call to autograd.grad(). They allow filter paths on the graph that are not needeed since only the gradients are calculated only for the variables in the inputs list.

captured_vars_ is where the results of the graph execution are temporarily stored if we used the torch.autograd.grad() api instead of torch.autograd.backward() since grad() returns the gradients as tensors instead of just filling the .grad field of the inputs.

NodeTask

The NodeTask struct is a basic class that holds an fn_ pointer to the node to execute, and an inputs_ buffer to store the input arguments to this function. Note that the functions executed by the backward pass are the derivatives specified in the derivatives.yaml file. or the user provided backward function when using custom functions as described in the second blog post.

The inputs_ buffer is also where the output gradients of the previously executed functions are aggregated, and it is defined as a std::vector<Variable> container with facilities to accumulate values at a given position.

struct NodeTask {
  std::weak_ptr<GraphTask> base_;
  std::shared_ptr<Node> fn_;
  // This buffer serves as an implicit "addition" node for all of the
  // gradients flowing here.  Once all the dependencies are finished, we
  // use the contents of this buffer to run the function.
  InputBuffer inputs_;
};

GraphRoot

The GraphRoot is a special function used to hold multiple input variables in a single place. The code is pretty simple as it only acts as a container of variables.

struct TORCH_API GraphRoot : public Node {
  GraphRoot(edge_list functions, variable_list inputs)
      : Node(std::move(functions)),
      outputs(std::move(inputs)) {
    for (const auto& t : outputs) {
      add_input_metadata(t);
    }
  }

  variable_list apply(variable_list&& inputs) override {
    return outputs;
  }

AccumulateGrad

This function is set during the graph creation in gradient_edge when the Variable object doesn’t have a grad_fn. This is, it is a leaf node.

    if (const auto& gradient = self.grad_fn()) {
      // …
    } else {
      return Edge(grad_accumulator(self), 0);
    }

The function body is defined in torch/csrc/autograd/functions/accumulate_grad.cpp and it essentially accumulates the input grads in the object’s .grad attribute.

auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
  check_input_variables("AccumulateGrad", grads, 1, 0);
  

  at::Tensor new_grad = callHooks(variable, std::move(grads[0]));
  std::lock_guard<std::mutex> lock(mutex_);

  at::Tensor& grad = variable.mutable_grad();
  accumulateGrad(
      variable,
      grad,
      new_grad,
      1 + !post_hooks().empty() /* num_expected_refs */,
      [&grad](at::Tensor&& grad_update) { grad = std::move(grad_update); });
  return variable_list();
}
}} // namespace torch::autograd



accumulateGrad
does several checks on the tensors format and eventually performs the variable_grad += new_grad; accumulation.

Preparing the graph for execution

Now, let’s walk through Engine::execute. The first thing to do besides arguments consistency checks is to create the actual GraphTask object we described above. This object keeps all the metadata of the graph execution.

auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {

  validate_outputs(roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
    return msg;
  });

  // Checks

  auto graph_task = std::make_shared<GraphTask>(
      /* keep_graph */ keep_graph,
      /* create_graph */ create_graph,
      /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
      /* cpu_ready_queue */ local_ready_queue);

  // If we receive a single root, skip creating extra root node
  // …
  // Prepare graph by computing dependencies
  // …
  // Queue the root 
  // …
  // launch execution
  // …
}

After creating the GraphTask, we use its associated function if we only have one root node. If we have multiple root nodes, we create a special GraphRoot object as described before.

  bool skip_dummy_node = roots.size() == 1;
  auto graph_root = skip_dummy_node ?
    roots.at(0).function :
    std::make_shared<GraphRoot>(roots, inputs);

The next step is to fill the dependencies_ map in the GraphTask object since the engine must know when it can execute a task. The outputs here is the inputs argument passed to the torch.autograd.backward() call in Python. But here, we have reversed the names since the gradients w.r.t. the inputs of the forward pass are now the outputs of the backward pass. And from now on, there is no concept of forward/backward, but only graph traversal and execution.

  auto min_topo_nr = compute_min_topological_nr(outputs);
  // Now compute the dependencies for all executable functions
  compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);

  if (!outputs.empty()) {
    graph_task->init_to_execute(*graph_root, outputs, accumulate_grad, min_topo_nr);
  }

Here we preprocess the graph for the execution of the nodes. First, compute_min_topological_nr is called to to obtain the minimum topological number of the tensors specified in outputs (0 if no inputs kwarg was supplied to .backward or input for .grad). This computation prunes paths in the graph that lead to input variables of which we don’t want/need to calculate the grads.

Second, is the compute_dependencies call. This function is a very simple graph traversal that starts with the root Node, and for each of the edges in node.next_edges() it increments the counter in dependencies_. Figure 3 shows the result of the dependencies calculation for the example graph. Note that the number of dependencies of any node is just the number of edges arriving at it.

Figure 3: Number of dependencies for each node

Finally, the init_to_execute call, this is the one that populates the GraphTask::exec_info_ map in case that inputs were specified in the python backward call. It iterates the graph again, starting from the root, and records in the exec_info_ map the intermediate nodes needed to calculate only the given inputs gradients.

  // Queue the root
  if (skip_dummy_node) {
    InputBuffer input_buffer(roots.at(0).function->num_inputs());
    auto input = inputs.at(0);


    input_buffer.add(roots.at(0).input_nr,
                      std::move(input),
                      input_stream,
                      opt_next_stream);

    execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
  } else {
    execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list()));
  }
  // Avoid a refcount bump for the Future, since we check for refcount in
  // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1)
  // in dist_engine.cpp).
  auto& fut = graph_task->future_result_;
  fut->wait();
  return fut->value().toTensorVector();
}

And now, we are ready to start the actual execution by creating the InputBuffer. In case we only have one root variable, we begin by copying the value of the inputs tensor (this is the gradients passed to python backward) in position 0 of the input_buffer. This is a small optimization that avoids running the RootNode for no reason. Also, if the rest of the graph is not on the cpu, we directly start on that worker while the RootNode is always placed on the cpu ready queue. Details of the workers and ready queues are explained in the section below.

On the other hand, if we have multiple roots, the GraphRoot object also holds the inputs, so it is enough to pass it an empty InputBuffer.

Graph Traversal and Node Execution

Devices, Threads and Queues

Before diving into the actual execution, we need to see how the engine is structured.

First of all, the engine is multithreaded with one thread per device. For example, the caller thread is associated with the CPU while additional threads are created and associated with each GPU or other devices available in the system. Each thread tracks its device using thread-local storage in the worker_device variable. In addition, the threads have a queue of tasks to be executed also located in thread-local storage, the local_ready_queue. This is where work is queued for this thread to execute in the thread_main function that is explained later.
You will wonder how the device where a task should be executed is decided. The InputBuffer class has a device() function that returns the first non-cpu device of all its tensors.
This function is used together with Engine::ready_queue to select the queue to queue a task.

auto Engine::ready_queue(std::shared_ptr<ReadyQueue> cpu_ready_queue, at::Device device) -> std::shared_ptr<ReadyQueue>{
  if (device.type() == at::kCPU || device.type() == at::DeviceType::Meta) {
    return cpu_ready_queue;
  } else {
    // See Note [Allocating GPUs to autograd threads]
    return device_ready_queues_.at(device.index());
  }
}

The ReadyQueue object is defined in torch/csrc/autograd/engine.h and it is a simple wrapper over std::priority_queue that allows a thread to wait for a task if it’s empty. One interesting property of the ReadyQueue is that it increases the GraphTask::outstanding_tasks_ value used to determine if the execution has completed or not.

auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
  {
    std::lock_guard<std::mutex> lock(mutex_);
    if (incrementOutstandingTasks) {
      std::shared_ptr<GraphTask> graph_task = item.base_.lock();
      ++graph_task->outstanding_tasks_;
    }
    heap_.push(std::move(item));
  }
  not_empty_.notify_one();
}

auto ReadyQueue::pop() -> NodeTask {
  std::unique_lock<std::mutex> lock(mutex_);
  not_empty_.wait(lock, [this]{ return !heap_.empty(); });
  auto task = std::move(const_cast<NodeTask&>(heap_.top())); heap_.pop();
  return task;
}

Reentrant Backward

A reentrant backward happens when one of the tasks in a backward pass calls again backward. It is not a very common case, but it can be used to reduce memory utilization as it could potentially avoid saving intermediate results. For more information, check this PyTorch forum post.

class ReentrantBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.sum()

    @staticmethod
    def backward(ctx, input):
        # Let's compute the backward by using autograd
        input = input.detach().requires_grad_()
        with torch.enable_grad():
            out = input.sum()
        out.backward()  # REENTRANT CALL!!
        return out.detach()

Here, we call backward() inside backward() for a user custom-defined autograd function.
This situation can lead to deadlocks because the first backward needs to wait for the second one to complete. But some internal implementation details can prevent the second backward from completing as it is explained in the dedicated subsection.

Thread Initialization

execute_with_graph_task is in charge of initializing the threads taking care of the computation and placing the root node in the queue of the device that produced it.

c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
    const std::shared_ptr<GraphTask>& graph_task,
    std::shared_ptr<Node> graph_root,
    InputBuffer&& input_buffer) {

  initialize_device_threads_pool();
  // Lock mutex for GraphTask.
  std::unique_lock<std::mutex> lock(graph_task->mutex_);

  auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());

  if (worker_device == NO_DEVICE) {
    set_device(CPU_DEVICE);
    graph_task->owner_ = worker_device;
    queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
    lock.unlock();
    thread_main(graph_task);
    worker_device = NO_DEVICE;
  } else {
     // This deals with reentrant backwards, we will see it later.
  }
  return graph_task->future_result_;
}

First, this function initializes several threads (one per device) calling initialize_device_threads_pool() where several things happen:
One ReadyQueue per device is created.
One thread per non-cpu device is created.
A thread local worker_device variable is set to track the current device associated with the thread.
thread_main function is called, and threads wait for tasks to be put in their queues.

Then it retrieves the queue to place the root node based on the device that holds the tensors present in the input_buffer using the ready_queue function. Now, the main thread (the one also executing the Python interpreter) has its worker_device set to NO_DEVICE, and it is in charge of executing functions with all its tensors living in the cpu. If worker_device is set to any other value, the graph execution is already started, and .backward() was called inside a running Node, creating a reentrant backward call. This is explained later. For now,
the main thread places the task in the queue and call thread_main.

Where the Magic Happens

It’s been a long way, but finally, we are ready to traverse the graph and execute the nodes. Each of the spawned threads, and the main thread call thread_main.

auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {

  while (graph_task == nullptr || !graph_task->future_result_->completed()) {
    std::shared_ptr<GraphTask> local_graph_task;
    {
      NodeTask task = local_ready_queue->pop();

      if (task.isShutdownTask_) {
        break;
      }

      if (!(local_graph_task = task.base_.lock())) {
        // GraphTask for function is no longer valid, skipping further
        // execution.
        continue;
      }

      if (task.fn_ && !local_graph_task->has_error_.load()) {
        at::ThreadLocalStateGuard tls_guard(local_graph_task->thread_locals_);

        try {
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
          {
            evaluate_function(
                local_graph_task,
                task.fn_.get(),
                task.inputs_,
                local_graph_task->cpu_ready_queue_);
          }
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);
        }
      }
    }

    // Decrement the outstanding tasks.
    --local_graph_task->outstanding_tasks_;

    // Check if we've completed execution.
    if (local_graph_task->completed()) {
      local_graph_task->mark_as_completed_and_run_post_processing();
      auto base_owner = local_graph_task->owner_;
      if (worker_device != base_owner) {
        std::atomic_thread_fence(std::memory_order_release);
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
      }
    }
  }
}

The code here is simple, given the local_ready_queue assigned to each thread in thread-local storage. The threads loop until there are no tasks left to execute in the graph. Note that for device-associated threads, the passed graph_task argument is nullptr, and they block in local_ready_queue->pop() until a task is pushed in their queue. After some consistency checks (the task type is shutdown, or the graph is still valid). We get to the actual function invocation in evaluate_function.

        try {
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
          {
            evaluate_function(
                local_graph_task,
                task.fn_.get(),
                task.inputs_,
                local_graph_task->cpu_ready_queue_);
          }
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);
        }
      }

After calling evaluate_function, we check if the graph_task execution is complete by looking the outstanding_tasks_ number. This number increases when a task is pushed to a queue and is decreased in local_graph_task->completed() when a task is executed. When the execution is done, we return the results that are be in the captured_vars_ in case we called torch.autograd.grad() instead of torch.autograd.backward() as this function returns tensors instead of storing them in the .grad attribute of the inputs. Finally we wake up the main thread if it’s waiting by sending a dummy task.

   // Decrement the outstanding tasks.
    --local_graph_task->outstanding_tasks_;

    // Check if we've completed execution.
    if (local_graph_task->completed()) {
      local_graph_task->mark_as_completed_and_run_post_processing();
      auto base_owner = local_graph_task->owner_;
      if (worker_device != base_owner) {
        std::atomic_thread_fence(std::memory_order_release);
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
      }
    }

Calling the Function and Unlocking New Tasks

evaluate_function serves three purposes:

Run the function.
Accumulate its results in the next node InputBuffers.
Decrease the dependencies counter of the next nodes and enqueues the tasks reaching 0 to be executed.

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {

  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    // Checks if the function needs to be executed 
    if (!fn_info.needed_) {
      // Skip execution if we don't need to execute the function.
      return;
    }
  }

  auto outputs = call_function(graph_task, func, inputs);

  auto& fn = *func;
  if (!graph_task->keep_graph_) {
    fn.release_variables();
  }

Initially, we check the exec_info_ map of the GraphTask structure to determine if the current node needs to be executed. Remember that if this map is empty, all the nodes are executed because we are calculating the grads for all the inputs of the forward pass.

After this check, the function is executed by running call_function. Its implementation is very straightforward and calls the actual derivative function and registered hooks if any.

  int num_outputs = outputs.size();
  if (num_outputs == 0) {
    // Records leaf stream (if applicable)
    return;
  }

  if (AnomalyMode::is_enabled()) {
    // check for nan values in result
  }

Next, we check the outputs of the function after call_function is done. If the number of outputs is 0, there are no following nodes to be executed so we can safely return. This is the case of the AccumulateGrad node associated with the leaf nodes.

Also, the check for NaN values in the gradients is done here if requested.


  std::lock_guard<std::mutex> lock(graph_task->mutex_);
  for (const auto i : c10::irange(num_outputs)) {
    auto& output = outputs[i];
    const auto& next = fn.next_edge(i);

    if (!next.is_valid()) continue;

   

We have now executed a grad_fn that has returned one gradient per each of the associated forward pass function inputs. As we saw in the previous blog post, we have an Edge object per each of these input tensors, and the grad_fn of the function producing them in the forward pass. Essentially, Output[0] of the node in the backward pass, corresponds to the first argument of the forward pass associated function. Figure 4 shows how the outputs of a backward function are related to the inputs of the forward function. See that the outputs of grad_fn C are the gradients of z w.r.t. the inputs of Function C

Figure 4: Correspondence between forward and backward functions inputs and outputs

We now iterate through these edges and check if the associated functions are ready to be executed.

 // Check if the next function is ready to be computed
    bool is_ready = false;
    auto& dependencies = graph_task->dependencies_;
    auto it = dependencies.find(next.function.get());

    if (it == dependencies.end()) {
      auto name = next.function->name();
      throw std::runtime_error(std::string("dependency not found for ") + name);
    } else if (--it->second == 0) {
      dependencies.erase(it);
      is_ready = true;
    }

    auto& not_ready = graph_task->not_ready_;
    auto not_ready_it = not_ready.find(next.function.get());

For this, we check the graph_task->dependencies_ map. We decrement the counter, and if it reaches 0, we mark the function pointed by the edge ready to be executed. Following, we prepare the input buffers of the tasks indicated by the next edges.

    if (not_ready_it == not_ready.end()) {
      if (!exec_info_.empty()) {
        // Skip functions that aren't supposed to be executed
      }

      // Creates an InputBuffer and moves the output to the corresponding input position
      InputBuffer input_buffer(next.function->num_inputs());
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);

      if (is_ready) {
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push(
            NodeTask(graph_task, next.function, std::move(input_buffer)));
      } else {
        not_ready.emplace(next.function.get(), std::move(input_buffer));
      }

Here, we look for the task in the graph_task->not_ready_ map. If it is not present, we create a new InputBuffer object and set the current output in the input_nr position of the buffer associated with the edge. If the task is ready to be executed, we enqueue it in the appropriate device ready_queue and complete the execution. However, if the task is not ready and we have seen it before, it is present in the not_ready_map_.

    } else {
      // The function already has a buffer
      auto &input_buffer = not_ready_it->second;
      // Accumulates into buffer
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);
      if (is_ready) {
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push(NodeTask(graph_task, next.function, std::move(input_buffer)));
        not_ready.erase(not_ready_it);
      }
    }
  }
}

In this case, we accumulate the output in the existing input_buffer instead of creating a new one. Once all the tasks are processed, the worker thread exits the loop and complete.
All this process is summarized in the animation in Figure 5. We see how a thread peeks at the tasks in the ready queue and decrements the next nodes’ dependencies, unlocking them for execution.

Figure 5: Animation of the execution of the computational graph

Flow with Reentrant Backward

As we saw above, the reentrant backward problem is when the currently executed function does a nested call to backward. When this happens, the thread running this function goes all the way down to execute_with_graph_task as in the non-reentrant case, but here is when things are different.

c10::intrusive_ptr<at::ivalue::Future> Engine::execute_with_graph_task(
    const std::shared_ptr<GraphTask>& graph_task,
    std::shared_ptr<Node> graph_root,
    InputBuffer&& input_buffer) {

  initialize_device_threads_pool();
  // Lock mutex for GraphTask.
  std::unique_lock<std::mutex> lock(graph_task->mutex_);

  auto queue = ready_queue(graph_task->cpu_ready_queue_, input_buffer.device());

  if (worker_device == NO_DEVICE) {
    //Regular case
  } else {
    // If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
    //    backward call from that device.
    graph_task->owner_ = worker_device;

    // Now that all the non-thread safe fields of the graph_task have been populated,
    // we can enqueue it.
    queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));

    if (current_depth >= max_recursion_depth_) {
      // If reached the max depth, switch to a different thread
      add_thread_pool_task(graph_task);
    } else {
      ++total_depth;
      ++current_depth;
      lock.unlock();
      thread_main(graph_task);
      --current_depth;
      --total_depth;
    }
  }
  return graph_task->future_result_;
}

Here, execute_with_graph_task detects this as a reentrant call and then looks for the current number of nested calls. If it exceeds the limit, we create a new thread to take care of the execution of this graph, and if not, we execute this reentrant call regularly.
The limit of nested calls was originally set to avoid stack overflow due to reentrant calls creating very large call stacks. However, the number was further reduced when sanitizer tests were added because of the maximum amount of locks a thread can hold at a given moment. This can be seen in torch/csrc/autograd/engine.h.

When this maximum depth is exceeded, a new thread is created with the add_thread_pool_task function.

void Engine::add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task) {
  std::unique_lock<std::mutex> lck(thread_pool_shared_->mutex_);
  // if we have pending graph_task objects to be processed, create a worker.
   bool create_thread = (thread_pool_shared_->num_workers_ <= thread_pool_shared_->graphtasks_queue_.size());
  thread_pool_shared_->graphtasks_queue_.push(graph_task);


  lck.unlock();
  if (create_thread) {
    std::thread t(&Engine::reentrant_thread_init, this);
    t.detach();
  }

  thread_pool_shared_->work_.notify_one();
}



Before going in-depth, let’s look at the thread_pool_shared_ object in the Engine which manages all the information related to the threads associated to the reentrant backward calls.

  struct ThreadPoolShared {
    unsigned int num_workers_;
    std::condition_variable work_;
    std::mutex mutex_;
    std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;

    // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
    ThreadPoolShared() : num_workers_(0) {}
 };



ThreadPoolShared is a simple container holding a queue of GraphTask objects with synchronization mechanisms and the number of current workers.

Now it is easy to understand how add_thread_pool_task creates a thread when there are graph_task objects enqueued and insufficient workers to process them.

add_thread_pool_task initializes a thread by executing reentrant_thread_init

void Engine::reentrant_thread_init() {
  at::init_num_threads();
  auto tp_shared = thread_pool_shared_;
  while(true) {
    std::unique_lock<std::mutex> lk(tp_shared->mutex_);
    ++thread_pool_shared_->num_workers_;
    tp_shared->work_.wait(lk, [&tp_shared]{ return !tp_shared->graphtasks_queue_.empty();});
    --thread_pool_shared_->num_workers_;
    auto task = tp_shared->graphtasks_queue_.front();
    tp_shared->graphtasks_queue_.pop();
    lk.unlock();
    std::shared_ptr<GraphTask> graph_task;
    if (!(graph_task = task.lock())) {
      continue;
    }
    set_device(graph_task->owner_);
    // set the local_ready_queue to the ready queue on the graph_task->owner_ device
    local_ready_queue = ready_queue_by_index(graph_task->cpu_ready_queue_, graph_task->owner_);
    total_depth = graph_task->reentrant_depth_;
    thread_main(graph_task);
  }
}



The code is straightforward. The newly created thread waits on the thread_pool_shared->graphtasks_queue_ for reentrant backward graphs to be available and executes them. Notice that this thread uses the task-ready queue associated with the device of the thread that started this call by accessing the graph_task->owner_ field set in the execute_with_graph_task function.

Error Handling

Whenever an error happens in one of the worker threads. It will be propagated to the backward calling thread.

To achieve this, there is a try/catch block in the thread_main that catches any exception in the Node function call and sets it to the associated GraphTask object.

       try {
          
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
          {
            evaluate_function(
               
          }
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);
        }
      }
    }

thread_on_exception and the functions it calls end up setting the exception in the local_graph_task object.

void Engine::thread_on_exception(
    std::shared_ptr<GraphTask> graph_task,
    const std::shared_ptr<Node>& fn,
    std::exception& e) {
  graph_task->set_exception(std::current_exception(), fn);
}

void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
  if (!has_error_.exchange(true)) {
    if (AnomalyMode::is_enabled() && fn) {
      fn->metadata()->print_stack(fn->name());
    }
  }
}

void GraphTask::set_exception(
    std::exception_ptr eptr,
    const std::shared_ptr<Node>& fn) {
  set_exception_without_signal(fn);
  if (!future_completed_.exchange(true)) {
    // NOLINTNEXTLINE(performance-move-const-arg)
    future_result_->setError(std::move(eptr));
  }
}

In set_exception it sets the has_error_ flag to true and it calls the setError
function of the future_result_ object. This will make the error to be re-thrown at the caller thread when future_result_->value() is accessed.

 IValue value() {
    std::unique_lock<std::mutex> lock(mutex_);
    AT_ASSERT(completed());
    if (eptr_) {
      std::rethrow_exception(eptr_);
    }
    return value_;
  }

Closing Remarks

This has been the last post of this series covering how PyTorch does the auto differentiation. We hope you enjoyed reading it and that now you are familiar enough with PyTorch internals to start contributing in PyTorch development!

Read More

Geospatial deep learning with TorchGeo

TorchGeo is a PyTorch domain library providing datasets, samplers, transforms, and pre-trained models specific to geospatial data.

https://github.com/microsoft/torchgeo

For decades, Earth observation satellites, aircraft, and more recently UAV platforms have been collecting increasing amounts of imagery of the Earth’s surface. With information about seasonal and long-term trends, remotely sensed imagery can be invaluable for solving some of the greatest challenges to humanity, including climate change adaptation, natural disaster monitoring, water resource management, and food security for a growing global population. From a computer vision perspective, this includes applications like land cover mapping (semantic segmentation), deforestation and flood monitoring (change detection), glacial flow (pixel tracking), hurricane tracking and intensity estimation (regression), and building and road detection (object detection, instance segmentation). By leveraging recent advancements in deep learning architectures, cheaper and more powerful GPUs, and petabytes of freely available satellite imagery datasets, we can come closer to solving these important problems.

National Oceanic and Atmospheric Administration satellite image of Hurricane Katrina, taken on August 28, 2005 (source). Geospatial machine learning libraries like TorchGeo can be used to detect, track, and predict future trajectories of hurricanes and other natural disasters.

The challenges

In traditional computer vision datasets, such as ImageNet, the image files themselves tend to be rather simple and easy to work with. Most images have 3 spectral bands (RGB), are stored in common file formats like PNG or JPEG, and can be easily loaded with popular software libraries like PIL or OpenCV. Each image in these datasets is usually small enough to pass directly into a neural network. Furthermore, most of these datasets contain a finite number of well-curated images that are assumed to be independent and identically distributed, making train-val-test splits straightforward. As a result of this relative homogeneity, the same pre-trained models (e.g., CNNs pretrained on ImageNet) have shown to be effective across a wide range of vision tasks using transfer learning methods. Existing libraries, such as torchvision, handle these simple cases well, and have been used to make large advances in vision tasks over the past decade.

Remote sensing imagery is not so uniform. Instead of simple RGB images, satellites tend to capture images that are multispectral (Landsat 8 has 11 spectral bands) or even hyperspectral (Hyperion has 242 spectral bands). These images capture information at a wider range of wavelengths (400 nm–15 µm), far outside of the visible spectrum. Different satellites also have very different spatial resolutions—GOES has a resolution of 4 km/px, Maxar imagery is 30 cm/px, and drone imagery resolution can be as high as 7 mm/px. These datasets almost always have a temporal component, with satellite revisists that are daily, weekly, or biweekly. Images often have overlap with other images in the dataset, and need to be stitched together based on geographic metadata. These images tend to be very large (e.g., 10K x 10K pixels), so it isn’t possible to pass an entire image through a neural network. This data is distributed in hundreds of different raster and vector file formats like GeoTIFF and ESRI Shapefile, requiring specialty libraries like GDAL to load.

From left to right: Mercator, Albers Equal Area, and Interrupted Goode Homolosine projections (source). Geospatial data is associated with one of many different types of reference systems that project the 3D Earth onto a 2D representation. Combining data from different sources often involves re-projecting to a common reference system in order to ensure that all layers are aligned.

Although each image is 2D, the Earth itself is 3D. In order to stitch together images, they first need to be projected onto a 2D representation of the Earth, called a coordinate reference system (CRS). Most people are familiar with equal angle representations like Mercator that distort the size of regions (Greenland looks larger than Africa even though Africa is 15x larger), but there are many other CRSs that are commonly used. Each dataset may use a different CRS, and each image within a single dataset may also be in a unique CRS. In order to use data from multiple layers, they must all share a common CRS, otherwise the data won’t be properly aligned. For those who aren’t familiar with remote sensing data, this can be a daunting task.

Even if you correctly georeference images during indexing, if you don’t project them to a common CRS, you’ll end up with rotated images with nodata values around them, and the images won’t be pixel-aligned.

The solution

At the moment, it can be quite challenging to work with both deep learning models and geospatial data without having expertise in both of these very different fields. To address these challenges, we’ve built TorchGeo, a PyTorch domain library for working with geospatial data. TorchGeo is designed to make it simple:

  1. for machine learning experts to work with geospatial data, and
  2. for remote sensing experts to explore machine learning solutions.

TorchGeo is not just a research project, but a production-quality library that uses continuous integration to test every commit with a range of Python versions on a range of platforms (Linux, macOS, Windows). It can be easily installed with any of your favorite package managers, including pip, conda, and spack:

$ pip install torchgeo

TorchGeo is designed to have the same API as other PyTorch domain libraries like torchvision, torchtext, and torchaudio. If you already use torchvision in your workflow for computer vision datasets, you can switch to TorchGeo by changing only a few lines of code. All TorchGeo datasets and samplers are compatible with the PyTorch DataLoader class, meaning that you can take advantage of wrapper libraries like PyTorch Lightning for distributed training. In the following sections, we’ll explore possible use cases for TorchGeo to show how simple it is to use.

Geospatial datasets and samplers

Example application in which we combine A) a scene from Landsat 8 and B) Cropland Data Layer labels, even though these files are in different EPSG projections. We want to sample patches C) and D) from these datasets using a geospatial bounding box as an index.

Many remote sensing applications involve working with geospatial datasets —datasets with geographic metadata. In TorchGeo, we define a GeoDataset class to represent these kinds of datasets. Instead of being indexed by an integer, each GeoDataset is indexed by a spatiotemporal bounding box, meaning that two or more datasets covering a different geographic extent can be intelligently combined.

In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of Landsat and Cropland Data Layer (CDL) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we’ll only use the bands that both satellites have in common. We’ll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets.

from torch.utils.data import DataLoader
from torchgeo.datasets import CDL, Landsat7, Landsat8, stack_samples
from torchgeo.samplers import RandomGeoSampler

landsat7 = Landsat7(root="...")
landsat8 = Landsat8(root="...", bands=Landsat8.all_bands[1:-2])
landsat = landsat7 | landsat8

Next, we take the intersection between this dataset and the CDL dataset. We want to take the intersection instead of the union to ensure that we only sample from regions where we have both Landsat and CDL data. Note that we can automatically download and checksum CDL data. Also note that each of these datasets may contain files in different CRSs or resolutions, but TorchGeo automatically ensures that a matching CRS and resolution is used.

cdl = CDL(root="...", download=True, checksum=True)
dataset = landsat & cdl

This dataset can now be used with a PyTorch data loader. Unlike benchmark datasets, geospatial datasets often include very large images. For example, the CDL dataset consists of a single image covering the entire contiguous United States. In order to sample from these datasets using geospatial coordinates, TorchGeo defines a number of samplers. In this example, we’ll use a random sampler that returns 256 x 256 pixel images and 10,000 samples per epoch. We’ll also use a custom collation function to combine each sample dictionary into a mini-batch of samples.

sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)

This data loader can now be used in your normal training/evaluation pipeline.

for batch in dataloader:
    image = batch["image"]
    mask = batch["mask"]

    # train a model, or make predictions using a pre-trained model

Many applications involve intelligently composing datasets based on geospatial metadata like this. For example, users may want to:

  • Combine datasets for multiple image sources and treat them as equivalent (e.g., Landsat 7 and 8)
  • Combine datasets for disparate geospatial locations (e.g., Chesapeake NY and PA)

These combinations require that all queries are present in at least one dataset, and can be created using a UnionDataset. Similarly, users may want to:

  • Combine image and target labels and sample from both simultaneously (e.g., Landsat and CDL)
  • Combine datasets for multiple image sources for multimodal learning or data fusion (e.g., Landsat and Sentinel)

These combinations require that all queries are present in both datasets, and can be created using an IntersectionDataset. TorchGeo automatically composes these datasets for you when you use the intersection (&) and union (|) operators.

Multispectral and geospatial transforms

In deep learning, it’s common to augment and transform the data so that models are robust to variations in the input space. Geospatial data can have variations such as seasonal changes and warping effects, as well as image processing and capture issues like cloud cover and atmospheric distortion. TorchGeo utilizes augmentations and transforms from the Kornia library, which supports GPU acceleration and supports multispectral imagery with more than 3 channels.

Traditional geospatial analyses compute and visualize spectral indices which are combinations of multispectral bands. Spectral indices are designed to highlight areas of interest in a multispectral image relevant to some application, such as vegetation health, areas of man-made change or increasing urbanization, or snow cover. TorchGeo supports numerous transforms, which can compute common spectral indices and append them as additional bands to a multispectral image tensor.

Below, we show a simple example where we compute the Normalized Difference Vegetation Index (NDVI) on a Sentinel-2 image. NDVI measures the presence of vegetation and vegetation health and is computed as the normalized difference between the red and near-infrared (NIR) spectral bands. Spectral index transforms operate on sample dictionaries returned from TorchGeo datasets and append the resulting spectral index to the image channel dimension.

First, we instantiate a Sentinel-2 dataset and load a sample image. Then, we plot the true color (RGB) representation of this data to see the region we are looking at.

import matplotlib.pyplot as plt
from torchgeo.datasets import Sentinel2
from torchgeo.transforms import AppendNDVI

dataset = Sentinel2(root="...")
sample = dataset[...]
fig = dataset.plot(sample)
plt.show()

Next, we instantiate and compute an NDVI transform, appending this new channel to the end of the image. Sentinel-2 imagery uses index 0 for its red band and index 3 for its NIR band. In order to visualize the data, we also normalize the image. NDVI values can range from -1 to 1, but we want to use the range 0 to 1 for plotting.

transform = AppendNDVI(index_red=0, index_nir=3)
sample = transform(sample)
sample["image"][-1] = (sample["image"][-1] + 1) / 2
plt.imshow(sample["image"][-1], cmap="RdYlGn_r")
plt.show()

True color (left) and NDVI (right) of the Texas Hill Region, taken on November 16, 2018 by the Sentinel-2 satellite. In the NDVI image, red indicates water bodies, yellow indicates barren soil, light green indicates unhealthy vegetation, and dark green indicates healthy vegetation.

Benchmark datasets

One of the driving factors behind progress in computer vision is the existence of standardized benchmark datasets like ImageNet and MNIST. Using these datasets, researchers can directly compare the performance of different models and training procedures to determine which perform the best. In the remote sensing domain, there are many such datasets, but due to the aforementioned difficulties of working with this data and the lack of existing libraries for loading these datasets, many researchers opt to use their own custom datasets.

One of the goals of TorchGeo is to provide easy-to-use data loaders for these existing datasets. TorchGeo includes a number of benchmark datasets —datasets that include both input images and target labels. This includes datasets for tasks like image classification, regression, semantic segmentation, object detection, instance segmentation, change detection, and more.

If you’ve used torchvision before, these types of datasets should be familiar. In this example, we’ll create a dataset for the Northwestern Polytechnical University (NWPU) very-high-resolution ten-class (VHR-10) geospatial object detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision.

from torch.utils.data import DataLoader
from torchgeo.datasets import VHR10

dataset = VHR10(root="...", download=True, checksum=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

for batch in dataloader:
    image = batch["image"]
    label = batch["label"]

    # train a model, or make predictions using a pre-trained model

All TorchGeo datasets are compatible with PyTorch data loaders, making them easy to integrate into existing training workflows. The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch Tensor.

Example predictions from a Mask R-CNN model trained on the NWPU VHR-10 dataset. The model predicts sharp bounding boxes and masks for all objects with high confidence scores.

Reproducibility with PyTorch Lightning

Another key goal of TorchGeo is reproducibility. For many of these benchmark datasets, there is no predefined train-val-test split, or the predefined split has issues with class imbalance or geographic distribution. As a result, the performance metrics reported in the literature either can’t be reproduced, or aren’t indicative of how well a pre-trained model would work in a different geographic location.

In order to facilitate direct comparisons between results published in the literature and further reduce the boilerplate code needed to run experiments with datasets in TorchGeo, we have created PyTorch Lightning datamodules with well-defined train-val-test splits and trainers for various tasks like classification, regression, and semantic segmentation. These datamodules show how to incorporate augmentations from the kornia library, include preprocessing transforms (with pre-calculated channel statistics), and let users easily experiment with hyperparameters related to the data itself (as opposed to the modeling process). Training a regression model on the Inria Aerial Image Labeling dataset is as easy as a few imports and four lines of code.

from pytorch_lightning import Trainer
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
from torchgeo.trainers import SemanticSegmentationTask

datamodule = InriaAerialImageLabelingDataModule(root_dir="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(model="resnet18", pretrained=True, learning_rate=0.1)
trainer = Trainer(gpus=1, default_root_dir="...")

trainer.fit(model=task, datamodule=datamodule)

Building segmentations produced by a U-Net model trained on the Inria Aerial Image Labeling dataset. Reproducing these results is as simple as a few imports and four lines of code, making comparison of different models and training techniques simple and easy.

In our preprint we show a set of results that use the aforementioned datamodules and trainers to benchmark simple modeling approaches for several of the datasets in TorchGeo. For example, we find that a simple ResNet-50 can achieve state-of-the-art performance on the So2Sat dataset. These types of baseline results are important for evaluating the contribution of different modeling choices when tackling problems with remotely sensed data.

Future work and contributing

There is still a lot of remaining work to be done in order to make TorchGeo as easy to use as possible, especially for users without prior deep learning experience. One of the ways in which we plan to achieve this is by expanding our tutorials to include subjects like “writing a custom dataset” and “transfer learning”, or tasks like “land cover mapping” and “object detection”.

Another important project we are working on is pre-training models. Most remote sensing researchers work with very small labeled datasets, and could benefit from pre-trained models and transfer learning approaches. TorchGeo is the first deep learning library to provide models pre-trained on multispectral imagery. Our goal is to provide models for different image modalities (optical, SAR, multispectral) and specific platforms (Landsat, Sentinel, MODIS) as well as benchmark results showing their performance with different amounts of training data. Self-supervised learning is a promising method for training such models. Satellite imagery datasets often contain petabytes of imagery, but accurately labeled datasets are much harder to come by. Self-supervised learning methods will allow us to train directly on the raw imagery without needing large labeled datasets.

Aside from these larger projects, we’re always looking to add new datasets, data augmentation transforms, and sampling strategies. If you’re Python savvy and interested in contributing to TorchGeo, we would love to see contributions! TorchGeo is open source under an MIT license, so you can use it in almost any project.

External links:

If you like TorchGeo, give us a star on GitHub! And if you use TorchGeo in your work, please cite our paper.

Acknowledgments

We would like to thank all TorchGeo contributors for their efforts in creating the library, the Microsoft AI for Good program for support, and the PyTorch Team for their guidance. This research is part of the Blue Waters sustained-petascale computing project, which is supported by the National Science Foundation (awards OCI-0725070 and ACI-1238993), the State of Illinois, and as of December, 2019, the National Geospatial-Intelligence Agency. Blue Waters is a joint effort of the University of Illinois at Urbana-Champaign and its National Center for Supercomputing Applications. The research was supported in part by NSF grants IIS-1908104, OAC-1934634, and DBI-2021898.

Read More

How Disney Improved Activity Recognition Through Multimodal Approaches with PyTorch

Introduction

Among the many things Disney Media & Entertainment Distribution (DMED) is responsible for, is the management and distribution of a huge array of media assets including news, sports, entertainment and features, episodic programs, marketing and advertising and more.

Our team focuses on media annotation as part of DMED Technology’s content platforms group. In our day-to-day work, we automatically analyze a variety of content that constantly challenges the efficiency of our machine learning workflow and the accuracy of our models.

Several of our colleagues recently discussed the workflow efficiencies that we achieved by switching to an end-to-end video analysis pipeline using PyTorch, as well as how we approach animated character recognition. We invite you to read more about both in this previous post.

While the conversion to an end-to-end PyTorch pipeline is a solution that any company might benefit from, animated character recognition was a uniquely-Disney concept and solution.

In this article we will focus on activity recognition, which is a general challenge across industries — but with some specific opportunities when leveraged in the media production field, because we can combine audio, video, and subtitles to provide a solution.

Experimenting with Multimodality

Working on a multimodal problem adds more complexity to the usual training pipelines. Having multiple information modes for each example means that the multimodal pipeline has to have specific implementations to process each mode in the dataset. Usually after this processing step, the pipeline has to merge or fuse the outputs.

Our initial experiments in multimodality were completed using the MMF framework. MMF is a modular framework for vision and language multimodal research. MMF contains reference implementations of state-of-the-art vision and language models and has also powered multiple research projects at Meta AI Research (as seen in this poster presented in PyTorch Ecosystem Day 2020). Along with the recent release of TorchMultimodal, a PyTorch library for training state-of-the-art multimodal models at scale, MMF highlights the growing interest in Multimodal understanding.

MMF tackles this complexity with modular management of all the elements of the pipeline through a wide set of different implementations for specific modules, ranging from the processing of the modalities to the fusion of the processed information.

In our scenario, MMF was a great entry point to experiment with multimodality. It allowed us to iterate quickly by combining audio, video and closed captioning and experiment at different levels of scale with certain multimodal models, shifting from a single GPU to TPU Pods.

Multimodal Transformers

With a workbench based on MMF, our initial model was based on a concatenation of features from each modality evolving to a pipeline that included a Transformer-based fusion module to combine the different input modes.

Specifically, we made use of the fusion module called MMFTransformer, developed in collaboration with the Meta AI Research team. This is an implementation based on VisualBERT for which the necessary modifications were added to be able to work with text, audio and video.

Despite having decent results with the out-of-box implementation MMFTransformer, we were still far from our goal, and the Transformers-based models required more data than we had available.

Searching for less data-hungry solutions

Searching for less data-hungry solutions, our team started studying MLP-Mixer. This new architecture has been proposed by the Google Brain team and it provides an alternative to well established de facto architectures like convolutions or self-attention for computer vision tasks.

MLP-Mixer

The core idea behind mixed variations consists of replacing the convolutions or self-attention mechanisms used in transformers with Multilayer Perceptrons. This change in architecture favors the performance of the model in high data regimes (especially with respect to the Transformers), while also opening some questions regarding the inductive biases hidden in the convolutions and the self-attention layers.

Those proposals perform great in solving image classification tasks by splitting the image in chunks, flattening those chunks into 1D vectors and passing them through a sequence of Mixer Layers.

Inspired by the advantages of Mixer based architectures, our team searched for parallelisms with the type of problems we try to solve in video classification: specifically, instead of a single image, we have a set of frames that need to be classified, along with audio and closed captioning in the shape of new modalities.

Activity Recognition reinterpreting the MLP-Mixer

Our proposal takes the core idea of the MLP-Mixer — using multiple multi-layer perceptrons on a sequence and transposed sequence and extends it into a Multi Modal framework that allows us to process video, audio & text with the same architecture.

For each of the modalities, we use different extractors that will provide embeddings describing the content. Given the embeddings of each modality, the MLP-Mixer architecture solves the problem of deciding which of the modalities might be the most important, while also weighing how much each modality contributes to the final labeling.

For example, when it comes to detecting laughs, sometimes the key information is in audio or in the frames, and in some of the cases we have a strong signal in the closed caption.

We tried processing each frame separately with a ResNet34 and getting a sequence of embeddings and by using a video-specific model called R3D, both pre-trained on ImageNet and Kinetics400 respectively.

To process the audio, we use the pretrained ResNet34, and we remove the final layers to be able to extract 2D embeddings from the audio spectrograms (for 224×224 images we end up with 7×7 embeddings).

For closed captioning, we are using a pre-trained BERT-large, with all layers frozen, except for the Embeddings & LayerNorms.

Once we have extracted the embedding from each modality, we concatenate them into a single sequence and pass it through a set of MLP-Mixer blocks; next we use average pooling & a classification head to get predictions.

Our experiments have been performed on a custom, manually labeled dataset for activity recognition with 15 classes, which we know from experiments are hard and cannot all be predicted accurately using a single modality.

These experiments have shown a significant increase in performance using our approach, especially in a low/mid-data regime (75K training samples).

When it comes to using only Text and Audio, our experiments showed a 15 percent improvement in accuracy over using a classifier on top of the features extracted by state-of-the-art backbones.

Using Text, Audio and Video we have seen a 17 percent improvement in accuracy over using Meta AIFacebook’s MMF Framework, which uses a VisualBERT-like model to combine modalities using more powerful state of the art backbones.

Currently, we extended the initial model to cover up to 55 activity classes and 45 event classes. One of the challenges we expect to improve upon in the future is to include all activities and events, even those that are less frequent.

Interpreting the MLP-Mixer mode combinations

An MLP-Mixer is a concatenation of MultiLayer Perceptrons. This can be, very roughly, approximated to a linear operation, in the sense that, once trained, the weights are fixed and the input will directly affect the output.

Once we assume that approximation, we also assume that for an input consisting of NxM numbers, we could find a NxM matrix that (when multiplied elementwise) could approximate the predictions of the MLP-Mixer for a class.

We will call this matrix a stencil, and if we have access to it, we can find what parts of the input embeddings are responsible for a specific prediction.

You can think of it as a punch card with holes in specific positions. Only information in those positions will pass and contribute to a specific prediction. So we can measure the intensity of the input at those positions.

Of course, this is an oversimplification, and there won’t exist a unique stencil that perfectly represents all of the contributions of the input to a class (otherwise that would mean that the problem could be solved linearly). So this should be used for visualization purposes only, not as an accurate predictor.

Once we have a set of stencils for each class, we can effortlessly measure input contribution without relying on any external visualization techniques.

To find a stencil, we can start from a “random noise” stencil and optimize it to maximize the activations for a specific class by just back-propagating through the MLP-Mixer.

By doing this we can end up with many valid stencils, and we can reduce them to a few by using K-means to cluster them into similar stencils and averaging each cluster.

Using the Mixer to get the best of each world

MLP-Mixer, used as an image classification model without convolutional layers, requires a lot of data, since the lack of inductive bias – one of the model’s good points overall – is a weakness when it comes to working in low data domains.

When used as a way to combine information previously extracted by large pretrained backbones (as opposed to being used as a full end-to-end solution), they shine. The Mixer’s strength lies in finding temporal or structural coherence between different inputs. For example, in video-related tasks we could extract embeddings from the frames using a powerful, pretrained model that understands what is going on at frame level and use the mixer to make sense of it in a sequential manner.

This way of using the Mixer allows us to work with limited amounts of data and still get better results than what was achieved with Transformers. This is because Mixers seem to be more stable during training and seem to pay attention to all the inputs, while Transformers tend to collapse and pay attention only to some modalities/parts of the sequence.

Acknowledgements: We would like to thank the Meta AI Research and Partner Engineering teams for this collaboration.

Read More