Announcing PyTorch Conference 2022

We are excited to announce that the PyTorch Conference returns in-person as a satellite event to NeurlPS (Neural Information Processing Systems) in New Orleans on Dec. 2nd.

We changed the name from PyTorch Developer Day to PyTorch Conference to signify the turning of a new chapter as we look to the future of PyTorch, encompassing the entire PyTorch Community. This conference will bring together leading researchers, academics and developers from the Machine Learning (ML) and Deep Learning (DL) communities to join a multiple set of talks and a poster session; covering new software releases on PyTorch, use cases in academia and industry, as well as ML/DL development and production trends.

EVENT OVERVIEW

When: Dec 2nd, 2022 (In-Person and Virtual)

Where: New Orleans, Louisiana (USA) | Virtual option as well

SCHEDULE

All times in Central Standard.

8:00-9:00 am   Registration/Check in

9:00-11:20 am   Keynote & Technical Talks

11:30-1:00 pm   Lunch

1:00-3:00 pm   Poster Session & Breakouts

3:00-4:00 pm   Community/Partner Talks

4:00-5:00 pm   Panel Discussion

Agenda subject to change.

All talks will be livestreamed and available to the public. The in-person event will be by invitation only as space is limited. If you’d like to apply to attend in person, please submit all requests here.

LINKS

Read More

PyTorch strengthens its governance by joining the Linux Foundation

Today, I am proud to announce that PyTorch is moving to the Linux Foundation (LF) as a top-level project under the name PyTorch Foundation. The core mission of the Linux Foundation is the collaborative development of open source software. With a governing board of leaders from AMD, Amazon Web Services (AWS), Google Cloud, Meta, Microsoft Azure and NVIDIA, this model aligns with where PyTorch stands today and what it needs to travel forward. The creation of the PyTorch Foundation will ensure business decisions are being made in a transparent and open manner by a diverse group of members for years to come. The technical decisions remain in control of individual maintainers. I’m excited that the Linux Foundation will be our new home as they have notable experience supporting large open-source projects like ours such as Kubernetes and NodeJS. At this pivotal moment, I want to take a look back at how we started, share why we are moving, and what’s ahead.

This January, PyTorch celebrated its 5 year anniversary! I reflected on what it meant to me in this tweet thread, and this conversation with my colleagues Mike Schroepfer, Lin Qiao, and Yann LeCun. When we started PyTorch development in 2016, it was a collective effort by a band of people from the [Lua]Torch community with a big chunk of people and funding from Meta and individuals contributing from NVIDIA, Twitter and other entities.

Since 2017, PyTorch has grown far beyond our initial vision. With over 2,400 contributors who have built nearly 154,000 projects using PyTorch as a foundation, PyTorch has become one of the primary platforms for AI research, as well as commercial production use. We’ve seen its impact across industry and academia, from large companies to numerous university courses at Stanford, NYU, EPFL, Oxford, and other academic institutions. As a maintainer of PyTorch, the journey has been extremely fulfilling, with the impact of the project seen in various fields from self-driving cars to healthcare to aerospace.

As PyTorch grew, many companies have made foundational investments around it. While Meta remains the largest contributor to PyTorch, companies such as AMD, Amazon Web Services (AWS), Google Cloud, HuggingFace, Lightning AI, Microsoft Azure, Nvidia, and many others have made significant investments, including both technical contributions and community building efforts. They’ve established teams around PyTorch or filled significant voids within the PyTorch community and sent countless contributions to the PyTorch core and to the ecosystem around it — PyTorch is an important part of their future. With PyTorch continuing to grow as a multi-stakeholder project, it’s time to move to a broader open-source foundation.

The business governance of PyTorch was fairly unstructured for quite some time since launch – we operated like a scrappy startup. Team members at Meta spent the time and energy to structure this properly and organize PyTorch into an organizationally more healthy entity. Meta helped PyTorch with introducing many structures, such as Contributor License Agreements, Branding Guidelines, and Trademark registration. Keeping PyTorch’s organizational health up to check is essential and beneficial for the community. The next stage of our organizational progress is to support the interests of multiple stakeholders, hence moving to a foundation is good. We chose the Linux Foundation as it has vast organization experience hosting large multi-stakeholder open-source projects with the right balance of organizational structure and finding specific solutions for these projects.

Simultaneously, the technical governance of PyTorch has been a loosely structured community model of open-source development — A set of people maintaining PyTorch by area with their responsibility often tied to their individual identity rather than their employment. While we kept a codified list at the PyTorch – Persons of Interest page, the technical governance was not formalized nor codified. As PyTorch scales as a community, the next step is to structure and codify. The PyTorch Technical Governance now supports a hierarchical maintainer structure and clear outlining of processes around day to day work and escalations. This doesn’t change how we run things, but it does add discipline and openness that at our scale feels essential and timely.

It’s been an exciting journey since 2016. I am grateful for the experiences and people I’ve met along the way. PyTorch started with a small group of contributors which have grown and diversified over the years, all bringing in new ideas and innovations that would not have been possible without our community. We want to continue the open-source spirit – for the community and by the community. Thank you to our contributors, maintainers, users, supporters and new foundation members. We look forward to the next chapter of PyTorch with the PyTorch Foundation.

Read More

Fast Beam Search Decoding in PyTorch with TorchAudio and Flashlight Text

Beam search decoding with industry-leading speed from Flashlight Text (part of the Flashlight ML framework) is now available with official support in TorchAudio, bringing high-performance beam search and text utilities for speech and text applications built on top of PyTorch. The current integration supports CTC-style decoding, but it can be used for any modeling setting that outputs token-level probability distributions over time steps.

A brief beam search refresher

In speech and language settings, beam search is an efficient, greedy algorithm that can convert sequences of continuous values (i.e. probabilities or scores) into graphs or sequences (i.e. tokens, word-pieces, words) using optional constraints on valid sequences (i.e. a lexicon), optional external scoring (i.e. an LM which scores valid sequences), and other score adjustments for particular sequences.

In the example that follows, we’ll consider — a token set of {ϵ, a, b}, where ϵ is a special token that we can imagine denotes a space between words or a pause in speech. Graphics here and below are taken from Awni Hannun’s excellent distill.pub writeup on CTC and beam search.

With a greedy-like approach, beam search considers the next viable token given an existing sequence of tokens — in the example above, a, b, b is a valid sequence, but a, b, a is not. We rank each possible next token at each step of the beam search according to a scoring function. Scoring functions (s) typically looks something like:

Where ŷ is a potential path/sequence of tokens, x is the input (P(ŷ|x) represents the model’s predictions over time), and 𝛼 is a weight on the language model probability (P(y) the probability of the sequence under the language model). Some scoring functions add 𝜷 which adjusts a score based on the length of the predicted sequence |ŷ|. This particular scoring function is used in FAIR’s prior work on end-to-end ASR, and there are many variations on scoring functions which can vary across application areas.

Given a particular sequence, to assess the next viable token in that sequence (perhaps constrained by a set of allowed words or sequences, such as a lexicon of words), the beam search algorithm scores the sequence with each candidate token added, and sorts token candidates based on those scores. For efficiency and since the number of paths is exponential in the token set size, the top-k highest-scoring candidates are kept — k represents the beam size.

There are many other nuances with how beam search can progress: similar hypothesis sequences can be “merged”, for instance.

The scoring function can be further augmented to up/down-weight token insertion or long or short words. Scoring with stronger external language models, while incurring computational cost, can also significantly improve performance; this is frequently referred to as LM fusion. There are many other knobs to tune for decoding — these are documented in TorchAudio’s documentation and explored further in TorchAudio’s ASR Inference tutorial. Since decoding is quite efficient, parameters can be easily swept and tuned.

Beam search has been used in ASR extensively over the years in far too many works to cite, and in strong, recent results and systems including wav2vec 2.0 and NVIDIA’s NeMo.

Why beam search?

Beam search remains a fast competitor to heavier-weight decoding approaches such as RNN-Transducer that Google has invested in putting on-device and has shown strong results with on common benchmarks. Autoregressive text models at scale can benefit from beam search as well. Among other things, beam search gives:

  • A flexible performance/latency tradeoff — by adjusting beam size and the external LM, users can sacrifice latency for accuracy or pay for more accurate results with a small latency cost. Decoding with no external LM can improve results at very little performance cost.
  • Portability without retraining — existing neural models can benefit from multiple decoding setups and plug-and-play with external LMs without training or fine-tuning.
  • A compelling complexity/accuracy tradeoff — adding beam search to an existing modeling pipeline incurs little additional complexity and can improve performance.

Performance Benchmarks

Today’s most commonly-used beam search decoding libraries today that support external language model integration include Kensho’s pyctcdecode, NVIDIA’s NeMo toolkit. We benchmark the TorchAudio + Flashlight decoder against them with a wav2vec 2.0 base model trained on 100 hours of audio evaluated on LibriSpeech dev-other with the official KenLM 3-gram LM. Benchmarks were run on Intel E5-2698 CPUs on a single thread. All computation was in-memory — KenLM memory mapping was disabled as it wasn’t widely supported.

When benchmarking, we measure the time-to-WER (word error rate) — because of subtle differences in the implementation of decoding algorithms and the complex relationships between parameters and decoding speed, some hyperparameters differed across runs. To fairly assess performance, we first sweep for parameters that achieve a baseline WER, minimizing beam size if possible.

Decoding performance on Librispeech dev-other of a pretrained wav2vec 2.0 model. TorchAudio + Flashlight decoding outperforms by an order of magnitude at low WERs.

Time-to-WER results, deferring to smaller beam size, across decoders. The TorchAudio + Flashlight decoder scales far better with larger beam sizes and at lower WERs.

TorchAudio API and Usage

TorchAudio provides a Python API for CTC beam search decoding, with support for the following:

  • lexicon and lexicon-free decoding
  • KenLM n-gram language model integration
  • character and word-piece decoding
  • sample pretrained LibriSpeech KenLM models and corresponding lexicon and token files
  • various customizable beam search parameters (beam size, pruning threshold, LM weight…)

To set up the decoder, use the factory function torchaudio.models.decoder.ctc_decoder

from torchaudio.models.decoder import ctc_decoder, download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
decoder = ctc_decoder(
   lexicon=files.lexicon,
   tokens=files.tokens,
   lm=files.lm,
   nbest=1,
   ... additional optional customizable args ...
)

Given emissions of shape (batch, time, num_tokens), the decoder will compute and return a List of batch Lists, each consisting of the nbest hypotheses corresponding to the emissions. Each hypothesis can be further broken down into tokens, words (if a lexicon is provided), score, and timesteps components.

emissions = acoustic_model(waveforms)  # (B, T, N)
batch_hypotheses = decoder(emissions)  # List[List[CTCHypothesis]]

# transcript for a lexicon decoder
transcripts = [" ".join(hypo[0].words) for hypo in batch_hypotheses]

# transcript for a lexicon free decoder, splitting by sil token
batch_tokens = [decoder.idxs_to_tokens(hypo[0].tokens) for hypo in batch_hypotheses]
transcripts = ["".join(tokens) for tokens in batch_tokens]

Please refer to the documentation for more API details, and the tutorial (ASR Inference Decoding) or sample inference script for more usage examples.

Upcoming Improvements

Full NNLM support — decoding with large neural language models (e.g. transformers) remains somewhat unexplored at scale. Already supported in Flashlight, we plan to add support in TorchAudio, allowing users to use custom decoder-compatible LMs. Custom word level language models are already available in the nightly TorchAudio build, and is slated to be released in TorchAudio 0.13.

Autoregressive/seq2seq decoding — Flashlight Text also supports sequence-to-sequence (seq2seq) decoding for autoregressive models, which we hope to add bindings for and add to TorchAudio and TorchText with efficient GPU implementations as well.

Better build support — to benefit from improvements in Flashlight Text, TorchAudio will directly submodule Flashlight Text to make upstreaming modifications and improvements easier. This is already in effect in the nightly TorchAudio build, and is slated to be released in TorchAudio 0.13.

Citation

To cite the decoder, please use the following:

@inproceedings{kahn2022flashlight,
  title={Flashlight: Enabling innovation in tools for machine learning},
  author={Kahn, Jacob D and Pratap, Vineel and Likhomanenko, Tatiana and Xu, Qiantong and Hannun, Awni and Cai, Jeff and Tomasello, Paden and Lee, Ann and Grave, Edouard and Avidov, Gilad and others},
  booktitle={International Conference on Machine Learning},
  pages={10557--10574},
  year={2022},
  organization={PMLR}
}
@inproceedings{yang2022torchaudio,
  title={Torchaudio: Building blocks for audio and speech processing},
  author={Yang, Yao-Yuan and Hira, Moto and Ni, Zhaoheng and Astafurov, Artyom and Chen, Caroline and Puhrsch, Christian and Pollack, David and Genzel, Dmitriy and Greenberg, Donny and Yang, Edward Z and others},
  booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
  pages={6982--6986},
  year={2022},
  organization={IEEE}
}

Read More

Introducing nvFuser, a deep learning compiler for PyTorch

nvFuser is a Deep Learning Compiler for NVIDIA GPUs that automatically just-in-time compiles fast and flexible kernels to reliably accelerate users’ networks. It provides significant speedups for deep learning networks running on Volta and later CUDA accelerators by generating fast custom “fusion” kernels at runtime. nvFuser is specifically designed to meet the unique requirements of the PyTorch community, and it supports diverse network architectures and programs with dynamic inputs of varying shapes and strides.
In this blog post we’ll describe nvFuser and how it’s used today, show the significant performance improvements it can obtain on models from HuggingFace and TIMM, and look ahead to nvFuser in PyTorch 1.13 and beyond. If you would like to know more about how and why fusion improves the speed of training for Deep Learning networks, please see our previous talks on nvFuser from GTC 2022 and GTC 2021.
nvFuser relies on a graph representation of PyTorch operations to optimize and accelerate. Since PyTorch has an eager execution model, the PyTorch operations users are running are not directly accessible as a whole program that can be optimized by a system like nvFuser. Therefore users must utilize systems built on top of nvFuser which are capable of capturing users programs and translating them into a form that is optimizable by nvFuser. These higher level systems then pass these captured operations to nvFuser, so that nvFuser can optimize the execution of the user’s script for NVIDIA GPUs. There are three systems that capture, translate, and pass user programs to nvFuser for optimization:

  • TorchScript jit.script

    • This system directly parses sections of an annotated python script to translate into its own representation what the user is doing. This system then applies its own version of auto differentiation to the graph, and passes sections of the subsequent forward and backwards graphs to nvFuser for optimization.
  • FuncTorch

    • This system doesn’t directly look at the user python script, instead inserting a mechanism that captures PyTorch operations as they’re being run. We refer to this type of capture system as “trace program acquisition”, since we’re tracing what has been performed. FuncTorch doesn’t perform its own auto differentiation – it simply traces PyTorch’s autograd directly to get backward graphs.
  • TorchDynamo

    • TorchDynamo is another program acquisition mechanism built on top of FuncTorch. TorchDynamo parses the Python bytecode produced from the user script in order to select portions to trace with FuncTorch. The benefit of TorchDynamo is that it’s able to apply decorators to a user’s script, effectively isolating what should be sent to FuncTorch, making it easier for FuncTorch to successfully trace complex Python scripts.

These systems are available for users to interact with directly while nvFuser automatically and seamlessly optimizes performance critical regions of the user’s code. These systems automatically send parsed user programs to nvFuser so nvFuser can:

  1. Analyze the operations being run on GPUs
  2. Plan parallelization and optimization strategies for those operations
  3. Apply those strategies in generated GPU code
  4. Runtime-compile the generated optimized GPU functions
  5. Execute those CUDA kernels on subsequent iterations

It is important to note nvFuser does not yet support all PyTorch operations, and there are still some scenarios that are actively being improved in nvFuser that are discussed herein. However, nvFuser does support many DL performance critical operations today, and the number of supported operations will grow in subsequent PyTorch releases. nvFuser is capable of generating highly specialized and optimized GPU functions for the operations it does have support for. This means nvFuser is able to power new PyTorch systems like TorchDynamo and FuncTorch to combine the flexibility PyTorch is known for with unbeatable performance.

nvFuser Performance

Before getting into how to use nvFuser, in this section we’ll show the improvements in training speed nvFuser provides for a variety of models from the HuggingFace Transformers and PyTorch Image Models (TIMM) repositories and we will discuss current gaps in nvFuser performance that are under development today. All performance numbers in this section were taken using an NVIDIA A100 40GB GPU, and used either FuncTorch alone or Functorch with TorchDynamo.

HuggingFace Transformer Benchmarks

nvFuser can dramatically accelerate training of HuggingFace Transformers when combined with another important optimization (more on that in a moment). Performance improvements can be seen in Figure 1 to range between 1.12x and 1.50x across a subset of popular HuggingFace Transformer networks.

Figure 1: Performance gains of 8 training scenarios from HuggingFace’s Transformer repository. First performance boost in the dark green is due to replacing the optimizer with an NVIDIA Apex fused AdamW optimizer. The light green is due to adding nvFuser. Models were run with batch size and sequence lengths of [64, 128], [8, 512], [2, 1024], [64, 128], [8, 512], [8, src_seql=512, tgt_seql=128], [8, src_seql=1024, tgt_seql=128], and [8, 512] respectively. All networks were run with Automatic Mixed Precision (AMP) enabled with dtype=float16.

While these speedups are significant, it’s important to understand that nvFuser doesn’t (yet) automate everything about running networks quickly. For HuggingFace Transformers, for example, it was important to use the AdamW fused optimizer from NVIDIA’s Apex repository as the optimizer otherwise consumed a large portion of runtime. Using the fused AdamW optimizer to make the network faster exposes the next major performance bottleneck — memory bound operations. These operations are optimized by nvFuser, providing another large performance boost. With the fused optimizer and nvFuser enabled, the training speed of these networks improved between 1.12x to 1.5x.
HuggingFace Transformer models were run with the torch.amp module. (“amp” stands for Automated Mixed Precision, see the “What Every User Should Know about Mixed Precision in PyTorch” blog post for details.) An option to use nvFuser was added to HuggingFace’sTrainer. If you have TorchDynamo installed you can activate it to enable nvFuser in HuggingFace by passing torchdynamo = ‘nvfuser’ to the Trainer class.
nvFuser has great support for normalization kernels and related fusions frequently found in Natural Language Processing (NLP) models, and it is recommended users try nvFuser in their NLP workloads.

PyTorch Image Models (TIMM) Benchmarks

nvFuser, can also significantly reduce the training time of TIMM networks, up to over 1.3x vs. eager PyTorch, and up to 1.44x vs. eager PyTorch when combined with the torch.amp module. Figure 1 shows nvFuser’s speedup without torch.amp, and when torch.amp is used with the NHWC (“channels last”) and NCHW (“channels first”) formats. nvFuser is integrated in TIMM through FuncTorch tracing directly (without TorchDynamo) and can be used by adding the –aot-autograd command line argument when running the TIMM benchmark or training script.

Figure 1: The Y-axis is the performance gain nvFuser provides over not using nvFuser. A value of 1.0 means no change in perf, 2.0 would mean nvFuser is twice as fast, 0.5 would mean nvFuser takes twice the time to run. Square markers are with float16 Automatic Mixed Precision (AMP) and channels first contiguous inputs, circle markers are float32 inputs, and triangles are with float16 AMP and channels last contiguous inputs. Missing data points are due to an error being encountered when tracing.

When running with float32 precision nvFuser provides a 1.12x geometric mean (“geomean”) speedup on TIMM networks, and when running with torch.amp and “channels first” it provides a 1.14x geomean speedup. However, nvFuser currently doesn’t speedup torch.amp and “channels last” training (a .9x geomean regression), so we recommend not using it in those cases. We are actively working on improving “channels last” performance now, and soon we will have two additional optimization strategies (grid persistent optimizations for channels-last normalizations and fast transposes) which we expect will provide speedups comparable to “channels first” in PyTorch version 1.13 and later. Many of nvFuser’s optimizations can also help in inference cases. However, in PyTorch when running inference on small batch sizes, the performance is typically limited by CPU overhead, which nvFuser can’t completely remove or fix. Therefore, typically the most important optimization for inference is to enable CUDA Graphs when possible. Once CUDA Graphs is enabled, then it can also be beneficial to also enable fusion through nvFuser. Performance of inference is shown in Figure 2 and Figure 3. Inference is only run with float16 AMP as it is uncommon to run inference workloads in full float32 precision.

Figure 2: Performance gains of enabling CUDA Graphs, and CUDA Graphs with nvFuser compared to the performance of native PyTorch without CUDA Graphs and nvFuser across TIMM models with float16 AMP, channels first inputs, and a batch size of 1 and 8 respectively. There is a geomean speedup of 2.74x with CUDA Graphs and 2.71x with CUDA Graphs + nvFuser respectively. nvFuser provides a maximum regression of 0.68x and a maximum performance gain of 2.74x (relative to CUDA Graphs without nvFuser). Performance gain is measured relative to the average time per iteration PyTorch takes without CUDA Graphs and without nvFuser. Models are sorted by how much additional performance nvFuser is providing.

Figure 3: Performance gains of enabling CUDA Graphs, and CUDA Graphs with nvFuser compared to the performance of native PyTorch without CUDA Graphs and nvFuser across TIMM models with AMP, channels last inputs, and a batch size of 1 and 8 respectively. There is a geomean speedup of 2.29x with CUDA Graphs and 2.95x with CUDA Graphs + nvFuser respectively. nvFuser provides a maximum regression of 0.86x and a maximum performance gain of 3.82x (relative to CUDA Graphs without nvFuser). Performance gain is measured relative to the average time per iteration PyTorch takes without CUDA Graphs and without nvFuser. Models are sorted by how much additional performance nvFuser is providing.

So far nvFuser performance has not been tuned for inference workloads so its performance benefit is not consistent across all cases. However, there are still many models that benefit significantly from nvFuser during inference and we encourage users to try nvFuser in inference workloads to see if you would benefit today. Performance of nvFuser in inference workloads will improve in the future and if you’re interested in nvFuser in inference workloads please reach out to us on the PyTorch forums.

Getting Started – Accelerate Your Scripts with nvFuser

We’ve created a tutorial demonstrating how to take advantage of nvFuser to accelerate part of a standard transformer block, and how nvFuser can be used to define fast and novel operations. There are still some rough edges in nvFuser that we’re working hard on improving as we’ve outlined in this blog post. However we’ve also demonstrated some great improvements for training speed on multiple networks in HuggingFace and TIMM and we expect there are opportunities in your networks where nvFuser can help today, and many more opportunities it will help in the future.
If you would like to learn more about nvFuser we recommend watching our presentations from NVIDIA’s GTC conference GTC 2022 and GTC 2021.

Read More

Accelerating PyTorch Vision Models with Channels Last on CPU

Overview

Memory formats has significant impact on performance when running vision models, generally Channels Last is a more favorable from performance perspective due to better data locality.

This blog will introduce fundamental concepts of memory formats and demonstrate performance benefits using Channels Last on popular PyTorch vision models on Intel® Xeon® Scalable processors.

Memory Formats Introduction

Memory format refers to data representation that describes how a multidimensional (nD) array is stored in linear (1D) memory address space. The concept of memory format has two aspects:

  • Physical Order is the layout of data storage in physical memory. For vision models, usually we talk about NCHW, NHWC. These are the descriptions of physical memory layout, also referred as Channels First and Channels Last respectively.
  • Logical Order is a convention on how to describe tensor shape and stride. In PyTorch, this convention is NCHW. No matter what the physical order is, tensor shape and stride will always be depicted in the order of NCHW.

Fig-1 is the physical memory layout of a tensor with shape of [1, 3, 4, 4] on both Channels First and Channels Last memory format (channels denoted as R, G, B respectively):

Fig-1 Physical memory layout of Channels First and Channels Last

Memory Formats Propagation

The general rule for PyTorch memory format propagation is to preserve the input tensor’s memory format. Which means a Channels First input will generate a Channels First output and a Channels Last input will generate a Channels Last output.

For Convolution layers, PyTorch uses oneDNN (oneAPI Deep Neural Network Library) by default to achieve optimal performance on Intel CPUs. Since it is physically impossible to achieve highly optimized performance directly with Channels Frist memory format, input and weight are firstly converted to blocked format and then computed. oneDNN may choose different blocked formats according to input shapes, data type and hardware architecture, for vectorization and cache reuse purposes. The blocked format is opaque to PyTorch, so the output needs to be converted back to Channels First. Though blocked format would bring about optimal computing performance, the format conversions may add overhead and therefore offset the performance gain.

On the other hand, oneDNN is optimized for Channels Last memory format to use it for optimal performance directly and PyTorch will simply pass a memory view to oneDNN. Which means the conversion of input and output tensor is saved. Fig-2 indicates memory format propagation behavior of convolution on PyTorch CPU (the solid arrow indicates a memory format conversion, and the dashed arrow indicates a memory view):

Fig-2 CPU Conv memory format propagation

On PyTorch, the default memory format is Channels First. In case a particular operator doesn’t have support on Channels Last, the NHWC input would be treated as a non-contiguous NCHW and therefore fallback to Channels First, which will consume the previous memory bandwidth on CPU and result in suboptimal performance.

Therefore, it is very important to extend the scope of Channels Last support for optimal performance. And we have implemented Channels Last kernels for the commonly use operators in CV domain, applicable for both inference and training, such as:

  • Activations (e.g., ReLU, PReLU, etc.)
  • Convolution (e.g., Conv2d)
  • Normalization (e.g., BatchNorm2d, GroupNorm, etc.)
  • Pooling (e.g., AdaptiveAvgPool2d, MaxPool2d, etc.)
  • Shuffle (e.g., ChannelShuffle, PixelShuffle)

Refer to Operators-with-Channels-Last-support for details.

Native Level Optimization on Channels Last

As mentioned above, PyTorch uses oneDNN to achieve optimal performance on Intel CPUs for convolutions. The rest of memory format aware operators are optimized at PyTorch native level, which doesn’t require any third-party library support.

  • Cache friendly parallelization scheme: keep the same parallelization scheme for all the memory format aware operators, this will help increase data locality when passing each layer’s output to the next.
  • Vectorization on multiple archs: generally, we can vectorize on the most inner dimension on Channels Last memory format. And each of the vectorized CPU kernels will be generated for both AVX2 and AVX512.

While contributing to Channels Last kernels, we tried our best to optimize Channels First counterparts as well. The fact is some operators are physically impossible to achieve optimal performance on Channels First, such as Convolution, Pooling, etc.

Run Vision Models on Channels Last

The Channels Last related APIs are documented at PyTorch memory format tutorial. Typically, we can convert a 4D tensor from Channels First to Channels Last by:

# convert x to channels last
# suppose x’s shape is (N, C, H, W)
# then x’s stride will be (HWC, 1, WC, C)
x = x.to(memory_format=torch.channels_last)

To run models on Channels Last memory format, simply need to convert input and model to Channels Last and then you are ready to go. The following is a minimal example showing how to run ResNet50 with TorchVision on Channels Last memory format:

import torch
from torchvision.models import resnet50

N, C, H, W = 1, 3, 224, 224
x = torch.rand(N, C, H, W)
model = resnet50()
model.eval()

# convert input and model to channels last
x = x.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)
model(x)

The Channels Last optimization is implemented at native kernel level, which means you may apply other functionalities such as torch.fx and torch script together with Channels Last as well.

Performance Gains

We benchmarked inference performance of TorchVision models on Intel® Xeon® Platinum 8380 CPU @ 2.3 GHz, single instance per socket (batch size = 2 x number of physical cores). Results show that Channels Last has 1.3x to 1.8x performance gain over Channels First.

The performance gain primarily comes from two aspects:

  • For Convolution layers, Channels Last saved the memory format conversion to blocked format for activations, which improves the overall computation efficiency.
  • For Pooling and Upsampling layers, Channels Last can use vectorized logic along the most inner dimension, e.g., “C”, while Channels First can’t.

For memory format non aware layers, Channels Last and Channels First has the same performance.

Conclusion & Future Work

In this blog we introduced fundamental concepts of Channels Last and demonstrated the performance benefits of CPU using Channels Last on vision models. The current work is limited to 2D models at the current stage, and we will extend the optimization effort to 3D models in near future!

Acknowledgement

The results presented in this blog is a joint effort of Meta and Intel PyTorch team. Special thanks to Vitaly Fedyunin and Wei Wei from Meta who spent precious time and gave substantial assistance! Together we made one more step on the path of improving the PyTorch CPU eco system.

References

Read More

Easily list and initialize models with new APIs in TorchVision

TorchVision now supports listing and initializing all available built-in models and weights by name. This new API builds upon the recently introduced Multi-weight support API, is currently in Beta, and it addresses a long-standing request from the community.

You can try out the new API in the latest nightly release of TorchVision. We’re looking to collect feedback ahead of finalizing the feature in TorchVision v0.14. We have created a dedicated Github Issue where you can post your comments, questions and suggestions!

Querying and initializing available models

Before the new model registration API, developers had to query the __dict__ attribute of the modules in order to list all available models or to fetch a specific model builder method by its name:

# Initialize a model by its name:
model = torchvision.models.__dict__[model_name]()

# List available models:
available_models = [
    k for k, v in torchvision.models.__dict__.items()
    if callable(v) and k[0].islower() and k[0] != "_"
]

The above approach does not always produce the expected results and is hard to discover. For example, since the get_weight() method is exposed publicly under the same module, it will be included in the list despite not being a model. In general, reducing the verbosity (less imports, shorter names etc) and being able to initialize models and weights directly from their names (better support of configs, TorchHub etc) was feedback provided previously by the community. To solve this problem, we have developed a model registration API.

A new approach

We’ve added 4 new methods under the torchvision.models module:

from torchvision.models import get_model, get_model_weights, get_weight, list_models

The styles and naming conventions align closely with a prototype mechanism proposed by Philip Meier for the Datasets V2 API, aiming to offer a similar user experience. The model registration methods are kept private on purpose as we currently focus only on supporting the built-in models of TorchVision.

List models

Listing all available models in TorchVision can be done with a single function call:

>>> list_models()
['alexnet', 'mobilenet_v3_large', 'mobilenet_v3_small', 'quantized_mobilenet_v3_large', ...]

To list the available models of specific submodules:

>>> list_models(module=torchvision.models)
['alexnet', 'mobilenet_v3_large', 'mobilenet_v3_small', ...]
>>> list_models(module=torchvision.models.quantization)
['quantized_mobilenet_v3_large', ...]

Initialize models

Now that you know which models are available, you can easily initialize a model with pre-trained weights:

>>> get_model("quantized_mobilenet_v3_large", weights="DEFAULT")
QuantizableMobileNetV3(
  (features): Sequential(
   ....
   )
)

Get weights

Sometimes, while working with config files or using TorchHub, you might have the name of a specific weight entry and wish to get its instance. This can be easily done with the following method:

>>> get_weight("ResNet50_Weights.IMAGENET1K_V2")
ResNet50_Weights.IMAGENET1K_V2

To get the enum class with all available weights of a specific model you can use either its name:

>>> get_model_weights("quantized_mobilenet_v3_large")
<enum 'MobileNet_V3_Large_QuantizedWeights'>

Or its model builder method:

>>> get_model_weights(torchvision.models.quantization.mobilenet_v3_large)
<enum 'MobileNet_V3_Large_QuantizedWeights'>

TorchHub support

The new methods are also available via TorchHub:

import torch

# Fetching a specific weight entry by its name:
weights = torch.hub.load("pytorch/vision", "get_weight", weights="ResNet50_Weights.IMAGENET1K_V2")

# Fetching the weights enum class to list all available entries:
weight_enum = torch.hub.load("pytorch/vision", "get_model_weights", name="resnet50")
print([weight for weight in weight_enum])

Putting it all together

For example, if you wanted to retrieve all the small-sized models with pre-trained weights and initialize one of them, it’s a matter of using the above APIs:

import torchvision
from torchvision.models import get_model, get_model_weights, list_models


max_params = 5000000

tiny_models = []
for model_name in list_models(module=torchvision.models):
    weights_enum = get_model_weights(model_name)
    if len([w for w in weights_enum if w.meta["num_params"] <= max_params]) > 0:
        tiny_models.append(model_name)

print(tiny_models)
# ['mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mobilenet_v2', ...]

model = get_model(tiny_models[0], weights="DEFAULT")
print(sum(x.numel() for x in model.state_dict().values()))
# 2239188

For more technical details please see the original RFC. Please spare a few minutes to provide your feedback on the new API, as this is crucial for graduating it from beta and including it in the next release. You can do this on the dedicated Github Issue. We are looking forward to reading your comments!

Read More

Empowering PyTorch on Intel® Xeon® Scalable processors with Bfloat16

Overview

Recent years, the growing complexity of AI models have been posing requirements on hardware for more and more compute capability. Reduced precision numeric format has been proposed to address this problem. Bfloat16 is a custom 16-bit floating point format for AI which consists of one sign bit, eight exponent bits, and seven mantissa bits. With the same dynamic range as float32, bfloat16 doesn’t require a special handling such as loss scaling. Therefore, bfloat16 is a drop-in replacement for float32 when running deep neural networks for both inference and training.

The 3rd Gen Intel® Xeon® Scalable processor (codenamed Cooper Lake), is the first general purpose x86 CPU with native bfloat16 support. Three new bfloat16 instructions were introduced in Intel® Advanced Vector Extensions-512 (Intel® AVX-512): VCVTNE2PS2BF16, VCVTNEPS2BF16, and VDPBF16PS. The first two instructions perform conversion from float32 to bfloat16, and the last one performs a dot product of bfloat16 pairs. Bfloat16 theoretical compute throughput is doubled over float32 on Cooper Lake. On the next generation of Intel® Xeon® Scalable Processors, bfloat16 compute throughput will be further enhanced through Advanced Matrix Extensions (Intel® AMX) instruction set extension.

Intel and Meta previously collaborated to enable bfloat16 on PyTorch, and the related work was published in an earlier blog during launch of Cooper Lake. In that blog, we introduced the hardware advancement for native bfloat16 support and showcased a performance boost of 1.4x to 1.6x of bfloat16 over float32 from DLRM, ResNet-50 and ResNext-101-32x4d.

In this blog, we will introduce the latest software enhancement on bfloat16 in PyTorch 1.12, which would apply to much broader scope of user scenarios and showcase even higher performance boost.

Native Level Optimization on Bfloat16

On PyTorch CPU bfloat16 path, the compute intensive operators, e.g., convolution, linear and bmm, use oneDNN (oneAPI Deep Neural Network Library) to achieve optimal performance on Intel CPUs with AVX512_BF16 or AMX support. The other operators, such as tensor operators and neural network operators, are optimized at PyTorch native level. We have enlarged bfloat16 kernel level optimizations to majority of operators on dense tensors, both inference and training applicable (sparse tensor bfloat16 support will be covered in future work), specifically:

  • Bfloat16 vectorization: Bfloat16 is stored as unsigned 16-bit integer, which requires it to be casted to float32 for arithmetic operations such as add, mul, etc. Specifically, each bfloat16 vector will be converted to two float32 vectors, processed accordingly and then converted back. While for non-arithmetic operations such as cat, copy, etc., it is a straight memory copy and no data type conversion will be involved.
  • Bfloat16 reduction: Reduction on bfloat16 data uses float32 as accumulation type to guarantee numerical stability, e.g., sum, BatchNorm2d, MaxPool2d, etc.
  • Channels Last optimization: For vision models, Channels Last is the preferable memory format over Channels First from performance perspective. We have implemented fully optimized CPU kernels for all the commonly used CV modules on channels last memory format, taking care of both float32 and bfloat16.

Run Bfloat16 with Auto Mixed Precision

To run model on bfloat16, typically user can either explicitly convert the data and model to bfloat16, for example:

# with explicit conversion
input = input.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)

or utilize torch.amp (Automatic Mixed Precision) package. The autocast instance serves as context managers or decorators that allow regions of your script to run in mixed precision, for example:

# with AMP
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
    output = model(input)

Generally, the explicit conversion approach and AMP approach have similar performance. Even though, we recommend run bfloat16 models with AMP, because:

  • Better user experience with automatic fallback: If your script includes operators that don’t have bfloat16 support, autocast will implicitly convert them back to float32 while the explicit converted model will give a runtime error.

  • Mixed data type for activation and parameters: Unlike the explicit conversion which converts all the model parameters to bfloat16, AMP mode will run in mixed data type. To be specific, input/output will be kept in bfloat16 while parameters, e.g., weight/bias, will be kept in float32. The mixed data type of activation and parameters will help improve performance while maintaining the accuracy.

Performance Gains

We benchmarked inference performance of TorchVision models on Intel® Xeon® Platinum 8380H CPU @ 2.90GHz (codenamed Cooper Lake), single instance per socket (batch size = 2 x number of physical cores). Results show that bfloat16 has 1.4x to 2.2x performance gain over float32.

The performance boost of bfloat16 over float32 primarily comes from 3 aspects:

  • The compute intensive operators take advantage of the new bfloat16 native instruction VDPBF16PS which doubles the hardware compute throughput.
  • Bfloat16 have only half the memory footprint of float32, so theoretically the memory bandwidth intensive operators will be twice faster.
  • On Channels Last, we intentionally keep the same parallelization scheme for all the memory format aware operators (can’t do this on Channels First though), which increases the data locality when passing each layer’s output to the next. Basically, it keeps the data closer to CPU cores while data would reside in cache anyway. And bfloat16 will have a higher cache hit rate compared with float32 in such scenarios due to smaller memory footprint.

Conclusion & Future Work

In this blog, we introduced recent software optimizations on bfloat16 introduced in PyTorch 1.12. Results on the 3rd Gen Intel® Xeon® Scalable processor show that bfloat16 has 1.4x to 2.2x performance gain over float32 on the TorchVision models. Further improvement is expected on the next generation of Intel® Xeon® Scalable Processors with AMX instruction support. Though the performance number for this blog is collected with TorchVision models, the benefit is broad across all topologies. And we will continue to extend the bfloat16 optimization effort to a broader scope in the future!

Acknowledgement

The results presented in this blog is a joint effort of Meta and Intel PyTorch team. Special thanks to Vitaly Fedyunin and Wei Wei from Meta who spent precious time and gave substantial assistance! Together we made one more step on the path of improving the PyTorch CPU eco system.

Reference

Read More

Introducing the PlayTorch app: Rapidly Create Mobile AI Experiences

In December, we announced PyTorch Live, a toolkit for building AI-powered mobile prototypes in minutes. The initial release included a command-line interface to set up a development environment and an SDK for building AI-powered experiences in React Native. Today, we’re excited to share that PyTorch Live will now be known as PlayTorch. This new release provides an improved and simplified developer experience. PlayTorch development is independent from the PyTorch project and the PlayTorch code repository is moving into the Meta Research GitHub organization.

A New Workflow: The PlayTorch App

The PlayTorch team is excited to announce that we have partnered with Expo to change the way AI powered mobile experiences are built. Our new release simplifies the process of building mobile AI experiences by eliminating the need for a complicated development environment. You will now be able to build cross platform AI powered prototypes from the very browser you are using to read this blog.

In order to make this happen, we are releasing the PlayTorch app which is able to run AI-powered experiences built in the Expo Snack web based code editor.

The PlayTorch app can be downloaded from the Apple App Store and Google Play Store. With the app installed, you can head over to playtorch.dev/snack and write the code for your AI-powered PlayTorch Snack. When you want to try what you’ve built, you can use the PlayTorch app’s QR code scanner to scan the QR code on the Snack page and load the code to your device.

NOTE: PlayTorch Snacks will not work in the Expo Go app.

More to Explore in the PlayTorch App

AI Demos

The PlayTorch app comes with several examples of how you can build AI powered experiences with a variety of different machine learning models from object detection to natural language processing. See what can be built with the PlayTorch SDK and be inspired to make something of your own as you play with the examples.

Sharing Your Creations

Any PlayTorch Snack that you run in the PlayTorch app can be shared with others in an instant. When they open the link on their device, the PlayTorch app will instantly load what you’ve built from the cloud so they can experience it first hand.

When you have something you want to share, let us know on Discord or Twitter or embed the PlayTorch Snack on your own webpage.

SDK Overhaul

We learned a lot from the community after our initial launch in December and have been hard at work over the past several months to make the PlayTorch SDK (formerly known as PyTorch Live) simple, performant, and robust. In our initial version, the SDK relied on config files to define how a model ingested and output data.

Today, we are happy to announce the next version of our SDK can handle data processing in JavaScript for your prototypes with the new PlayTorch API that leverages the JavaScript Interface (JSI) to directly call C++ code. Not only have we completely redone the way you can interact with models, but we have also greatly expanded the variety of supported model architectures.

A New Data Processing API for Prototyping

With this JSI API, we now allow users direct access to tensors (data format for machine learning). Instead of only having access to predefined transformations, you can now manipulate tensors however you would like for your prototypes.

No more switching back and forth between code and config. You will now be able to write everything in JavaScript and have access to all of the type annotations and autocomplete features available to you in those languages.

Check out our tutorials to see the new Data Processing API in action, take a deeper dive in the API docs, or inspect the code yourself on GitHub.

Expanded Use Cases

With the new version of the SDK, we have added support for several cutting edge models.

Image-to-image transformations are now supported thanks to our robust JSI API, so you can see what your world would look like if it were an anime.

Translate French to English with an AI powered translator using the Seq2Seq model.

Use DeepLab V3 to segment images!

Start Playing

If you want to start creating AI experiences yourself, head over to playtorch.dev and try out our tutorials. Each tutorial will guide you through building a simple AI powered experience that you can instantly run on your phone and share with others.

How to Get Support

Join us on Discord, collaborate with us on GitHub, or follow us on Twitter. Got questions or feedback? We’d love to hear from you!

Read More

What Every User Should Know About Mixed Precision Training in PyTorch

Efficient training of modern neural networks often relies on using lower precision data types. Peak float16 matrix multiplication and convolution performance is 16x faster than peak float32 performance on A100 GPUs. And since the float16 and bfloat16 data types are only half the size of float32 they can double the performance of bandwidth-bound kernels and reduce the memory required to train a network, allowing for larger models, larger batches, or larger inputs. Using a module like torch.amp (short for “Automated Mixed Precision”) makes it easy to get the speed and memory usage benefits of lower precision data types while preserving convergence behavior.

Going faster and using less memory is always advantageous – deep learning practitioners can test more model architectures and hyperparameters, and larger, more powerful models can be trained. Training very large models like those described in Narayanan et al. and Brown et al. (which take thousands of GPUs months to train even with expert handwritten optimizations) is infeasible without using mixed precision.

We’ve talked about mixed precision techniques before (here, here, and here), and this blog post is a summary of those techniques and an introduction if you’re new to mixed precision.

Mixed Precision Training in Practice

Mixed precision training techniques – the use of the lower precision float16 or bfloat16 data types alongside the float32 data type – are broadly applicable and effective. See Figure 1 for a sampling of models successfully trained with mixed precision, and Figures 2 and 3 for example speedups using torch.amp.

Figure 1: Sampling of DL Workloads Successfully Trained with float16 (Source).

Figure 2: Performance of mixed precision training using torch.amp on NVIDIA 8xV100 vs. float32 training on 8xV100 GPU. Bars represent the speedup factor of torch.amp over float32.
(Higher is better.) (Source).

Figure 3. Performance of mixed precision training using torch.amp on NVIDIA 8xA100 vs. 8xV100 GPU. Bars represent the speedup factor of A100 over V100.
(Higher is Better.) (Source).

See the NVIDIA Deep Learning Examples repository for more sample mixed precision workloads.

Similar performance charts can be seen in 3D medical image analysis, gaze estimation, video synthesis, conditional GANs, and convolutional LSTMs. Huang et al. showed that mixed precision training is 1.5x to 5.5x faster over float32 on V100 GPUs, and an additional 1.3x to 2.5x faster on A100 GPUs on a variety of networks. On very large networks the need for mixed precision is even more evident. Narayanan et al. reports that it would take 34 days to train GPT-3 175B on 1024 A100 GPUs (with a batch size of 1536), but it’s estimated it would take over a year using float32!

Getting Started With Mixed Precision Using torch.amp

torch.amp, introduced in PyTorch 1.6, makes it easy to leverage mixed precision training using the float16 or bfloat16 dtypes. See this blog post, tutorial, and documentation for more details. Figure 4 shows an example of applying AMP with grad scaling to a network.

import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()

for data, label in data_iter:
   optimizer.zero_grad()
   # Casts operations to mixed precision
   with torch.amp.autocast(device_type=“cuda”, dtype=torch.float16):
      loss = model(data)

   # Scales the loss, and calls backward()
   # to create scaled gradients
   scaler.scale(loss).backward()

   # Unscales gradients and calls
   # or skips optimizer.step()
   scaler.step(optimizer)

   # Updates the scale for next iteration
   scaler.update()

Figure 4: AMP recipe

Picking The Right Approach

Out-of-the-box mixed precision training with either float16 or bfloat16 is effective at speeding up the convergence of many deep learning models, but some models may require more careful numerical accuracy management. Here are some options:

  • Full float32 precision. Floating point tensors and modules are created in float32 precision by default in PyTorch, but this is a historic artifact not representative of training most modern deep learning networks. It’s rare that networks need this much numerical accuracy.
  • Enabling TensorFloat32 (TF32) mode. On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. See the Accelerating AI Training with NVIDIA TF32 Tensor Cores blog post for more details. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too (see the documentation here for how to do so). It can significantly speed up computations with typically negligible loss of numerical accuracy.
  • Using torch.amp with bfloat16 or float16. Both these low precision floating point data types are usually comparably fast, but some networks may only converge with one vs the other. If a network requires more precision it may need to use float16, and if a network requires more dynamic range it may need to use bfloat16, whose dynamic range is equal to that of float32. If overflows are observed, for example, then we suggest trying bfloat16.

There are even more advanced options than those presented here, like using torch.amp’s autocasting for only parts of a model, or managing mixed precision directly. These topics are largely beyond the scope of this blog post, but see the “Best Practices” section below.

Best Practices

We strongly recommend using mixed precision with torch.amp or the TF32 mode (on Ampere and later CUDA devices) whenever possible when training a network. If one of those approaches doesn’t work, however, we recommend the following:

  • High Performance Computing (HPC) applications, regression tasks, and generative networks may simply require full float32 IEEE precision to converge as expected.
  • Try selectively applying torch.amp. In particular we recommend first disabling it on regions performing operations from the torch.linalg module or when doing pre- or post-processing. These operations are often especially sensitive. Note that TF32 mode is a global switch and can’t be used selectively on regions of a network. Enable TF32 first to check if a network’s operators are sensitive to the mode, otherwise disable it.
  • If you encounter type mismatches while using torch.amp we don’t suggest inserting manual casts to start. This error is indicative of something being off with the network, and it’s usually worth investigating first.
  • Figure out by experimentation if your network is sensitive to range and/or precision of a format. For example fine-tuning bfloat16-pretrained models in float16 can easily run into range issues in float16 because of the potentially large range from training in bfloat16, so users should stick with bfloat16 fine-tuning if the model was trained in bfloat16.
  • The performance gain of mixed precision training can depend on multiple factors (e.g. compute-bound vs memory-bound problems) and users should use the tuning guide to remove other bottlenecks in their training scripts. Although having similar theoretical performance benefits, BF16 and FP16 can have different speeds in practice. It’s recommended to try the mentioned formats and use the one with best speed while maintaining the desired numeric behavior.

For more details, refer to the AMP Tutorial, Training Neural Networks with Tensor Cores, and see the post “More In-Depth Details of Floating Point Precision” on PyTorch Dev Discussion.

Conclusion

Mixed precision training is an essential tool for training deep learning models on modern hardware, and it will become even more important in the future as the performance gap between lower precision operations and float32 continues to grow on newer hardware, as reflected in Figure 5.

Figure 5: Relative peak throughput of float16 (FP16) vs float32 matrix multiplications on Volta and Ampere GPUs. On Ampere relative peak throughput for the TensorFloat32 (TF32) mode and bfloat16 matrix multiplications are shown, too. The relative peak throughput of low precision data types like float16 and bfloat16 vs. float32 matrix multiplications is expected to grow as new hardware is released.

PyTorch’s torch.amp module makes it easy to get started with mixed precision, and we highly recommend using it to train faster and reduce memory usage. torch.amp supports both float16 and bfloat16 mixed precision.

There are still some networks that are tricky to train with mixed precision, and for these networks we recommend trying TF32 accelerated matrix multiplications on Ampere and later CUDA hardware. Networks are rarely so precision sensitive that they require full float32 precision for every operation.

If you have questions or suggestions for torch.amp or mixed precision support in PyTorch then let us know by posting to the mixed precision category on the PyTorch Forums or filing an issue on the PyTorch GitHub page.

Read More

Case Study: PathAI Uses PyTorch to Improve Patient Outcomes with AI-powered Pathology

​PathAI is the leading provider of AI-powered technology tools and services for pathology (the study of disease). Our platform was built to enable substantial improvements to the accuracy of diagnosis and the measurement of therapeutic efficacy for complex diseases, leveraging modern approaches in machine learning like image segmentation, graph neural networks, and multiple instance learning.

Traditional manual pathology is prone to subjectivity and observer variability that can negatively affect diagnoses and drug development trials. Before we dive into how we use PyTorch to improve our diagnosis workflow, let us first lay out the traditional analog Pathology workflow without machine learning.

How Traditional Biopharma Works

There are many avenues that biopharma companies take to discover novel therapeutics or diagnostics. One of those avenues relies heavily on the analysis of pathology slides to answer a variety of questions: how does a particular cellular communication pathway work? Can a specific disease state be linked to the presence or lack of a particular protein? Why did a particular drug in a clinical trial work for some patients but not others? Might there be an association between patient outcomes and a novel biomarker?

To help answer these questions, biopharma companies rely on expert pathologists to analyze slides and help evaluate the questions they might have. 

As you might imagine, it takes an expert board certified pathologist to make accurate interpretations and diagnosis. In one study, a single biopsy result was given to 36 different pathologists and the outcome was 18 different diagnoses varying in severity from no treatment to aggressive treatment necessary. Pathologists also often solicit feedback from colleagues in difficult edge cases. Given the complexity of the problem, even with expert training and collaboration, pathologists can still have a hard time making a correct diagnosis. This potential variance can be the difference between a drug being approved and it failing the clinical trial.

How PathAI utilizes machine learning to power drug development

PathAI develops machine learning models which provide insights for drug development R&D, for powering clinical trials, and for making diagnoses. To this end, PathAI leverages PyTorch for slide level inference using a variety of methods including graph neural networks (GNN) as well as multiple instance learning. In this context, “slides” refers to full size scanned images of glass slides, which are pieces of glass with a thin slice of tissue between them, stained to show various cell formations. PyTorch enables our teams using these different methodologies to share a common framework which is robust enough to work in all the conditions we need. PyTorch’s high level, imperative, and pythonic syntax allows us to prototype models quickly and then take those models to scale once we have the results we want. 

Multi-instance learning on gigabyte images

One of the uniquely challenging aspects of applying ML to pathology is the immense size of the images. These digital slides can often be 100,000 x 100,000 pixels or more in resolution and gigabytes in size. Loading the full image in GPU memory and applying traditional computer vision algorithms on them is an almost impossible task. It also takes both a considerable amount of time and resources to have a full slide image (100k x 100k) annotated, especially when annotators need to be domain experts (board-certified pathologists). We often build models to predict image-level labels, like the presence of cancer, on a patient slide which covers a few thousand pixels in the whole image. The cancerous area is sometimes a tiny fraction of the entire slide, which makes the ML problem similar to finding a needle in a haystack. On the other hand, some problems like the prediction of certain histological biomarkers require an aggregation of information from the whole slide which is again hard due to the size of the images. All these factors add significant algorithmic, computational, and logistical complexity when applying ML techniques to pathology problems.

Breaking down the image into smaller patches, learning patch representations, and then pooling those representations to predict an image-level label is one way to solve this problem as is depicted in the image below. One popular method for doing this is called Multiple Instance Learning (MIL). Each patch is considered an ‘instance’ and a set of patches forms a ‘bag’. The individual patch representations are pooled together to predict a final bag-level label. Algorithmically, the individual patch instances in the bag do not require labels and hence allow us to learn bag-level labels in a weakly-supervised way. They also use permutation invariant pooling functions which make the prediction independent of the order of patches and allows for an efficient aggregation of information. Typically, attention based pooling functions are used which not only allow for efficient aggregation but also provide attention values for each patch in the bag. These values indicate the importance of the corresponding patch in the prediction and can be visualized to better understand the model predictions. This element of interpretability can be very important to drive adoption of these models in the real world and we use variations like Additive MIL models to enable such spatial explainability. Computationally, MIL models circumvent the problem of applying neural networks to large image sizes since patch representations are obtained independently of the size of the image.

At PathAI, we use custom MIL models based on deep nets to predict image-level labels. The overview of this process is as follows:

  1. Select patches from a slide using different sampling approaches.
  2. Construct a bag of patches based on random sampling or heuristic rules.
  3. Generate patch representations for each instance based on pre-trained models or large-scale representation learning models.
  4. Apply permutation invariant pooling functions to get the final slide-level score.

Now that we have walked through some of the high-level details around MIL in PyTorch, let’s look at some code to see how simple it is to go from ideation to code in production with PyTorch. We begin by defining a sampler, transformations, and our MIL dataset:

# Create a bag sampler which randomly samples patches from a slide
bag_sampler = RandomBagSampler(bag_size=12)

# Setup the transformations
crop_transform = FlipRotateCenterCrop(use_flips=True)

# Create the dataset which loads patches for each bag
train_dataset = MILDataset(
  bag_sampler=bag_sampler,
  samples_loader=sample_loader,
  transform=crop_transform,
)

After we have defined our sampler and dataset, we need to define the model we will actually train with said dataset. PyTorch’s familiar model definition syntax makes this easy to do while also allowing us to create bespoke models at the same time.

classifier = DefaultPooledClassifier(hidden_dims=[256, 256], input_dims=1024, output_dims=1)

pooling = DefaultAttentionModule(
  input_dims=1024,
  hidden_dims=[256, 256],
  output_activation=StableSoftmax()
)

# Define the model which is a composition of the featurizer, pooling module and a classifier
model = DefaultMILGraph(featurizer=ShuffleNetV2(), classifier=classifier, pooling = pooling)

Since these models are trained end-to-end, they offer a powerful way to go directly from a gigapixel whole slide image to a single label. Due to their wide applicability to different biological problems, two aspects of their implementation and deployment are important:

  1. Configurable control over each part of the pipeline including the data loaders, the modular parts of the model, and their interaction with each other.
  2. Ability to rapidly iterate through the ideate-implement-experiment-productionize loop.

PyTorch has various advantages when it comes to MIL modeling. It offers an intuitive way to create dynamic computational graphs with flexible control flow which is great for rapid research experimentation. The map-style datasets, configurable sampler and batch-samplers allow us to customize how we construct bags of patches, enabling faster experimentation. Since MIL models are IO heavy, data parallelism and pythonic data loaders make the task very efficient and user friendly. Lastly, the object-oriented nature of PyTorch enables building of reusable modules which aid in the rapid experimentation, maintainable implementation and ease of building compositional components of the pipeline.

Exploring spatial tissue organization with GNNs in PyTorch

In both healthy and diseased tissue, the spatial arrangement and structure of cells can oftentimes be as important as the cells themselves. For example, when assessing lung cancers, pathologists try to look at the overall grouping and structure of tumor cells (do they form solid sheets? Or do they occur in smaller, localized clusters?) to determine if the cancer belongs to specific subtypes which can have vastly different prognosis. Such spatial relationships between cells and other tissue structures can be modeled using graphs to capture tissue topology and cellular composition at the same time. Graph Neural Networks (GNNs) allow learning spatial patterns within these graphs that relate to other clinical variables, for example overexpression of genes in certain cancers.

In late 2020, when PathAI started using GNNs on tissue samples, PyTorch had the best and most mature support for GNN functionality via the PyG package. This made PyTorch the natural choice for our team given that GNN models were something that we knew would be an important ML concept we wanted to explore. 

One of the main value-adds of GNN’s in the context of tissue samples is that the graph itself can uncover spatial relationships that would otherwise be very difficult to find by visual inspection alone. In our recent AACR publication, we showed that by using GNNs, we can better understand the way the presence of immune cell aggregates (specifically tertiary lymphoid structures, or TLS) in the tumor microenvironment can influence patient prognosis. In this case, the GNN approach was used to predict expression of genes associated with the presence of TLS, and identify histological features beyond the TLS region itself that are relevant to TLS. Such insights into gene expression are difficult to identify from tissue sample images when unassisted by ML models. 

One of the most promising GNN variations we have had success with is self attention graph pooling. Let’s take a look at how we define our Self Attention Graph Pooling (SAGPool) model using PyTorch and PyG:

class SAGPool(torch.nn.Module):
  def __init__(self, ...):
    super().__init__()
    self.conv1 = GraphConv(in_features, hidden_features, aggr='mean')
    self.convs = torch.nn.ModuleList()
    self.pools = torch.nn.ModuleList()
    self.convs.extend([GraphConv(hidden_features, hidden_features, aggr='mean') for i in range(num_layers - 1)])
    self.pools.extend([SAGPooling(hidden_features, ratio, GNN=GraphConv, min_score=min_score) for i in range((num_layers) // 2)])
    self.jump = JumpingKnowledge(mode='cat')
    self.lin1 = Linear(num_layers * hidden_features, hidden_features)
    self.lin2 = Linear(hidden_features, out_features)
    self.out_activation = out_activation
    self.dropout = dropout

In the above code, we begin by defining a single convolutional graph layer and then add two module list layers which allow us to pass in a variable number of layers. We then take our empty module list and append a variable number of GraphConv layers followed by a variable number of SAGPooling layers. We finish up our SAGPool definition by adding a JumpingKnowledge Layer, two linear layers, our activation function, and our dropout value. PyTorch’s intuitive syntax allows us to abstract away the complexity of working with state of the art methods like SAG Poolings while also maintaining the common approach to model development we are familiar with.

Models like our SAG Pool one described above are just one example of how GNNs with PyTorch are allowing us to explore new and novel ideas. We also recently explored multimodal CNN – GNN hybrid models which ended up being 20% more accurate than traditional Pathologist consensus scores. These innovations and interplay between traditional CNNs and GNNs are again enabled by the short research to production model development loop.

Improving Patient Outcomes

In order to achieve our mission of improving patient outcomes with AI-powered pathology, PathAI needs to rely on an ML development framework that (1) facilitates quick iteration and easy extension (i.e. Model configuration as code) during initial phases of development and exploration (2) scales model training and inference to massive images (3) easily and robustly serves models for production uses of our products (in clinical trials and beyond). As we’ve demonstrated, PyTorch offers us all of these capabilities and more. We are incredibly excited about the future of PyTorch and cannot wait to see what other impactful challenges we can solve using the framework.

Read More