Accelerated Generative Diffusion Models with PyTorch 2

Accelerated Generative Diffusion Models with PyTorch 2

TL;DR: PyTorch 2.0 nightly offers out-of-the-box performance improvement for Generative Diffusion models by using the new torch.compile() compiler and optimized implementations of Multihead Attention integrated with PyTorch 2.

Introduction

A large part of the recent progress in Generative AI came from denoising diffusion models, which allow producing high quality images and videos from text prompts. This family includes Imagen, DALLE, Latent Diffusion, and others. However, all models in this family share a common drawback: generation is rather slow, due to the iterative nature of the sampling process by which the images are produced. This makes it important to optimize the code running inside the sampling loop.

We took an open source implementation of a popular text-to-image diffusion model as a starting point and accelerated its generation using two optimizations available in PyTorch 2: compilation and fast attention implementation. Together with a few minor memory processing improvements in the code these optimizations give up to 49% inference speedup relative to the original implementation without xFormers, and 39% inference speedup relative to using the original code with xFormers (excluding the compilation time), depending on the GPU architecture and batch size. Importantly, the speedup comes without a need to install xFormers or any other extra dependencies.

The table below shows the improvement in runtime between the original implementation with xFormers installed and our optimized version with PyTorch-integrated memory efficient attention (originally developed for and released in the xFormers library) and PyTorch compilation. The compilation time is excluded.

Runtime improvement in % compared to original+xFormers

See the absolute runtime numbers in section “Benchmarking setup and results summary”

GPU Batch size 1 Batch size 2 Batch size 4
P100 (no compilation) -3.8 0.44 5.47
T4 2.12 10.51 14.2
A10 -2.34 8.99 10.57
V100 18.63 6.39 10.43
A100 38.5 20.33 12.17

One can notice the following:

  • The improvements are significant for powerful GPUs like A100 and V100. For those GPUs the improvement is most pronounced for batch size 1
  • For less powerful GPUs we observe smaller speedups (or in two cases slight regressions). The batch size trend is reversed here: improvement is larger for larger batches

In the following sections we describe the applied optimizations and provide detailed benchmarking data, comparing the generation time with various optimization features on/off.

Specifically, we benchmark 5 configurations and the plots below compare their absolute performance for different GPUs and batch sizes. For definitions of these configurations see section “Benchmarking setup and results”.

Benchmark of denoising diffusion text-to-image generation across GPU architectures, batch size 1

Benchmark of denoising diffusion text-to-image generation across GPU architectures, batch size 2

Benchmark of denoising diffusion text-to-image generation across GPU architectures, batch size 1

Optimizations

Here we’ll go into more detail about the optimizations introduced into the model code. These optimizations rely on features of PyTorch 2.0 which has been released recently.

Optimized Attention

One part of the code which we optimized is the scaled dot-product attention. Attention is known to be a heavy operation: naive implementation materializes the attention matrix, leading to time and memory complexity quadratic in sequence length. It is common for diffusion models to use attention (CrossAttention) as part of Transformer blocks in multiple parts of the U-Net. Since the U-Net runs at every sampling step, this becomes a critical point to optimize. Instead of custom attention implementation one can use torch.nn.MultiheadAttention, which in PyTorch 2 has optimized attention implementation is integrated into it. This optimization schematically boils down to the following pseudocode:

class CrossAttention(nn.Module):
    def __init__(self, ...):
        # Create matrices: Q, K, V, out_proj
        ...
    def forward(self, x, context=None, mask=None):
       # Compute out = SoftMax(Q*K/sqrt(d))V
       # Return out_proj(out)
       …

gets replaced with

class CrossAttention(nn.Module):
    def __init__(self, ...):
        self.mha = nn.MultiheadAttention(...)
    def forward(self, x, context):
	return self.mha(x, context, context)

The optimized implementation of attention was available already in PyTorch 1.13 (see here) and widely adopted (see e.g. HuggingFace transformers library example). In particular, it integrates memory-efficient attention from the xFormers library and flash attention from https://arxiv.org/abs/2205.14135. PyTorch 2.0 expands this to additional attention functions such as cross attention and custom kernels for further acceleration, making it applicable to diffusion models.

Flash attention is available on GPUs with compute capability SM 7.5 or SM 8.x – for example, on T4, A10, and A100, which are included in our benchmark (you can check compute capability of each NVIDIA GPU here). However, in our tests on A100 the memory efficient attention performed better than flash attention for the particular case of diffusion models, due to the small number of attention heads and small batch size. PyTorch understands this and in this case chooses memory efficient attention over flash attention when both are available (see the logic here). For full control over the attention backends (memory-efficient attention, flash attention, “vanilla math”, or any future ones), power users can enable and disable them manually with the help of the context manager torch.backends.cuda.sdp_kernel.

Compilation

Compilation is a new feature of PyTorch 2.0, enabling significant speedups with a very simple user experience. To invoke the default behavior, simply wrap a PyTorch module or a function into torch.compile:

model = torch.compile(model)

PyTorch compiler then turns Python code into a set of instructions which can be executed efficiently without Python overhead. The compilation happens dynamically the first time the code is executed. With the default behavior, under the hood PyTorch utilized TorchDynamo to compile the code and TorchInductor to further optimize it. See this tutorial for more details.

Although the one-liner above is enough for compilation, certain modifications in the code can squeeze a larger speedup. In particular, one should avoid so-called graph breaks – places in the code which PyTorch can’t compile. As opposed to previous PyTorch compilation approaches (like TorchScript), PyTorch 2 compiler doesn’t break in this case. Instead it falls back on eager execution – so the code runs, but with reduced performance. We introduced a few minor changes to the model code to get rid of graph breaks. This included eliminating functions from libraries not supported by the compiler, such as inspect.isfunction and einops.rearrange. See this doc to learn more about graph breaks and how to eliminate them.

Theoretically, one can apply torch.compile on the whole diffusion sampling loop. However, in practice it is enough to just compile the U-Net. The reason is that torch.compile doesn’t yet have a loop analyzer and would recompile the code for each iteration of the sampling loop. Moreover, compiled sampler code is likely to generate graph breaks – so one would need to adjust it if one wants to get a good performance from the compiled version.

Note that compilation requires GPU compute capability >= SM 7.0 to run in non-eager mode. This covers all GPUs in our benchmarks – T4, V100, A10, A100 – except for P100 (see the full list).

Other optimizations

In addition, we have improved efficiency of GPU memory operations by eliminating some common pitfalls, e.g. creating a tensor on GPU directly rather than creating it on CPU and later moving to GPU. The places where such optimizations were necessary were determined by line-profiling and looking at CPU/GPU traces and Flame Graphs.

Benchmarking setup and results summary

We have two versions of code to compare: original and optimized. On top of this, several optimization features (xFormers, PyTorch memory efficient attention, compilation) can be turned on/off. Overall, as mentioned in the introduction, we will be benchmarking 5 configurations:

  • Original code without xFormers
  • Original code with xFormers
  • Optimized code with vanilla math attention backend and no compilation
  • Optimized code with memory-efficient attention backend and no compilation
  • Optimized code with memory-efficient attention backend and compilation

As the original version we took the version of the code which uses PyTorch 1.12 and a custom implementation of attention. The optimized version uses nn.MultiheadAttention in CrossAttention and PyTorch 2.0.0.dev20230111+cu117. It also has a few other minor optimizations in PyTorch-related code.

The table below shows runtime of each version of the code in seconds, and the percentage improvement compared to the _original with xFormers. _The compilation time is excluded.

Runtimes for batch size 1. In parenthesis – relative improvement with respect to the “Original with xFormers” row

Configuration P100 T4 A10 V100 A100
Original without xFormers 30.4s (-19.3%) 29.8s (-77.3%) 13.0s (-83.9%) 10.9s (-33.1%) 8.0s (-19.3%)
Original with xFormers 25.5s (0.0%) 16.8s (0.0%) 7.1s (0.0%) 8.2s (0.0%) 6.7s (0.0%)
Optimized with vanilla math attention, no compilation 27.3s (-7.0%) 19.9s (-18.7%) 13.2s (-87.2%) 7.5s (8.7%) 5.7s (15.1%)
Optimized with mem. efficient attention, no compilation 26.5s (-3.8%) 16.8s (0.2%) 7.1s (-0.8%) 6.9s (16.0%) 5.3s (20.6%)
Optimized with mem. efficient attention and compilation 16.4s (2.1%) 7.2s (-2.3%) 6.6s (18.6%) 4.1s (38.5%)

Runtimes for batch size 2

Configuration P100 T4 A10 V100 A100
Original without xFormers 58.0s (-21.6%) 57.6s (-84.0%) 24.4s (-95.2%) 18.6s (-63.0%) 12.0s (-50.6%)
Original with xFormers 47.7s (0.0%) 31.3s (0.0%) 12.5s (0.0%) 11.4s (0.0%) 8.0s (0.0%)
Optimized with vanilla math attention, no compilation 49.3s (-3.5%) 37.9s (-21.0%) 17.8s (-42.2%) 12.7s (-10.7%) 7.8s (1.8%)
Optimized with mem. efficient attention, no compilation 47.5s (0.4%) 31.2s (0.5%) 12.2s (2.6%) 11.5s (-0.7%) 7.0s (12.6%)
Optimized with mem. efficient attention and compilation 28.0s (10.5%) 11.4s (9.0%) 10.7s (6.4%) 6.4s (20.3%)

Runtimes for batch size 4

Configuration P100 T4 A10 V100 A100
Original without xFormers 117.9s (-20.0%) 112.4s (-81.8%) 47.2s (-101.7%) 35.8s (-71.9%) 22.8s (-78.9%)
Original with xFormers 98.3s (0.0%) 61.8s (0.0%) 23.4s (0.0%) 20.8s (0.0%) 12.7s (0.0%)
Optimized with vanilla math attention, no compilation 101.1s (-2.9%) 73.0s (-18.0%) 28.3s (-21.0%) 23.3s (-11.9%) 14.5s (-13.9%)
Optimized with mem. efficient attention, no compilation 92.9s (5.5%) 61.1s (1.2%) 23.9s (-1.9%) 20.8s (-0.1%) 12.8s (-0.9%)
Optimized with mem. efficient attention and compilation 53.1s (14.2%) 20.9s (10.6%) 18.6s (10.4%) 11.2s (12.2%)

To minimize fluctuations and external influence on the performance of the benchmarked code, we ran each version of the code one after another, and then repeated this sequence 10 times: A, B, C, D, E, A, B, … So the results of a typical run would look like the one in the picture below.. Note that one shouldn’t rely on comparison of absolute run times between different graphs, but comparison of run times_ inside_ one graph is pretty reliable, thanks to our benchmarking setup.

Denoising diffusion model generation benchmarks

Each run of text-to-image generation script produces several batches, the number of which is regulated by the CLI parameter --n_iter. In the benchmarks we used n_iter = 2, but introduced an additional “warm-up” iteration, which doesn’t contribute to the run time. This was necessary for the runs with compilation, because compilation happens the first time the code runs, and so the first iteration is much longer than all subsequent. To make comparison fair, we also introduced this additional “warm-up” iteration to all other runs.

The numbers in the table above are for number of iterations 2 (plus a “warm-up one”), prompt ”A photo”, seed 1, PLMS sampler, and autocast turned on.

Benchmarks were done using P100, V100, A100, A10 and T4 GPUs. The T4 benchmarks were done in Google Colab Pro. The A10 benchmarks were done on g5.4xlarge AWS instances with 1 GPU.

Conclusions and next steps

We have shown that new features of PyTorch 2 – compiler and optimized attention implementation – give performance improvements exceeding or comparable with what previously required installation of an external dependency (xFormers). PyTorch achieved this, in particular, by integrating memory efficient attention from xFormers into its codebase. This is a significant improvement for user experience, given that xFormers, being a state-of-the-art library, in many scenarios requires custom installation process and long builds.

There are a few natural directions in which this work can be continued:

  • The optimizations we implemented and described here are only benchmarked for text-to-image inference so far. It would be interesting to see how they affect training performance. PyTorch compilation can be directly applied to training; enabling training with PyTorch optimized attention is on the roadmap
  • We intentionally minimized changes to the original model code. Further profiling and optimization can probably bring more improvements
  • At the moment compilation is applied only to the U-Net model inside the sampler. Since there is a lot happening outside of U-Net (e.g. operations directly in the sampling loop), it would be beneficial to compile the whole sampler. However, this would require analysis of the compilation process to avoid recompilation at every sampling step
  • Current code only applies compilation within the PLMS sampler, but it should be trivial to extend it to other samplers
  • Besides text-to-image generation, diffusion models are also applied to other tasks – image-to-image and inpainting. It would be interesting to measure how their performance improves from PyTorch 2 optimizations

See if you can increase performance of open source diffusion models using the methods we described, and share the results!

Resources

Acknowledgements

We would like to thank Geeta Chauhan, Natalia Gimelshein, Patrick Labatut, Bert Maher, Mark Saroufim, Michael Voznesensky and Francisco Massa for their valuable advice and early feedback on the text.

Special thanks to Yudong Tao initiating the work on using PyTorch native attention in diffusion models.

Read More

Straggler Mitigation On PyTorch DDP By Hierarchical SGD

Straggler Mitigation On PyTorch DDP By Hierarchical SGD

PyTorch DDP has been widely adopted across the industry for distributed training, which by default runs synchronous SGD to synchronize gradients across model replicas at every step. The performance of this technique is critical for fast iteration during model exploration as well as resource and cost saving. The performance is critical for fast iteration and cost saving of model development and exploration. To resolve a ubiquitous performance bottleneck introduced by slow nodes in large-scale training, Cruise and Meta co-developed a solution based on the Hierarchical SGD algorithm to significantly accelerate training in the presence of these stragglers.

The Need For Straggler Mitigation

In DDP setup, a straggler problem can occur when one or more processes run much slower (“stragglers”) than other processes. When this happens, all the processes have to wait for the stragglers before synchronizing gradients and completing the communication, which essentially bottlenecks distributed performance to the slowest worker.As a result, even for the cases of training relatively small models, the communication cost can still be a major performance bottleneck.

Potential Causes of Stragglers

Severe straggler issues are usually caused by workload imbalance before synchronization, and many factors can contribute to this imbalance. For instance, some data loader workers in the distributed environment can become stragglers, because some input examples can be outliers in terms of the data size, or the data transfer of some examples can be drastically slowed down due to unstable network I/O, or the on-the-fly data transformation costs can have a high variance.

Besides data loading, other phases before gradient synchronization can also cause stragglers, such as unbalanced workloads of embedding table lookup during the forward pass in recommendation systems.

The Appearance of Stragglers

If we profile DDP training jobs that have stragglers, we can find that some processes may have much higher gradient synchronization costs (a.k.a., allreducing gradients) than other processes at a certain step. As a result, the distributed performance can be dominated by the communication cost even if the model size is very small. In this case, some processes run faster than the straggler(s) at a step, and hence they have to wait for the stragglers and spend a much longer time on allreduce.

The below shows screenshots of two trace files output by PyTorch profiler in a use case. Each screenshot profiles 3 steps.

  • The first screenshot shows that a process has a very high allreduce cost in both the first and the third steps, because this process reaches the synchronization phase earlier than the straggler(s), and it spends more time on waiting. On the other hand, the allreduce cost is relatively small in the second step, this suggests that 1) there is no straggler at this step; or 2) this process is the straggler among all the processes, so it does not need to wait for any other process.

chart showing allreduce cost

Both the 1st and the 3rd Steps Are Slowed Down by Stragglers

  • The second screenshot shows a normal case without stragglers. In this case, all the gradient synchronizations are relatively short.

chart showing normal case without stragglers

Normal Case Without Stragglers

Hierarchical SGD in PyTorch

Recently hierarchical SGD has been proposed to optimize the communication costs by mainly reducing the total amount of data transfer in large-scale distributed training, and multiple convergence analyses have been provided (example). As a main novelty of this post, at Cruise we could leverage hierarchical SGD to mitigate stragglers, which may also occur on training relatively small models. Our implementation has been upstreamed by Cruise to PyTorch in early 2022.

How Does Hierarchical SGD Work?

As the name implies, hierarchical SGD organizes all the processes into groups at different levels as a hierarchy, and runs synchronization by following the rules below:

  • All the groups at the same level have the same number of processes, and the processes in these groups synchronize at the same frequency concurrently, where the synchronization period is pre-defined by the user.
  • The higher level a group is, the larger synchronization period is used, as the synchronization becomes more expensive.
  • When multiple overlapping groups are supposed to synchronize according to their periods, to reduce redundant synchronization and avoid data race across groups, only the highest-level group runs synchronization.

The following figure illustrates an example of 4-level hierarchy SGD among 16 processes on 8 machines, each of which has 2 GPUs:

  1. Level 1: Each process runs mini-batch SGD locally;
  2. Level 2: Each 4-process group across 2 machines runs synchronization every 2 steps;
  3. Level 3: Each 8-process group across 4 machines runs synchronization every 4 steps;
  4. Level 4: The global process group of all 16 processes over 8 machines runs synchronization every 8 steps.

Particularly, when the step number can be divided by 8, only the synchronization at 3) is executed, and when the step number can be divided by 4 but not 8, only the synchronization at 2) is executed.

An example of 4-level hierarchy SGD among 16 processes on 8 machines, each of which has 2 GPUs

Intuitively, hierarchical SGD can be viewed as an extension of local SGD, which only has a two-level hierarchy – every process runs mini-batch SGD locally and then synchronizes globally at a certain frequency. This can also help explain that, just like local SGD, hierarchical SGD synchronizes model parameters instead of gradients. Otherwise the gradient descent will be mathematically incorrect when the frequency is greater than 1.

Why Can Hierarchical SGD Mitigate Stragglers?

The key insight here is that, when there is a random straggler, it only directly slows down a relatively small group of processes instead of all the processes. Next time another random straggler is very likely to slow down a different small group, and hence a hierarchy can help smooth out the straggler effect.

The example below assumes that there is a random straggler among totally 8 processes at every step. After 4 steps, vanilla DDP that runs synchronous SGD will be slowed down by straggler 4 times, because it runs global synchronization at every step. In contrast, hierarchical SGD runs synchronization with the groups of 4 processes after the first two steps, and then a global synchronization after another two steps. We can see that both the first two and the last two stragglers have a large overlap, and hence the performance loss can be mitigated.

flow diagram

Essentially, the mitigation effect of this hierarchical SGD example actually is between local SGD at a frequency of every 2 steps and every 4 steps. The main advantage of hierarchical SGD over local SGD is a better convergence efficiency of the same global synchronization frequency, because hierarchical SGD allows more low-level synchronization. Moreover, it is possible for hierarchical SGD to provide a global synchronization frequency lower than local SGD with model parity, leading to a higher training performance, especially in a large-scale distributed training.

Ease of Use

Straggler mitigation is not a novel study in distributed training. Multiple approaches have been proposed, such as gossip SGD, data encoding, gradient coding, as well as some particularly designed for parameter-server architecture, including backup workers and stale synchronous parallel. However, to the best of our knowledge, before this effort we have not found a good open-source PyTorch implementation of straggler mitigation that can work like a plugin to our training system at Cruise. In contrast, our implementation only requires the minimal changes – no need to modify the existing code or tune any existing hyperparameters. This is a very appealing advantage for industry users.

As the code example below shows, only a few lines need to be added to the setup of DDP model, and the training loop code can keep untouched. As explained previously, hierarchical SGD is an extended form of local SGD, so the enablement can be quite similar to local SGD (see PyTorch docs of PostLocalSGDOptimizer):

  1. Register a post-local SGD communication hook to run a warmup stage of fully synchronous SGD and defer hierarchical SGD.
  2. Create a post-local SGD optimizer that wraps an existing local optimizer and a hierarchical SGD configuration.
import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
    PostLocalSGDState,
    post_localSGD_hook,
)
from torch.distributed.optim import PostLocalSGDOptimizer

ddp_model = nn.parallel.DistributedDataParallel(
    module=model,
    device_ids=[rank],
)

# Register a post-local SGD communication hook for the warmup.
subgroup, _ = torch.distributed.new_subgroups()
state = PostLocalSGDState(subgroup=subgroup, start_localSGD_iter=1_000)
ddp_model.register_comm_hook(state, post_localSGD_hook)

# Wraps the existing (local) optimizer to run hierarchical model averaging.
optim = PostLocalSGDOptimizer(
  optim=optim,
  averager=hierarchicalSGD.HierarchicalModelAverager(
    # The config runs a 4-level hierarchy SGD among 128 processes:
    # 1) Each process runs mini-batch SGD locally;
    # 2) Each 8-process group synchronize every 2 steps;
    # 3) Each 32-process group synchronize every 4 steps;
    # 4) All 128 processes synchronize every 8 steps.
    period_group_size_dict=OrderedDict([(2, 8), (4, 32), (8, 128)]),
    # Do not run hierarchical SGD until 1K steps for model parity.
    warmup_steps=1_000)
)

Algorithm Hyperparameters

Hierarchical SGD has two major hyperparameters: period_group_size_dict and warmup_steps.

  • period_group_size_dict is an ordered dictionary mapping from synchronization period to process group size, used for initializing process groups of different sizes in a hierarchy to synchronize parameters concurrently. A larger group is expected to use a larger synchronization period.
  • warmup_steps specifies a number of steps as the warmup stage to run synchronous SGD before hierarchical SGD. Similar to post-local SGD algorithm, a warmup stage is usually recommended to achieve a higher accuracy. The value should be the same as start_localSGD_iter arg used in PostLocalSGDState when post_localSGD_hook is registered. Typically the warmup stage should at least cover the beginning of training when the loss is decreased drastically.

A subtle difference between the PyTorch implementation and the initial design proposed by relevant papers is that, after the warmup stage, by default the processes within each host still run intra-host gradient synchronization at every step. This is because that:

  1. The intra-host communication is relatively cheap, and it can usually significantly accelerate the convergence;
  2. The intra-host group (of size 4 or 8 for most industry users) can usually be a good choice of the smallest group of processes that synchronize most frequently in hierarchical SGD. If the synchronization period is 1, then gradient synchronization is faster than model parameter synchronization (a.k.a., model averaging), because DDP automatically overlaps gradient synchronization and the backward pass.

Such intra-host gradient synchronization can be disabled by unsetting post_local_gradient_allreduce arg in PostLocalSGDState.

Demonstration

Now we demonstrate that hierarchical SGD can accelerate distributed training by mitigating stragglers.

Experimental Setup

We compared the performance of hierarchical SGD against local SGD and synchronous SGD on ResNet18 (model size: 45MB). Since the model is so small, the training is not bottlenecked by data transfer cost during synchronization. To avoid the noises incurred by data loading from remote storage, the input data was randomly simulated from memory. We varied the number of GPUs used by training from 64 to 256. The batch size per worker is 32, and the number of iterations of training is 1,000. Since we don’t evaluate convergence efficiency in this set of experiments, warmup is not enabled.

We also emulated stragglers at a rate of 1% on 128 and 256 GPUs, and 2% on 64 GPUs, to make sure at least one stragglers at every step on average. These stragglers randomly appear on different CUDA devices. Each straggler stalls for 1 second besides the normal per-step training time (~55ms in our setup). This can be perceived as a practical scenario where 1% or 2% of input data are outliers in terms of the data pre-processing cost (I/O and/or data transformation on the fly) during training, and such cost is 20X+ larger than the average.

The code snippet below shows how a straggler can be emulated in the training loop. We applied it to a ResNet model, and it can be easily applied to the other models as well.

     loss = loss_fn(y_pred, y)
     # Emulate a straggler that lags for 1 second at a rate of 1%.
     if random.randint(1, 100) == 1:
         time.sleep(1)
     loss.backward()
     optimizer.step()

The experiments are conducted on us-central1 GCP cluster. Each machine has 4 NVIDIA Tesla T4 GPUs with 16 GB memory per GPU, connected through a 32 Gbit/s ethernet network. Each instance also features 96 vCPUs, 360 GB RAM.

Architecture ResNet18 (45MB)
Workers 64, 128, 256
Backend NCCL
GPU Tesla T4, 16 GB memory
Batch size 32 x ## of workers
Straggler Duration 1 sec
Straggler Rate 1% on 128 and 256 GPUs, 2% on 64 GPUs

We used multiple configurations for both local SGD and hierarchical SGD. Local SGD runs global synchronization every 2, 4, and 8 steps, respectively.

We ran hierarchical SGD with the following configurations:

  1. On 64 GPUs:
    1. Each 8-process group, 32-process, and the global 64-process group synchronizes every 2, 4, and 8 steps, respectively. Denoted as “HSGD 2-8,4-32,8-64”.
    2. Each 32-process group and the global 64-process group synchronizes every 4 and 8 steps, respectively. Denoted as “HSGD 4-32,8-64”.
  2. On 128 GPUs:
    1. Each 8-process group, 32-process group, and the global 128-process group synchronizes every 2, 4, and 8 steps, respectively. Denoted as “HSGD 2-8,4-32,8-128”.
    2. Each 32-process group and the global 128-process group synchronizes every 4 and 8 steps, respectively. Denoted as “HSGD 4-32,8-128”.
  3. On 256 GPUs:
    1. Each 4-process group, 16-process group, 64-process group, and the global 256-process group synchronizes every 1, 2, 4, and 8 steps, respectively. Denoted as “HSGD 1-4,2-16,4-64,8-256”.
    2. Each 8-process group, 64-process group, and the global 256-process group synchronizes every 2, 4, and 8 steps. Denoted as “HSGD 2-8,4-64,8-256”.
    3. Each 16-process group and the global 256-process group synchronizes every 4 and 8 steps, respectively. Denoted as “HSGD 4-16,8-256”.

Experimental Results

The figures below show the speedups of different communication schemes against the baseline of synchronous SGD, with the emulated stragglers. We can make the following observations:

  1. As expected, we can see that both hierarchical SGD and local SGD can achieve a higher speedup with a lower synchronization frequency.
  2. The speedups of the hierarchical SGD schemes are 2.08X-2.45X on 64 GPUs, 2.57X-2.68X on 128 GPUs, and 2.63X-3.25X on 256 GPUs, respectively. This shows that hierarchical SGD can significantly mitigate stragglers, and such mitigation can be more effective at a larger scale.
  3. The performance of local SGD with the synchronization period of 2 steps and 8 steps can be perceived as the lower bound and upper bound of the experimented hierarchical SGD schemes, respectively. This is because the hierarchical SGD schemes synchronize less frequently than every 2 steps globally, but their low-level synchronization at small groups are the extra overheads in comparison with the global synchronization every 8 steps.

Overall, hierarchical SGD can provide a finer-grained trade-off between communication cost and model quality than local SGD. Therefore, when local SGD at a relatively large synchronization period like 8 or 4 cannot give a satisfactory convergence efficiency, hierarchical SGD can have a much better chance to achieve both a good speedup and a model parity.

Since only simulated data is used in the experiments, we did not demonstrate the model parity here, which in practice can be achieved in two ways:

  1. Tuning the hyperparameters including both hierarchy and warmup steps;
  2. For some cases, hierarchical SGD could lead to a slightly lower quality than the original model for the same number of training steps (i.e., lower convergence rate), but with a speedup like 2X+ per training step, it is still possible to achieve model parity with more steps but still less total training time.

Speedups on 64 GPUs

Speedups on 128 GPUs

Speedups on 256 GPUs

Limitations

Before applying hierarchical SGD to straggler mitigation, the user should be aware of a few limitations of this approach:

  1. This approach can only mitigate non-persistent stragglers, which occur to different workers at different times. However, for the case of persistent stragglers, which can be caused by hardware degradation or a network issue on a specific host, these stragglers will slow down the same low-level subgroup at every time, leading to nearly no straggler mitigation.
  2. This approach can only mitigate low-frequency stragglers. E.g., if 30% workers can randomly become stragglers at every step, then most low-level synchronizations will still be slowed down by stragglers. As a result, hierarchical SGD may not show an obvious performance advantage over synchronous SGD.
  3. Since hierarchical SGD applies model averaging that does not overlap with backward like gradient averaging used by vanilla DDP, its performance gain of straggler mitigation must outweigh the performance loss of no overlap between communication and backward pass. Therefore, if stragglers only slow down training by less than 10%, hierarchical SGD may not be able to bring much speedup. This limitation can be addressed by overlapping optimizer step and backward pass in the future.
  4. Since hierarchical SGD is less well-studied than local SGD, there is no guarantee that hierarchical SGD with a finer-grained synchronization granularity can converge faster than certain advanced forms of local SGD, such as SlowMo, which can improve convergence efficiency with slow momentum. However, to the best of our knowledge, these advanced algorithms cannot be natively supported as a PyTorch DDP plugin like hierarchical SGD yet.

Acknowledgements

We would like to thank Cruise teammates Bo Tian, Sergei Vorobev, Eugene Selivonchyk, Tsugn-Hsien Lee, Dan Ring, Ian Ackerman, Lei Chen, Maegan Chew, Viet Anh To, Xiaohui Long, Zeyu Chen, Alexander Sidorov, Igor Tsvetkov, Xin Hu, Manav Kataria, Marina Rubtsova, and Mohamed Fawzy, as well as Meta teammates Shen Li, Yanli Zhao, Suraj Subramanian, Hamid Shojanzeri, Anjali Sridhar and Bernard Nguyen for the support.

Read More

Celebrate PyTorch* 2.0 with New Performance Features for AI Developers

Congratulations to the PyTorch Foundation for its release of PyTorch* 2.0! In this blog, I discuss the four features for which Intel made significant contributions to PyTorch 2.0:

  1. TorchInductor
  2. GNN
  3. INT8 Inference Optimization
  4. oneDNN Graph API

We at Intel are delighted to be part of the PyTorch community and appreciate the collaboration with and feedback from our colleagues at Meta as we co-developed these features.

Let’s get started.

1. TorchInductor CPU FP32 Inference Optimized

As part of the PyTorch 2.0 compilation stack, TorchInductor CPU backend optimization brings notable performance improvements via graph compilation over the PyTorch eager mode.

The TorchInductor CPU backend is sped up by leveraging the technologies from the Intel® Extension for PyTorch for Conv/GEMM ops with post-op fusion and weight prepacking, and PyTorch ATen CPU kernels for memory-bound ops with explicit vectorization on top of OpenMP*-based thread parallelization.

With these optimizations on top of the powerful loop fusions in TorchInductor codegen, we achieved up to a 1.7x FP32 inference performance boost over three representative deep learning benchmarks: TorchBench, HuggingFace, and timm1. Training and low-precision support are under development.

See the Improvements

The performance improvements on various backends are tracked on this TouchInductor CPU Performance Dashboard.

Improve Graph Neural Network (GNN) in PyG for Inference and Training Performance on CPU

GNN is a powerful tool to analyze graph structure data. This feature is designed to improve GNN inference and training performance on Intel® CPUs, including the new 4th Gen Intel® Xeon® Scalable processors.

PyTorch Geometric (PyG) is a very popular library built upon PyTorch to perform GNN workflows. Currently on CPU, GNN models of PyG run slowly due to the lack of GNN-related sparse matrix multiplication operations (i.e., SpMM_reduce) and the lack of several critical kernel-level optimizations (scatter/gather, etc.) tuned for GNN compute.

To address this, optimizations are provided for message passing between adjacent neural network nodes:

  • scatter_reduce: performance hotspot in message-passing when the edge index is stored in coordinate format (COO).
  • gather: backward computation of scatter_reduce, specially tuned for the GNN compute when the index is an expanded tensor.
  • torch.sparse.mm with reduce flag: performance hotspot in message-passing when the edge index is stored in compressed sparse row (CSR). Supported reduce flag for: sum, mean, amax, amin.

End-to-end performance benchmark results for both inference and training on 3rd Gen Intel® Xeon® Scalable processors 8380 platform and on 4th Gen 8480+ platform are discussed in Accelerating PyG on Intel CPUs.

Optimize int8 Inference with Unified Quantization Backend for x86 CPU Platforms

The new X86 quantization backend is a combination of FBGEMM (Facebook General Matrix-Matrix Multiplication) and oneAPI Deep Neural Network Library (oneDNN) backends and replaces FBGEMM as the default quantization backend for x86 platforms. The result: better end-to-end int8 inference performance than FBGEMM.

Users access the x86 quantization backend by default for x86 platforms, and the selection between different kernels is automatically done behind the scenes. The rules of selection are based on prior performance testing data done by Intel during feature development. Thus, the x86 backend replaces FBGEMM and may offer better performance, depending on the use case.

The selection rules are:

  • On platforms without VNNI (e.g., Intel® Core™ i7 processors), FBGEMM is always used.
  • On platforms with VNNI (e.g., 2nd-4th Gen Intel® Xeon® Scalable processors and future platforms):
    • For linear, FBGEMM is always used.
    • For convolution layers, FBGEMM is used for depth-wise convolution whose layers > 100; otherwise, oneDNN is used.

Note that as the kernels continue to evolve.

The selection rules above are subject to change to achieve better performance. Performance metrics for through-put speed-up ratios of unified x86 backend vs. pure FBGEMM are discussed in [RFC] Unified quantization backend for x86 CPU platforms #83888.

Leverage oneDNN Graph API to Accelerate Inference on CPU

oneDNN Graph API extends oneDNN with a flexible graph API to maximize the optimization opportunity for generating efficient code on Intel® AI hardware. It automatically identifies the graph partitions to be accelerated via fusion. The fusion patterns focus on fusing compute-intensive operations such as convolution, matmul, and their neighbor operations for both inference and training use cases.

Currently, BFloat16 and Float32 datatypes are supported and only inference workloads can be optimized. BF16 is only optimized on machines with Intel® Advanced Vector Extensions 512 (Intel® AVX-512) BF16 support.

Few or no modifications are needed in PyTorch to support newer oneDNN Graph fusions/optimized kernels. To use oneDNN Graph, users can:

  • Either use the API torch.jit.enable_onednn_fusion(True) before JIT tracing a model, OR …
  • Use its context manager, viz. with torch.jit.fuser(“fuser3”).
  • For accelerating BFloat16 inference, we rely on eager-mode AMP (Automatic Mixed Precision) support in PyTorch and disable JIT mode’s AMP.

See the PyTorch performance tuning guide.

Next Steps

Get the Software

Try out PyTorch 2.0 and realize the performance benefits for yourself from these Intel-contributed features.

We encourage you to check out Intel’s other AI Tools and Framework optimizations and learn about the open, standards-based oneAPI multiarchitecture, multivendor programming model that forms the foundation of Intel’s AI software portfolio.

For more details about 4th Gen Intel Xeon Scalable processor, visit AI Platform where you can learn about how Intel is empowering developers to run high-performance, efficient end-to-end AI pipelines.

PyTorch Resources

Read More

the upcoming PyTorch/XLA features and integrations

PyTorch & OpenXLA: The Path Forward

As we celebrate the release of OpenXLA, PyTorch 2.0, and PyTorch/XLA 2.0, it’s worth taking a step back and sharing where we see it all going in the short to medium term. With PyTorch adoption leading in the AI space and XLA supporting best-in-class compiler features, PyTorch/XLA is well positioned to provide a cutting edge development stack for both model training and inference. To achieve this, we see investments in three main areas:

  • Training Large Models – Large language models (LLM) and diffusion models have quickly risen in popularity and many cutting edge applications today are built on them. Further to this, training these models requires scale and more specifically the ability to train across thousands of accelerators. To achieve this we are investing in features such as AMP for mixed precision training, PjRt for increased runtime performance, SPMD / FSDP for efficient model sharding, Dynamic Shapes to enable new research approaches, and faster data loading through Ray and tf.data. Some of these features are already available in experimental or beta stages, and others are coming up this year with many heavily leveraging the underlying OpenXLA compiler stack.
  • Model Inference – With large models continuing to grow in size and computational cost, deployment becomes the next challenge as these models continue to find their way into applications. With the introduction of Dynamo in the PyTorch 2.0 release, PyTorch/XLA delivers performance competitive inference. We are however incorporating additional inference-oriented including model serving support, Dynamo for sharded large models, quantization via Torch.Export and StableHLO.
  • Ecosystem integration – We are expanding integration with Hugging Face and PyTorch Lightning so users can take advantage of upcoming PyTorch/XLA cutting edge features (e.g. FSDP support in Hugging Face) and the downstream OpenXLA features (e.g. Quantization) through familiar APIs.

Additionally, PyTorch/XLA is set to migrate to the open source OpenXLA as its default downstream compiler; allowing the PyTorch community to gain access to a leading, framework-agnostic compiler stack that enjoys industry-wide contribution and innovation. To achieve this, we will begin supporting StableHLO. As a result, OpenXLA will replace the existing TF:XLA dependency, overall streamlining the dependencies and creating leverage from the broader😄😐compiler ecosystem. PyTorch/XLA will also sunset the XRT runtime after migration. You can see the resulting high level stack below with the TensorFlow dependency stricken out:

the upcoming PyTorch/XLA features and integrations

Figure: the upcoming PyTorch/XLA features and integrations are illustrated here

We cannot be more excited about what’s ahead for PyTorch/XLA and invite the community😊🤗😘 to join us. PyTorch/XLA is developed fully in open source so please file issues, submit pull requests, and send RFCs to GitHub such that we can openly collaborate. 4ew!w can also try out PyTorch/XLA for yourself on various XLA devices including TPUs and GPUs.

Cheers,
The PyTorch/XLA Team at Google

Read More

GPU performance improvement for TorchInductor vs eager-mode

Experience the power of PyTorch 2.0 on AMD Solutions

PyTorch 2.0 represents a significant step forward for the PyTorch machine learning framework. The stable release of PyTorch 2.0 brings new features that unlock even higher performance, while remaining backward compatible with prior releases and retaining the Pythonic focus which has helped to make PyTorch so enthusiastically adopted by the AI/ML community. AMD has long been a strong proponent of PyTorch, and we are delighted that PyTorch 2.0 stable release includes support for AMD Instinct™ and Radeon™ GPUs that are supported by the ROCm™ software platform.

Along with the stable PyTorch 2.0 release, the Beta includes torch.compile underpinned by TorchInductor with support for AMD Instinct and Radeon GPUs through OpenAI Triton deep learning compiler. Through TorchInductor, developers can now generate low level code using Triton that are portable and performant to hand-written kernels on native hardware centric kernel programming models.

Compilers like Triton can optimize the code generated by machine learning frameworks such as PyTorch for multiple AI accelerators including AMD Instinct GPU accelerator by leveraging hardware-specific features of the AMD CDNA™ GPU architecture. This makes it easy for developers and users to switch seamlessly from any HW to AMD Instinct GPU accelerators and get great out of the box performance.

In addition, compilers like Triton can also enable developers to use high-level programming languages, such as Python, to write machine learning code that can be efficiently compiled and executed on specialized hardware. This can help greatly improve the productivity of machine learning developers, as they can focus on the algorithmic aspects of their models and rely on the compiler to generate efficient code.

OpenAI Triton is a just-in-time (JIT) compiler that optimizes and accelerates the execution of deep learning models on various hardware architectures such as CPUs, GPUs, and ASICs. Here is a high-level overview

  1. Model Loading: The Triton server loads a trained deep learning model from a storage location, typically a file in a model format such as torchfx graphs.
  2. Graph Optimization: Triton optimizes the graph representation of the loaded model. This includes transformations such as common subexpression elimination, dead code elimination, and operator fusion, which can help reduce memory usage and computational overhead.
  3. Tensor Memory Allocation: Triton allocates memory for the tensors used by the model. This includes input and output tensors, as well as intermediate tensors created during computation.
  4. Hardware-Specific Optimization: Triton applies hardware-specific optimizations to the optimized graph representation of the model. These optimizations can include using low-level hardware instructions, optimizing data movement between different types of memory, and leveraging hardware-specific data structures that leverages domain specific architectures like CDNA on AMD Instinct GPUs
  5. Code Generation: Triton generates efficient machine code for the optimized graph representation of the model. This code can then be executed on the hardware platform for which it was optimized.
  6. Execution: Triton executes the generated code on the hardware platform, typically in a just-in-time fashion. Triton can also dynamically adjust the batch size and other parameters of the model during execution to maximize performance.
  7. Result Return: Triton returns the results of the computation to the client that requested the inference.

By design, PyTorch 2.0 is backward compatible to earlier PyTorch releases. This holds true for the ROCm build of PyTorch 2.0 as well. Developers using PyTorch with AMD GPUs can migrate to PyTorch 2.0 with the confidence that their existing code will continue to work without any required changes, so there is no penalty to access the improvements that come with this release. On the other hand, using PyTorch 2.0 and TorchInductor can result in significant performance improvement over the default eager-mode as shown below.

The initial results using AMD Instinct MI250 GPUs already shows strong performance improvement with minimal optimization on TorchInductor compared to the default eager-mode. We see an average performance increase of up to 1.54X on 44 out of the 45 models on HuggingFace benchmarks suite with CamemBert, DistillGPT2 and T5Small being a few of the standout models with up to 1.5X or more performance improvement over eager-mode. We are looking forward to continued engagement with members of the PyTorch team at Meta to enable further optimization on ROCm software stack and the additional performance improvement for future PyTorch releases.

GPU performance improvement for TorchInductor vs eager-mode

Image 1: AMD MI250 GPU performance improvement for TorchInductor vs eager-mode using HuggingFace MI200-89.

PyTorch 2.0 follows the same set of install options as before to build and install for supporting AMD GPUs. These include an installable Python package hosted at pytorch.org, AMD’s public PyTorch docker image, and of course the option to build from source using the upstream PyTorch repository. As with PyTorch builds for other platforms, the specific command line to be run for pip-based install is provided by the configurator at https://pytorch.org/get-started/locally/.

The GPUs supported by the ROCm software platform which forms the basis for PyTorch support on AMD GPUs are documented at https://docs.amd.com/bundle/Hardware_and_Software_Reference_Guide/page/Hardware_and_Software_Support.html

Conclusion

PyTorch 2.0 represents a major step in continuing to broaden support for ML developers by increasing performance while maintaining a simple, Pythonic interface. This performance uplift is made possible in large part by the new TorchInductor infrastructure, which in turn harnesses the Triton ML programming language and just-in-time compiler. AMD’s support for these technologies allows users to realize the full promise of the new PyTorch architecture. Our GPU support in PyTorch 2.0 is just one manifestation of a larger vision around AI and machine learning. AI/ML plays an important role in multiple AMD product lines, including Instinct and Radeon GPUs, Alveo™ data center accelerators, and both Ryzen™ and EPYC processors. These hardware and software initiatives are all part of AMD’s Pervasive AI vision, and we look forward to addressing the many new challenges and opportunities of this dynamic space.

MI200-89 – PyTorch Inductor mode HuggingFace Transformers training speedup, running the standard PyTorch 2.0 test suite, over PyTorch eager-mode comparison based on AMD internal testing on a single GCD as of 3/10/2023 using a 2P AMD EPYC™ 7763 production server with 4x AMD Instinct™ MI250 (128GB HBM2e) 560W GPUs with Infinity Fabric™ technology; host ROCm™ 5.3, guest ROCm™ 5.4.4, PyTorch 2.0.0, Triton 2.0. Server manufacturers may vary configurations, yielding different results. Performance may vary based on factors including use of latest drivers and optimizations.

© 2023 Advanced Micro Devices, Inc. All rights reserved. AMD, the AMD Arrow logo, AMD CDNA, AMD Instinct, EPYC, Radeon, ROCm, Ryzen, and combinations thereof are trademarks of Advanced Micro Devices, Inc. Other product names used in this publication are for identification purposes only and may be trademarks of their respective owners.

Read More

Accelerated PyTorch 2 Transformers

Accelerated PyTorch 2 Transformers

The PyTorch 2.0 release includes a new high-performance implementation of the PyTorch Transformer API with the goal of making training and deployment of state-of-the-art Transformer models affordable. Following the successful release of “fastpath” inference execution (“Better Transformer”), this release introduces high-performance support for training and inference using a custom kernel architecture for scaled dot product attention (SPDA).

You can take advantage of the new fused SDPA kernels either by calling the new SDPA operator directly (as described in the SDPA tutorial), or transparently via integration into the pre-existing PyTorch Transformer API. All features of the PyTorch Transformer API will continue to work compatibly, with many features mapped to high-performance SDPA kernels, while other features are impossible to support with higher performance (e.g., need_weights, as per below) while expanded high-performance support for other features may still be under active development.

Similar to the “fastpath” architecture, custom kernels are fully integrated into the PyTorch Transformer API – thus, using the native Transformer and MultiHeadAttention API will enable users to transparently see significant speed improvements. Unlike the “fastpath” architecture, the newly introduced “custom kernels” support many more use cases including models using Cross-Attention, Transformer Decoders, and for training models, in addition to the existing fastpath inference for fixed and variable sequence length Transformer Encoder and Self Attention use cases.

To take full advantage of different hardware models and Transformer use cases, multiple SDPA custom kernels are supported, with custom kernel selection logic that will pick the highest-performance kernel for a given model and hardware type. In particular, the first custom kernels included with the PyTorch 2.0 release are the Flash Attention kernel (sdpa_flash, for 16-bit floating point training and inference on Nvidia GPUs with SM80+ architecture level) and the xFormers memory-efficient attention kernel (sdpa_mem_eff, for 16-bit and 32-bit floating point training and inference on a broad range of Nvidia GPUs). A general-purpose kernel sdpa_math provides an implementation when the custom kernels are not applicable.

As mentioned, custom kernels provide a wider range of support for execution scenarios To ensure efficient execution (e,g., to use GPU tensor cores), model configurations need to meet a small number of requirements. This list of requirements will evolve over time, prospectively relaxing constraints limiting the usage of currently supported custom kernels, or providing additional kernels in the future.

For the most up to date list of custom kernels and dispatch constraints, you can refer to sdp_utils.h. As of PyTorch 2.0, the existing fused SDPA kernels have the following constraints:

  • Flash Attention only supports 16 bit floating point data types (float16 and bfloat16).
  • The head dimension must be a multiple of 8 for 16-bit floating point numbers and a multiple of 4 for 32-bit floating point numbers. At present, the maximum head_dim support for the Flash Attention custom kernel is 128.
  • The CUDA architecture level must be sm5x or better for the mem_efficient kernel, and sm80 for Flash Attention.
  • Flash Attention supports arbitrary dropout, in PyTorch 2.0 the mem_efficient kernel does not support dropout (i.e., dropout must be set to zero for this kernel to be selected in PyTorch 2.0).
  • To support variable-sequence length batches, all SDPA kernels support Nested Tensor inputs that combine input data and padding information using variable sequence length tensors for forward. (You can find more information about Nested Tensors in the Nested Tensor tutorial.)
  • You can specify both a key_padding_mask and an attn_mask by combining them before passing them to the SDPA operator. In particular, you can use the per-batch-element key padding mask of the nn.Transformer API to implement training for variable-sequence length inputs in a batch.
  • At present, the only attention mask supported by fused kernel implementation is the causal mask commonly used for training. To specify the causal mask in custom kernels, it must be specified with the is_causal boolean and attn_mask must be None.
  • Support for Nested Tensors is still under development. Specifically, in PyTorch 2.0, only the sdpa_math kernel supports training with Nested Tensors. Also, PyTorch 2.0 does not support Nested Tensors as part of code being compiled with torch.compile().
  • The SDPA operator does not support returning averaged attention weights because computing them defeats the optimizations that enabled fused kernels to execute more efficiently. The argument need_weights for torch.nn.MultiheadAttention’s forward function defaults to True. In order to use the fused kernels, need_weights needs to be set to need_weights=False.

We find that an attention mask is rarely used in real-world applications, except for the causal mask during training. Consequently, we reduce kernel complexity and compute cost by building in the option to use a causal mask as attention mask, and select this new capability with the is_causal parameter introduced in conjunction with the new SDPA operator.

Providing the is_causal Boolean flag for the frequently used causal mask also obviates the expensive and memory-intensive allocation of a causal mask, increasing training memory efficiency by allowing more memory to be used for large batch sizes, and reduce memory bandwidth and cache contention – which are both at a premium in GPU accelerators – by not needing to load an attention mask tensor.

If the constraints of none of the available custom kernels are met, then training falls back to using the default sdpa_math kernel, implementing the mathematical equations for scaled dot product attention using a sequence of PyTorch operator to implement SDPA. This is the most general “catch-all” fallback kernel to ensure successful training for all models.

In addition to the existing Transformer API, model developers may also use the scaled dot product attention kernels directly by calling the new scaled_dot_product_attention() operator. This operator may be used to efficiently implement multi-head attention by combining it with in-projection and outprojection, as described in the SDPA tutorial.

In addition to adding custom kernels, Accelerated PyTorch 2 Transformers are integrated with PyTorch 2.0 compilation. To use your model while benefiting from the additional acceleration of PT2-compilation (for inference or training), pre-process the model with

model = torch.compile(model)

We have achieved major speedups for training transformer models and in particular large language models with Accelerated PyTorch 2 Transformers using a combination of custom kernels and torch.compile().

Better Transformer chartFigure: Using scaled dot product attention with custom kernels and torch.compile delivers significant speedups for training large language models, such as for nanoGPT shown here.

Finally, because the custom kernels are much more memory efficient, try to increase the size of training batches to achieve faster training with increased batch size.

In addition to automatic kernel selection, a context manager enables developers to override the kernel selection algorithm – this is not required for day to day operation, but enables developers to debug their code as well as enable performance engineers to override kernel selection. The SDPA tutorial provides additional information on using the SDPA context manager.

In addition to availability as part of the nn.Transformer API, Accelerated PyTorch 2 Transformer custom kernels are also available in conjunction with the torchtext, torchvision, and fairseq domain libraries with the launch of PyTorch 2.0.

Read More

Inference Speedup - PyTorch/XLA Dynamo on TPU

PyTorch 2.0 & XLA—The Latest Cutting Edge Features

Today, we are excited to share our latest work for PyTorch/XLA 2.0. The release of PyTorch 2.0 is yet another major milestone for this storied community and we are excited to continue to be part of it. When the PyTorch/XLA project started in 2018 between Google and Meta, the focus was on bringing cutting edge Cloud TPUs to help support the PyTorch community. Along the way, others in the community such as Amazon joined the project and very quickly the community expanded. We are excited about XLA’s direction and the benefits this project continues to bring to the PyTorch community. In this blog we’d like to showcase some key features that have been in development, show code snippets, and illustrate the benefit through some benchmarks.

TorchDynamo / torch.compile (Experimental)

TorchDynamo (Dynamo) is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook in; its biggest feature is to dynamically modify Python bytecode just before execution. In the PyTorch/XLA 2.0 release, an experimental backend for Dynamo is provided for both inference and training.

Dynamo provides a Torch FX (FX) graph when it recognizes a model pattern and PyTorch/XLA uses a Lazy Tensor approach to compile the FX graph and return the compiled function. To get more insight regarding the technical details about PyTorch/XLA’s dynamo implementation, check out this dev-discuss post and dynamo doc.

Here is a small code example of running ResNet18 with torch.compile:

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
      xla_resnet18, backend='torchxla_trace_once')
  for data, _ in loader:
    output = dynamo_resnet18(data)

With torch.compile PyTorch/XLA only traces the ResNet18 model once during the init time and executes the compiled binary everytime dynamo_resnet18 is invoked, instead of tracing the model every step. To illustrate the benefits of Dynamo+XLA, below is an inference speedup analysis to compare Dynamo and LazyTensor (without Dynamo) using TorchBench on a Cloud TPU v4-8 where the y-axis is the speedup multiplier.

Inference Speedup - PyTorch/XLA Dynamo on TPU

Dynamo for training is in the development stage with its implementation being at an earlier stage than inference. Developers are welcome to test this early feature, however, in the 2.0 release, PyTorch/XLA supports the forward and backward pass graphs and not the optimizer graph; the optimizer graph is available in the nightly builds and will land in the PyTorch/XLA 2.1 release. Below is an example of what training looks like using the ResNet18 example with torch.compile:

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='aot_torchxla_trace_once')
  for data, target in loader:
    output = dynamo_train_model(xla_resnet18, data, target)

Note that the backend for training is aot_torchxla_trace_once (API will be updated for stable release) whereas the inference backend is torchxla_trace_once (name subject to change). We expect to extract and execute 3 graphs per training step instead of 1 training step if you use the Lazy tensor. Below is a training speedup analysis to compare Dynamo and Lazy using the TorchBench on Cloud TPU v4-8.

Training Speedup - PyTorch/XLA Dynamo on TPU

PJRT Runtime (Beta)

PyTorch/XLA is migrating from XRT to the new PJRT runtime. PJRT is a better-maintained stack, with demonstrated performance advantages, including, on average, a 35% performance for training on TorchBench 2.0 models. It also supports a richer set of features enabling technologies like SPMD. In the PyTorch/XLA 2.0 release, PJRT is the default runtime for TPU and CPU; GPU support is in experimental state. The PJRT features included in the PyTorch/XLA 2.0 release are:

  • TPU runtime implementation in libtpu using the PJRT Plugin API improves performance by up to 30%
  • torch.distributed support for TPU v2 and v3, including pjrt:// init_method (Experimental)
  • Single-host GPU support. Multi-host support coming soon. (Experimental)

Switching to PJRT requires no change (or minimal change for GPUs) to user code (see pjrt.md for more details). Runtime configuration is as simple as setting the PJRT_DEVICE environment variable to the local device type (i.e. TPU, GPU, CPU). Below are examples of using PJRT runtimes on different devices.

# TPU Device
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
# TPU Pod Device
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git"

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
# GPU Device (Experimental)
PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

Below is a performance comparison between XRT and PJRT by task on TorchBench 2.0 on v4-8 TPU. To learn more about PJRT vs. XRT please review the documentation.

TorchBench Training Time

Parallelization

GSPMD (Experimental)

We are delighted to introduce General and Scalable Parallelization for ML Computation Graphs (GSPMD) in PyTorch as a new experimental data & model sharding solution. GSPMD provides automatic parallelization for common ML workloads, allowing developers to write PyTorch programs as if on a single large device and without custom sharded computation ops and/or collective communication ops. The XLA compiler transforms the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. The API (RFC) will be available in the PyTorch/XLA 2.0 release as an experimental feature on a single TPU VM host.

Next Steps for GSPMD

GSPMD is experimental in 2.0 release. To bring it to Stable status, we plan to address a number of feature gaps and known issues in the following releases, including multi-host support, DTensor integration, partial replication sharding, asynchronous data loading, and checkpointing.

FSDP (Beta)

PyTorch/XLA introduced fully sharded data parallel (FSDP) experimental support in version 1.12. This feature is a parallel representation of PyTorch FSDP and there are subtle differences in how XLA and upstream CUDA kernels are set up. auto_wrap_policy is a new argument that enables developers to automatically specify conditions for propagating partitioning specifications to neural network submodules. auto_wrap_policys may be simply passed in as an argument when wrapping a model with FSDP. Two auto_wrap_policy callables worth noting are: size_based_auto_wrap_policy, transformer_auto_wrap_policy.

size_based_auto_wrap_policy enables users to wrap submodules with a minimum number of parameters. The example below wraps model submodules having at least 10M parameters.

auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=1e7)

transformer_auto_wrap_policy enables users to wrap all submodules that match a specific layer type. The example below wraps model submodules named torch.nn.Conv2d. To learn more, review this ResNet example by Ronghang Hu.

auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

PyTorch/XLA FSDP is now integrated in HuggingFace trainer class (PR) enabling users to train much larger models on PyTorch/XLA (official Hugging Face documentation). A 16B parameters GPT2 model trained on Cloud TPU v4-64 with this FSDP configuration achieved 39% hardware utilization.

TPU Accelerator – Num Devices v4-64
GPT2 Parameter Count 16B
Layers Wrapped with FSDP GPT2Block
TFLOPs / Chip 275
PFLOPs / Step 50
Hardware Utilization 39%

Differences Between FSDP & GSPMD

FSDP is a data parallelism technique that reduces device memory footprint by storing model parameters, optimizer states, and gradients all sharded. Note that the actual computation is still local to the device and requires all-gathering the sharded model parameters for both forward and backward passes, hence the name “data parallel”. FSDP is one of the newest additions to PyTorch/XLA to scale large model training.

GSPMD on the other hand, is a general parallelization system that enables various types of parallelisms, including both data and model parallelisms. PyTorch/XLA provides a sharding annotation API and XLAShardedTensor abstraction, so a user can annotate any tensor with sharding specs in the PyTorch program. Developers don’t need to manually implement sharded computations or inject collective communications ops to get it right. The XLA compiler does the work so that each computation can run in a distributed manner on multiple devices.

Examples & Preliminary Results

To learn about PyTorch/XLA parallelism sharding API, visit our RFC and see the Sample Code references. Below is a simple example to enable data and model parallelism.

model = SimpleLinear().to(xm.xla_device())
# Sharding annotate the linear layer weights.
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
# Training loop
model.train()
for step, (data, target) in enumerate(loader):
  optimizer.zero_grad()
  data = data.to(xm.xla_device())
  target = target.to(xm.xla_device())
  # Sharding annotate input data, we can shard any input
  # dimensions. Sharidng the batch dimension enables 
  # data parallelism, sharding the feature dimension enables
  # spatial partitioning.
  xs.mark_sharding(data, mesh, partition_spec)
  ouput = model(data)
  loss = loss_fn(output, target)
  optimizer.step()
  xm.mark_step()

The following graph highlights the memory efficiency benefits of PyTorch/XLA FSDP and SPMD on Cloud TPU v4-8 running ResNet50.

Batch Size Scaling with Spatial Partitioning

Closing Thoughts…

We are excited to bring these features to the PyTorch community, and this is really just the beginning. Areas like dynamic shapes, deeper support for OpenXLA and many others are in development and we plan to put out more blogs to dive into the details. PyTorch/XLA is developed fully open source and we invite you to join the community of developers by filing issues, submitting pull requests, and sending RFCs on GitHub. You can try PyTorch/XLA on a variety of XLA devices including TPUs and GPUs. Here is how to get started.

Congratulations again to the PyTorch community on this milestone!

Cheers,

The PyTorch Team at Google

Read More

Accelerated Diffusers with PyTorch 2.0

Accelerated Diffusers with PyTorch 2.0

PyTorch 2.0 has just been released. Its flagship new feature is torch.compile(), a one-line code change that promises to automatically improve performance across codebases. We have previously checked on that promise in Hugging Face Transformers and TIMM models, and delved deep into its motivation, architecture and the road ahead.

As important as torch.compile() is, there’s much more to PyTorch 2.0. Notably, PyTorch 2.0 incorporates several strategies to accelerate transformer blocks, and these improvements are very relevant for diffusion models too. Techniques such as FlashAttention, for example, have become very popular in the diffusion community thanks to their ability to significantly speed up Stable Diffusion and achieve larger batch sizes, and they are now part of PyTorch 2.0.

In this post we discuss how attention layers are optimized in PyTorch 2.0 and how these optimization are applied to the popular 🧨 Diffusers library. We finish with a benchmark that shows how the use of PyTorch 2.0 and Diffusers immediately translates to significant performance improvements across different hardware.

Accelerating transformer blocks

PyTorch 2.0 includes a scaled dot-product attention function as part of torch.nn.functional. This function encompasses several implementations that can be applied depending on the inputs and the hardware in use. Before PyTorch 2.0, you had to search for third-party implementations and install separate packages in order to take advantage of memory optimized algorithms, such as FlashAttention. The available implementations are:

  • FlashAttention, from the official FlashAttention project.
  • Memory-Efficient Attention, from the xFormers project.
  • A native C++ implementation suitable for non-CUDA devices or when high-precision is required.

All these methods are available by default, and PyTorch will try to select the optimal one automatically through the use of the new scaled dot-product attention (SDPA) API. You can also individually toggle them for finer-grained control, see the documentation for details.

Using scaled dot-product attention in diffusers

The incorporation of Accelerated PyTorch 2.0 Transformer attention to the Diffusers library was achieved through the use of the set_attn_processor method, which allows for pluggable attention modules to be configured. In this case, a new attention processor was created, which is enabled by default when PyTorch 2.0 is available. For clarity, this is how you could enable it manually (but it’s usually not necessary since diffusers will automatically take care of it):

from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0())

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

Stable Diffusion Benchmark

We ran a number of tests using accelerated dot-product attention from PyTorch 2.0 in Diffusers. We installed diffusers from pip and used nightly versions of PyTorch 2.0, since our tests were performed before the official release. We also used torch.set_float32_matmul_precision('high') to enable additional fast matrix multiplication algorithms.

We compared results with the traditional attention implementation in diffusers (referred to as vanilla below) as well as with the best-performing solution in pre-2.0 PyTorch: PyTorch 1.13.1 with the xFormers package (v0.0.16) installed.

Results were measured without compilation (i.e., no code changes at all), and also with a single call to torch.compile() to wrap the UNet module. We did not compile the image decoder because most of the time is spent in the 50 denoising iterations that run UNet evaluations.

Results in float32

Diffusers Speedup vs xFormers float32

The following figures explore performance improvement vs batch size for various representative GPUs belonging to different generations. We collected data for each combination until we reached maximum memory utilization. Vanilla attention runs out of memory earlier than xFormers or PyTorch 2.0, which explains the missing bars for larger batch sizes. Similarly, A100 (we used the 40 GB version) is capable of running batch sizes of 64, but the other GPUs could only reach 32 in our tests.

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (V100, float32)

We found very significant performance improvements over vanilla attention across the board, without even using torch.compile(). An out of the box installation of PyTorch 2.0 and diffusers yields about 50% speedup on A100 and between 35% and 50% on 4090 GPUs, depending on batch size. Performance improvements are more pronounced for modern CUDA architectures such as Ada (4090) or Ampere (A100), but they are still very significant for older architectures still heavily in use in cloud services.

In addition to faster speeds, the accelerated transformers implementation in PyTorch 2.0 allows much larger batch sizes to be used. A single 40GB A100 GPU runs out of memory with a batch size of 10, and 24 GB high-end consumer cards such as 3090 and 4090 cannot generate 8 images at once. Using PyTorch 2.0 and diffusers we could achieve batch sizes of 48 for 3090 and 4090, and 64 for A100. This is of great significance for cloud services and applications, as they can efficiently process more images at a time.

When compared with PyTorch 1.13.1 + xFormers, the new accelerated transformers implementation is still faster and requires no additional packages or dependencies. In this case we found moderate speedups of up to 2% on datacenter cards such as A100 or T4, but performance was great on the two last generations of consumer cards: up to 20% speed improvement on 3090 and between 10% and 45% on 4090, depending on batch size.

When torch.compile() is used, we get an additional performance boost of (typically) 2% and 3% over the previous improvements. As compilation takes some time, this is better geared towards user-facing inference services or training.

Results in float16

Diffusers Speedup vs xFormers float16

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float16)

When we consider float16 inference, the performance improvements of the accelerated transformers implementation in PyTorch 2.0 are between 20% and 28% over standard attention, across all the GPUs we tested, except for the 4090, which belongs to the more modern Ada architecture. This GPU benefits from a dramatic performance improvement when using PyTorch 2.0 nightlies. With respect to optimized SDPA vs xFormers, results are usually on par for most GPUs, except again for the 4090. Adding torch.compile() to the mix boosts performance a few more percentage points across the board.

Conclusions

PyTorch 2.0 comes with multiple features to optimize the crucial components of the foundational transformer block, and they can be further improved with the use of torch.compile. These optimizations lead to significant memory and time improvements for diffusion models, and remove the need for third-party library installations.

To take advantage of these speed and memory improvements all you have to do is upgrade to PyTorch 2.0 and use diffusers >= 0.13.0.

For more examples and in-detail benchmark numbers, please also have a look at the Diffusers with PyTorch 2.0 docs.

Acknowledgement

The authors are grateful to the PyTorch team for their insights, assistance and suggestions during the elaboration of this post, and for creating such excellent software. We are particularly indebted to Hamid Shojanazeri, Grigory Sizov, Christian Puhrsch, Driss Guessous, Michael Gschwind and Geeta Chauhan.

Read More

PyTorch 2.0: Our next generation release that is faster, more Pythonic and Dynamic as ever

PyTorch 2.0: Our next generation release that is faster, more Pythonic and Dynamic as ever

We are excited to announce the release of PyTorch® 2.0 which we highlighted during the PyTorch Conference on 12/2/22! PyTorch 2.0 offers the same eager-mode development and user experience, while fundamentally changing and supercharging how PyTorch operates at compiler level under the hood with faster performance and support for Dynamic Shapes and Distributed.

This next-generation release includes a Stable version of Accelerated Transformers (formerly called Better Transformers); Beta includes torch.compile as the main API for PyTorch 2.0, the scaled_dot_product_attention function as part of torch.nn.functional, the MPS backend, functorch APIs in the torch.func module; and other Beta/Prototype improvements across various inferences, performance and training optimization features on GPUs and CPUs. For a comprehensive introduction and technical overview of torch.compile, please visit the 2.0 Get Started page.

Along with 2.0, we are also releasing a series of beta updates to the PyTorch domain libraries, including those that are in-tree, and separate libraries including TorchAudio, TorchVision, and TorchText. An update for TorchX is also being released as it moves to community supported mode. More details can be found in this library blog.

This release is composed of over 4,541 commits and 428 contributors since 1.13.1. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.0 and the overall 2-series this year.

Summary:

  • torch.compile is the main API for PyTorch 2.0, which wraps your model and returns a compiled model. It is a fully additive (and optional) feature and hence 2.0 is 100% backward compatible by definition.
  • As an underpinning technology of torch.compile, TorchInductor with Nvidia and AMD GPUs will rely on OpenAI Triton deep learning compiler to generate performant code and hide low level hardware details. OpenAI Triton-generated kernels achieve performance that’s on par with hand-written kernels and specialized cuda libraries such as cublas.
  • Accelerated Transformers introduce high-performance support for training and inference using a custom kernel architecture for scaled dot product attention (SPDA). The API is integrated with torch.compile() and model developers may also use the scaled dot product attention kernels directly by calling the new scaled_dot_product_attention() operator.
  • Metal Performance Shaders (MPS) backend provides GPU accelerated PyTorch training on Mac platforms with added support for Top 60 most used ops, bringing coverage to over 300 operators.
  • Amazon AWS optimizes the PyTorch CPU inference on AWS Graviton3 based C7g instances. PyTorch 2.0 improves inference performance on Graviton compared to the previous releases, including improvements for Resnet50 and Bert.
  • New prototype features and technologies across TensorParallel, DTensor, 2D parallel, TorchDynamo, AOTAutograd, PrimTorch and TorchInductor.
Stable Beta Prototype Performance Improvements

Accelerated PT 2 Transformers

torch.compile

DTensor

CUDA support for 11.7 & 11.8 (deprecating CUDA 11.6)

PyTorch MPS Backend

TensorParallel

Python 1.8 (deprecating Python 1.7)

Scaled dot product attention

2D Parallel

AWS Graviton3

functorch

Torch.compile (dynamic=True)

Dispatchable Collectives
Torch.set_default & torch.device

X86 quantization backend

GNN inference and training performance

*To see a full list of public 2.0, 1.13 and 1.12 feature submissions click here.

Stable Features

[Stable] Accelerated PyTorch 2 Transformers (previously known as “Better Transformer”)

The PyTorch 2.0 release includes a new high-performance implementation of the PyTorch Transformer API, formerly known as “Better Transformer API, “ now renamed Accelerated PyTorch 2 Transformers. In releasing accelerated PT2 Transformers, our goal is to make training and deployment of state-of-the-art Transformer models affordable across the industry. This release introduces high-performance support for training and inference using a custom kernel architecture for scaled dot product attention (SPDA).

Similar to the “fastpath” architecture, custom kernels are fully integrated into the PyTorch Transformer API – thus, using the native Transformer and MultiHeadAttention API will enable users to:

  • transparently see significant speed improvements;
  • support many more use cases including models using Cross-Attention, Transformer Decoders, and for training models; and
  • continue to use fastpath inference for fixed and variable sequence length Transformer Encoder and Self Attention use cases.

To take full advantage of different hardware models and Transformer use cases, multiple SDPA custom kernels are supported (see below), with custom kernel selection logic that will pick the highest-performance kernel for a given model and hardware type. In addition to the existing Transformer API, model developers may also use the

scaled dot product attention kernels directly by calling the new scaled_dot_product_attention() operator. Accelerated PyTorch 2 Transformers are integrated with torch.compile() . To use your model while benefiting from the additional acceleration of PT2-compilation (for inference or training), pre-process the model with model = torch.compile(model).

We have achieved major speedups for training transformer models and in particular large language models with Accelerated PyTorch 2 Transformers using a combination of custom kernels and torch.compile().

alt_textFigure: Using scaled dot product attention with custom kernels and torch.compile delivers significant speedups for training large language models, such as for nanoGPT shown here.

Beta Features

[Beta] torch.compile

torch.compile is the main API for PyTorch 2.0, which wraps your model and returns a compiled model. It is a fully additive (and optional) feature and hence 2.0 is 100% backward compatible by definition.

Underpinning torch.compile are new technologies – TorchDynamo, AOTAutograd, PrimTorch and TorchInductor:

  • TorchDynamo captures PyTorch programs safely using Python Frame Evaluation Hooks and is a significant innovation that was a result of 5 years of our R&D into safe graph capture.
  • AOTAutograd overloads PyTorch’s autograd engine as a tracing autodiff for generating ahead-of-time backward traces.
  • PrimTorch canonicalizes ~2000+ PyTorch operators down to a closed set of ~250 primitive operators that developers can target to build a complete PyTorch backend. This substantially lowers the barrier of writing a PyTorch feature or backend.
  • TorchInductor is a deep learning compiler that generates fast code for multiple accelerators and backends. For NVIDIA and AMD GPUs, it uses OpenAI Triton as a key building block. For intel CPUs, we generate C++ code using multithreading, vectorized instructions and offloading appropriate operations to mkldnn when possible.

With all the new technologies, torch.compile is able to work 93% of time across 165 open-source models and runs 20% faster on average at float32 precision and 36% faster on average at AMP precision.

For more information, please refer to https://pytorch.org/get-started/pytorch-2.0/ and for TorchInductor CPU with Intel here.

[Beta] PyTorch MPS Backend

MPS backend provides GPU-accelerated PyTorch training on Mac platforms. This release brings improved correctness, stability, and operator coverage.

MPS backend now includes support for the Top 60 most used ops, along with the most frequently requested operations by the community, bringing coverage to over 300 operators. The major focus of the release was to enable full OpInfo-based forward and gradient mode testing to address silent correctness issues. These changes have resulted in wider adoption of MPS backend by 3rd party networks such as Stable Diffusion, YoloV5, WhisperAI, along with increased coverage for Torchbench networks and Basic tutorials. We encourage developers to update to the latest macOS release to see the best performance and stability on the MPS backend.

Links

  1. MPS Backend
  2. Developer information
  3. Accelerated PyTorch training on Mac
  4. Metal, Metal Performance Shaders & Metal Performance Shaders Graph

[Beta] Scaled dot product attention 2.0

We are thrilled to announce the release of PyTorch 2.0, which introduces a powerful scaled dot product attention function as part of torch.nn.functional. This function includes multiple implementations that can be seamlessly applied depending on the input and hardware in use.

In previous versions of PyTorch, you had to rely on third-party implementations and install separate packages to take advantage of memory-optimized algorithms like FlashAttention. With PyTorch 2.0, all these implementations are readily available by default.

These implementations include FlashAttention from HazyResearch, Memory-Efficient Attention from the xFormers project, and a native C++ implementation that is ideal for non-CUDA devices or when high-precision is required.

PyTorch 2.0 will automatically select the optimal implementation for your use case, but you can also toggle them individually for finer-grained control. Additionally, the scaled dot product attention function can be used to build common transformer architecture components.

Learn more with the documentation and this tutorial.

[Beta] functorch -> torch.func

Inspired by Google JAX, functorch is a library that offers composable vmap (vectorization) and autodiff transforms. It enables advanced autodiff use cases that would otherwise be tricky to express in PyTorch. Examples include:

We’re excited to announce that, as the final step of upstreaming and integrating functorch into PyTorch, the functorch APIs are now available in the torch.func module. Our function transform APIs are identical to before, but we have changed how the interaction with NN modules work. Please see the docs and the migration guide for more details.

Furthermore, we have added support for torch.autograd.Function: one is now able to apply function transformations (e.g. vmap, grad, jvp) over torch.autograd.Function.

[Beta] Dispatchable Collectives

Dispatchable collectives is an improvement to the existing init_process_group() API which changes backend to an optional argument. For users, the main advantage of this feature is that it will allow them to write code that can run on both GPU and CPU machines without having to change the backend specification. The dispatchability feature will also make it easier for users to support both GPU and CPU collectives, as they will no longer need to specify the backend manually (e.g. “NCCL” or “GLOO”). Existing backend specifications by users will be honored and will not require change.

Usage example:

import torch.distributed.dist
…
# old
dist.init_process_group(backend=”nccl”, ...)
dist.all_reduce(...) # with CUDA tensors works
dist.all_reduce(...) # with CPU tensors does not work

# new
dist.init_process_group(...) # backend is optional
dist.all_reduce(...) # with CUDA tensors works
dist.all_reduce(...) # with CPU tensors works

Learn more here.

[Beta] torch.set_default_device and torch.device as context manager

torch.set_default_device allows users to change the default device that factory functions in PyTorch allocate on. For example, if you torch.set_default_device(‘cuda’), a call to torch.empty(2) will allocate on CUDA (rather than on CPU). You can also use torch.device as a context manager to change the default device on a local basis. This resolves a long standing feature request from PyTorch’s initial release for a way to do this.

Learn more here.

[Beta] “X86” as the new default quantization backend for x86 CPU

The new X86 quantization backend, which utilizes FBGEMM and oneDNN kernel libraries, replaces FBGEMM as the default quantization backend for x86 CPU platforms and offers improved int8 inference performance compared to the original FBGEMM backend, leveraging the strengths of both libraries, with 1.3X – 2X inference performance speedup measured on 40+ deep learning models. The new backend is functionally compatible with the original FBGEMM backend.

Table: Geomean Speedup of X86 Quantization Backend vs. FBGEMM Backend

1 core/instance 2 cores/instance 4 cores/instance 1 socket (32 cores)/instance
Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz 1.76X 1.80X 2.04X 1.34X

By default, users on x86 platforms will utilize the x86 quantization backend and their PyTorch programs will remain unchanged when using the default backend. Alternatively, users have the option to specify “X86” as the quantization backend explicitly. Example code is show below:

import torch
from torch.ao.quantization import get_default_qconfig_mappingfrom torch.quantization.quantize_fx
import prepare_fx, convert_fx
 
# get default configuration
qconfig_mapping = get_default_qconfig_mapping()
 
# or explicitly specify the backend
# qengine = 'x86'
# torch.backends.quantized.engine = qengine
# qconfig_mapping = get_default_qconfig_mapping(qengine)
 
# construct fp32 model
model_fp32 = ...
 
# prepare
prepared_model = prepare_fx(model_fp32, qconfig_mapping, example_inputs=x)
 
# calibrate
...
 
# convert
quantized_model = convert_fx(prepared_model)

Find more information: https://github.com/pytorch/pytorch/issues/83888 and https://www.intel.com/content/www/us/en/developer/articles/technical/accelerate-pytorch-int8-inf-with-new-x86-backend.html.

[Beta] GNN inference and training optimization on CPU

PyTorch 2.0 includes several critical optimizations to improve GNN inference and training performance on CPU. Before 2.0, GNN models of PyG suffers from low efficiency on CPU due to lack of performance tuning for several critical kernels (scatter/gather, etc) and the lack of GNN-related sparse matrix multiplication ops. To be specific, optimizations include:

  • scatter_reduce: performance hotspot in Message Passing when the edge index is stored in Coordinate format (COO).
  • gather: backward of scatter_reduce, specially tuned for the GNN compute when the index is an expanded tensor.
  • torch.sparse.mm with reduce flag: performance hotspot in Message Passing when the edge index is stored in Compressed Sparse Row (CSR). Supported reduce flag of: sum, mean, amax, amin.

On PyG benchmarks/examples, OGB benchmarks, a 1.12x – 4.07x performance speedup is measured (1.13.1 compared with 2.0) for single node inference and training.

Model-Dataset Option Speedup Ratio
GCN-Reddit (inference) 512-2-64-dense 1.22x
1024-3-128-dense 1.25x
512-2-64-sparse 1.31x
1024-3-128-sparse 1.68x
512-2-64-dense 1.22x
GraphSage-ogbn-products (inference) 1024-3-128-dense 1.15x
512-2-64-sparse 1.20x
1024-3-128-sparse 1.33x
full-batch-sparse 4.07x
GCN-PROTEINS (training) 3-32 1.67x
GCN-REDDIT-BINARY (training) 3-32 1.67x
GCN-Reddit (training) 512-2-64-dense 1.20x
1024-3-128-dense 1.12x

Learn more: PyG CPU Performance Optimization.

[Beta] Accelerating inference on CPU with PyTorch by leveraging oneDNN Graph

oneDNN Graph API extends oneDNN with a flexible graph API to maximize the optimization opportunity for generating efficient code on AI hardware.

  • It automatically identifies the graph partitions to be accelerated via fusion.
  • The fusion patterns focus on fusing compute-intensive operations such as convolution, matmul and their neighbor operations for both inference and training use cases.
  • Although work is ongoing to integrate oneDNN Graph with TorchDynamo as well, its integration with the PyTorch JIT Fuser attained beta status in PyTorch 2.0 for Float32 & BFloat16 inference (on machines that support AVX512_BF16 ISA).

From a developer’s/researcher’s perspective, the usage is quite simple & intuitive, with the only change in code being an API invocation:

  • Leverage oneDNN Graph, with JIT-tracing, a model is profiled with an example input.
  • The context manager with torch.jit.fuser(“fuser3”): can also be used instead of invoking torch.jit.enable_onednn_fusion(True).
  • For accelerating BFloat16 inference, we rely on eager-mode AMP (Automatic Mixed Precision) support in PyTorch & disable JIT mode’s AMP, as both of them are currently divergent:
# Assuming we have a model of the name 'model'
 
example_input = torch.rand(1, 3, 224, 224)
 
# enable oneDNN Graph
torch.jit.enable_onednn_fusion(True)
# Disable AMP for JIT
torch._C._jit_set_autocast_mode(False)
with torch.no_grad(), torch.cpu.amp.autocast():
	model = torch.jit.trace(model, (example_input))
	model = torch.jit.freeze(model)
 	# 2 warm-ups (2 for tracing/scripting with an example, 3 without an example)
	model(example_input)
	model(example_input)
 
	# speedup would be observed in subsequent runs.
	model(example_input)

Learn more here.

Prototype Features

Distributed API

[Prototype] DTensor

PyTorch DistributedTensor (DTensor) is a prototyping effort with distributed tensor primitives to allow easier distributed computation authoring in the SPMD (Single Program Multiple Devices) paradigm. The primitives are simple but powerful when used to express tensor distributions with both sharded and replicated parallelism strategies. PyTorch DTensor empowered PyTorch Tensor Parallelism along with other advanced parallelism explorations. In addition, it also offers a uniform way to save/load state_dict for distributed checkpointing purposes, even when there’re complex tensor distribution strategies such as combining tensor parallelism with parameter sharding in FSDP. More details can be found in this RFC and the DTensor examples notebook.

[Prototype] TensorParallel

We now support DTensor based Tensor Parallel which users can distribute their model parameters across different GPU devices. We also support Pairwise Parallel which shards two concatenated linear layers in a col-wise and row-wise style separately so that only one collective(all-reduce/reduce-scatter) is needed in the end. More details can be found in this example.

[Prototype] 2D Parallel

We implemented the integration of the aforementioned TP with FullyShardedDataParallel(FSDP) as 2D parallel to further scale large model training. More details can be found in this slide and code example.

[Prototype] torch.compile(dynamic=True)

Experimental support for PT2 compilation with dynamic shapes is available in this release. Inference compilation with inductor for simple models is supported, but there are a lot of limitations:

  • Training available in a future release (This is partially fixed in nightlies!)
  • Minifier available in a future release.
  • It is easy to end up in a situation where the dimension you wanted to be dynamic gets specialized anyway. Some of these issues are fixed in nightlies, others are not.
  • We do not appropriately propagate Inductor guards to the top-level, this is tracked at #96296.
  • Data-dependent operations like nonzero still require a graph break.
  • Dynamic does not work with non-standard modes like reduce-overhead or max-autotune.
  • There are many bugs in Inductor compilation. To track known bugs, check the dynamic shapes label on the PyTorch issue tracker.

For the latest and greatest news about dynamic shapes support on master, check out our status reports.

Highlights/Performance Improvements

Deprecation of Cuda 11.6 and Python 1.7 support for PyTorch 2.0

If you are still using or depending on CUDA 11.6 or Python 3.7 builds, we strongly recommend moving to at least CUDA 11.7 and Python 3.8, as it would be the minimum versions required for PyTorch 2.0. For more detail, please refer to the Release Compatibility Matrix for PyTorch releases.

Python 3.11 support on Anaconda Platform

Due to lack of Python 3.11 support for packages that PyTorch depends on, including NumPy, SciPy, SymPy, Pillow and others on the Anaconda platform. We will not be releasing Conda binaries compiled with Python 3.11 for PyTorch Release 2.0. The Pip packages with Python 3.11 support will be released, hence if you intend to use PyTorch 2.0 with Python 3.11 please use our Pip packages. Please note: Conda packages with Python 3.11 support will be made available on our nightly channel. Also we are planning on releasing Conda Python 3.11 binaries as part of future release once Anaconda provides these key dependencies. More information and instructions on how to download the Pip packages can be found here.

Optimized PyTorch Inference with AWS Graviton processors

The optimizations focused on three key areas: GEMM kernels, bfloat16 support, primitive caching and the memory allocator. For aarch64 platforms, PyTorch supports Arm Compute Library (ACL) GEMM kernels via Mkldnn(OneDNN) backend. The ACL library provides Neon/SVE GEMM kernels for fp32 and bfloat16 formats. The bfloat16 support on c7g allows efficient deployment of bfloat16 trained, AMP (Automatic Mixed Precision) trained, or even the standard fp32 trained models. The standard fp32 models leverage bfloat16 kernels via OneDNN fast math mode, without any model quantization. Next we implemented primitive caching for conv, matmul and inner product operators. More information on the updated PyTorch user guide with the upcoming 2.0 release improvements and TorchBench benchmark details can be found here.

Read More

New Library Updates in PyTorch 2.0

Summary

We are bringing a number of improvements to the current PyTorch libraries, alongside the PyTorch 2.0 release. These updates demonstrate our focus on developing common and extensible APIs across all domains to make it easier for our community to build ecosystem projects on PyTorch.

Along with 2.0, we are also releasing a series of beta updates to the PyTorch domain libraries, including those that are in-tree, and separate libraries including TorchAudio, TorchVision, and TorchText. An update for TorchX is also being released as it moves to community supported mode. Please find the list of the latest stable versions and updates below.

Latest Stable Library Versions (Full List)

TorchArrow 0.1.0 TorchRec 0.4.0 TorchVision 0.15
TorchAudio 2.0 TorchServe 0.7.1 TorchX 0.4.0
TorchData 0.6.0 TorchText 0.15.0 PyTorch on XLA Devices 1.14

*To see prior versions or (unstable) nightlies, click on versions in the top left menu above ‘Search Docs’.

TorchAudio

[Beta] Data augmentation operators

The release adds several data augmentation operators under torchaudio.functional and torchaudio.transforms:

  • torchaudio.functional.add_noise
  • torchaudio.functional.convolve
  • torchaudio.functional.deemphasis
  • torchaudio.functional.fftconvolve
  • torchaudio.functional.preemphasis
  • torchaudio.functional.speed
  • torchaudio.transforms.AddNoise
  • torchaudio.transforms.Convolve
  • torchaudio.transforms.Deemphasis
  • torchaudio.transforms.FFTConvolve
  • torchaudio.transforms.Preemphasis
  • torchaudio.transforms.Speed
  • torchaudio.transforms.SpeedPerturbation

The operators can be used to synthetically diversify training data to improve the generalizability of downstream models.

For usage details, please refer to the functional and transform documentation and Audio Data Augmentation tutorial.

[Beta] WavLM and XLS-R models

The release adds two self-supervised learning models for speech and audio.

  • WavLM that is robust to noise and reverberation.
  • XLS-R that is trained on cross-lingual datasets.

Besides the model architectures, torchaudio also supports corresponding pre-trained pipelines:

  • torchaudio.pipelines.WAVLM_BASE
  • torchaudio.pipelines.WAVLM_BASE_PLUS
  • torchaudio.pipelines.WAVLM_LARGE
  • torchaudio.pipelines.WAV2VEC_XLSR_300M
  • torchaudio.pipelines.WAV2VEC_XLSR_1B
  • torchaudio.pipelines.WAV2VEC_XLSR_2B

For usage details, please refer to the factory function and pre-trained pipelines documentation.

TorchRL

The initial release of torchrl includes several features that span across the entire RL domain. TorchRL can already be used in online, offline, multi-agent, multi-task and distributed RL settings, among others. See below:

[Beta] Environment wrappers and transforms

torchrl.envs includes several wrappers around common environment libraries. This allows users to swap one library with another without effort. These wrappers build an interface between these simulators and torchrl:

  • dm_control:
  • Gym
  • Brax
  • EnvPool
  • Jumanji
  • Habitat

It also comes with many commonly used transforms and vectorized environment utilities that allow for a fast execution across simulation libraries. Please refer to the documentation for more detail.

[Beta] Datacollectors

Data collection in RL is made easy via the usage of single process or multiprocessed/distributed data collectors that execute the policy in the environment over a desired duration and deliver samples according to the user’s needs. These can be found in torchrl.collectors and are documented here.

[Beta] Objective modules

Several objective functions are included in torchrl.objectives, among which:

  • A generic PPOLoss class and derived ClipPPOLoss and KLPPOLoss
  • SACLoss and DiscreteSACLoss
  • DDPGLoss
  • DQNLoss
  • REDQLoss
  • A2CLoss
  • TD3Loss
  • ReinforceLoss
  • Dreamer

Vectorized value function operators also appear in the library. Check the documentation here.

[Beta] Models and exploration strategies

We provide multiple models, modules and exploration strategies. Get a detailed description in the doc.

[Beta] Composable replay buffer

A composable replay buffer class is provided that can be used to store data in multiple contexts including single and multi-agent, on and off-policy and many more.. Components include:

  • Storages (list, physical or memory-based contiguous storages)
  • Samplers (Prioritized, sampler without repetition)
  • Writers
  • Possibility to add transforms

Replay buffers and other data utilities are documented here.

[Beta] Logging tools and trainer

We support multiple logging tools including tensorboard, wandb and mlflow.

We provide a generic Trainer class that allows for easy code recycling and checkpointing.

These features are documented here.

TensorDict

TensorDict is a new data carrier for PyTorch.

[Beta] TensorDict: specialized dictionary for PyTorch

TensorDict allows you to execute many common operations across batches of tensors carried by a single container. TensorDict supports many shape and device or storage operations, and can readily be used in distributed settings. Check the documentation to know more.

[Beta] @tensorclass: a dataclass for PyTorch

Like TensorDict, tensorclass provides the opportunity to write dataclasses with built-in torch features such as shape or device operations.

[Beta] tensordict.nn: specialized modules for TensorDict

The tensordict.nn module provides specialized nn.Module subclasses that make it easy to build arbitrarily complex graphs that can be executed with TensorDict inputs. It is compatible with the latest PyTorch features such as functorch, torch.fx and torch.compile.

TorchRec

[Beta] KeyedJaggedTensor All-to-All Redesign and Input Dist Fusion

We observed performance regression due to a bottleneck in sparse data distribution for models that have multiple, large KJTs to redistribute.

To combat this we altered the comms pattern to transport the minimum data required in the initial collective to support the collective calls for the actual KJT tensor data. This data sent in the initial collective, ‘splits’ means more data is transmitted over the comms stream overall, but the CPU is blocked for significantly shorter amounts of time leading to better overall QPS.

Furthermore, we altered the TorchRec train pipeline to group the initial collective calls for the splits together before launching the more expensive KJT tensor collective calls. This fusion minimizes the CPU blocked time as launching each subsequent input distribution is no longer dependent on the previous input distribution.

With this feature, variable batch sizes are now natively supported across ranks. These features are documented here.

TorchVision

[Beta] Extending TorchVision’s Transforms to Object Detection, Segmentation & Video tasks

TorchVision is extending its Transforms API! Here is what’s new:

  • You can use them not only for Image Classification but also for Object Detection, Instance & Semantic Segmentation and Video Classification.
  • You can use new functional transforms for transforming Videos, Bounding Boxes and Segmentation Masks.

Learn more about these new transforms from our docs, and submit any feedback in our dedicated issue.

TorchText

[Beta] Adding scriptable T5 and Flan-T5 to the TorchText library with incremental decoding support!

TorchText has added the T5 model architecture with pre-trained weights for both the original T5 paper and Flan-T5. The model is fully torchscriptable and features an optimized multiheaded attention implementation. We include several examples of how to utilize the model including summarization, classification, and translation.

For more details, please refer to our docs.

TorchX

TorchX is moving to community supported mode. More details will be coming in at a later time.

Read More