GPU performance improvement for TorchInductor vs eager-mode

Experience the power of PyTorch 2.0 on AMD Solutions

PyTorch 2.0 represents a significant step forward for the PyTorch machine learning framework. The stable release of PyTorch 2.0 brings new features that unlock even higher performance, while remaining backward compatible with prior releases and retaining the Pythonic focus which has helped to make PyTorch so enthusiastically adopted by the AI/ML community. AMD has long been a strong proponent of PyTorch, and we are delighted that PyTorch 2.0 stable release includes support for AMD Instinct™ and Radeon™ GPUs that are supported by the ROCm™ software platform.

Along with the stable PyTorch 2.0 release, the Beta includes torch.compile underpinned by TorchInductor with support for AMD Instinct and Radeon GPUs through OpenAI Triton deep learning compiler. Through TorchInductor, developers can now generate low level code using Triton that are portable and performant to hand-written kernels on native hardware centric kernel programming models.

Compilers like Triton can optimize the code generated by machine learning frameworks such as PyTorch for multiple AI accelerators including AMD Instinct GPU accelerator by leveraging hardware-specific features of the AMD CDNA™ GPU architecture. This makes it easy for developers and users to switch seamlessly from any HW to AMD Instinct GPU accelerators and get great out of the box performance.

In addition, compilers like Triton can also enable developers to use high-level programming languages, such as Python, to write machine learning code that can be efficiently compiled and executed on specialized hardware. This can help greatly improve the productivity of machine learning developers, as they can focus on the algorithmic aspects of their models and rely on the compiler to generate efficient code.

OpenAI Triton is a just-in-time (JIT) compiler that optimizes and accelerates the execution of deep learning models on various hardware architectures such as CPUs, GPUs, and ASICs. Here is a high-level overview

  1. Model Loading: The Triton server loads a trained deep learning model from a storage location, typically a file in a model format such as torchfx graphs.
  2. Graph Optimization: Triton optimizes the graph representation of the loaded model. This includes transformations such as common subexpression elimination, dead code elimination, and operator fusion, which can help reduce memory usage and computational overhead.
  3. Tensor Memory Allocation: Triton allocates memory for the tensors used by the model. This includes input and output tensors, as well as intermediate tensors created during computation.
  4. Hardware-Specific Optimization: Triton applies hardware-specific optimizations to the optimized graph representation of the model. These optimizations can include using low-level hardware instructions, optimizing data movement between different types of memory, and leveraging hardware-specific data structures that leverages domain specific architectures like CDNA on AMD Instinct GPUs
  5. Code Generation: Triton generates efficient machine code for the optimized graph representation of the model. This code can then be executed on the hardware platform for which it was optimized.
  6. Execution: Triton executes the generated code on the hardware platform, typically in a just-in-time fashion. Triton can also dynamically adjust the batch size and other parameters of the model during execution to maximize performance.
  7. Result Return: Triton returns the results of the computation to the client that requested the inference.

By design, PyTorch 2.0 is backward compatible to earlier PyTorch releases. This holds true for the ROCm build of PyTorch 2.0 as well. Developers using PyTorch with AMD GPUs can migrate to PyTorch 2.0 with the confidence that their existing code will continue to work without any required changes, so there is no penalty to access the improvements that come with this release. On the other hand, using PyTorch 2.0 and TorchInductor can result in significant performance improvement over the default eager-mode as shown below.

The initial results using AMD Instinct MI250 GPUs already shows strong performance improvement with minimal optimization on TorchInductor compared to the default eager-mode. We see an average performance increase of up to 1.54X on 44 out of the 45 models on HuggingFace benchmarks suite with CamemBert, DistillGPT2 and T5Small being a few of the standout models with up to 1.5X or more performance improvement over eager-mode. We are looking forward to continued engagement with members of the PyTorch team at Meta to enable further optimization on ROCm software stack and the additional performance improvement for future PyTorch releases.

GPU performance improvement for TorchInductor vs eager-mode

Image 1: AMD MI250 GPU performance improvement for TorchInductor vs eager-mode using HuggingFace MI200-89.

PyTorch 2.0 follows the same set of install options as before to build and install for supporting AMD GPUs. These include an installable Python package hosted at pytorch.org, AMD’s public PyTorch docker image, and of course the option to build from source using the upstream PyTorch repository. As with PyTorch builds for other platforms, the specific command line to be run for pip-based install is provided by the configurator at https://pytorch.org/get-started/locally/.

The GPUs supported by the ROCm software platform which forms the basis for PyTorch support on AMD GPUs are documented at https://docs.amd.com/bundle/Hardware_and_Software_Reference_Guide/page/Hardware_and_Software_Support.html

Conclusion

PyTorch 2.0 represents a major step in continuing to broaden support for ML developers by increasing performance while maintaining a simple, Pythonic interface. This performance uplift is made possible in large part by the new TorchInductor infrastructure, which in turn harnesses the Triton ML programming language and just-in-time compiler. AMD’s support for these technologies allows users to realize the full promise of the new PyTorch architecture. Our GPU support in PyTorch 2.0 is just one manifestation of a larger vision around AI and machine learning. AI/ML plays an important role in multiple AMD product lines, including Instinct and Radeon GPUs, Alveo™ data center accelerators, and both Ryzen™ and EPYC processors. These hardware and software initiatives are all part of AMD’s Pervasive AI vision, and we look forward to addressing the many new challenges and opportunities of this dynamic space.

MI200-89 – PyTorch Inductor mode HuggingFace Transformers training speedup, running the standard PyTorch 2.0 test suite, over PyTorch eager-mode comparison based on AMD internal testing on a single GCD as of 3/10/2023 using a 2P AMD EPYC™ 7763 production server with 4x AMD Instinct™ MI250 (128GB HBM2e) 560W GPUs with Infinity Fabric™ technology; host ROCm™ 5.3, guest ROCm™ 5.4.4, PyTorch 2.0.0, Triton 2.0. Server manufacturers may vary configurations, yielding different results. Performance may vary based on factors including use of latest drivers and optimizations.

© 2023 Advanced Micro Devices, Inc. All rights reserved. AMD, the AMD Arrow logo, AMD CDNA, AMD Instinct, EPYC, Radeon, ROCm, Ryzen, and combinations thereof are trademarks of Advanced Micro Devices, Inc. Other product names used in this publication are for identification purposes only and may be trademarks of their respective owners.

Read More

Accelerated PyTorch 2 Transformers

Accelerated PyTorch 2 Transformers

The PyTorch 2.0 release includes a new high-performance implementation of the PyTorch Transformer API with the goal of making training and deployment of state-of-the-art Transformer models affordable. Following the successful release of “fastpath” inference execution (“Better Transformer”), this release introduces high-performance support for training and inference using a custom kernel architecture for scaled dot product attention (SPDA).

You can take advantage of the new fused SDPA kernels either by calling the new SDPA operator directly (as described in the SDPA tutorial), or transparently via integration into the pre-existing PyTorch Transformer API. All features of the PyTorch Transformer API will continue to work compatibly, with many features mapped to high-performance SDPA kernels, while other features are impossible to support with higher performance (e.g., need_weights, as per below) while expanded high-performance support for other features may still be under active development.

Similar to the “fastpath” architecture, custom kernels are fully integrated into the PyTorch Transformer API – thus, using the native Transformer and MultiHeadAttention API will enable users to transparently see significant speed improvements. Unlike the “fastpath” architecture, the newly introduced “custom kernels” support many more use cases including models using Cross-Attention, Transformer Decoders, and for training models, in addition to the existing fastpath inference for fixed and variable sequence length Transformer Encoder and Self Attention use cases.

To take full advantage of different hardware models and Transformer use cases, multiple SDPA custom kernels are supported, with custom kernel selection logic that will pick the highest-performance kernel for a given model and hardware type. In particular, the first custom kernels included with the PyTorch 2.0 release are the Flash Attention kernel (sdpa_flash, for 16-bit floating point training and inference on Nvidia GPUs with SM80+ architecture level) and the xFormers memory-efficient attention kernel (sdpa_mem_eff, for 16-bit and 32-bit floating point training and inference on a broad range of Nvidia GPUs). A general-purpose kernel sdpa_math provides an implementation when the custom kernels are not applicable.

As mentioned, custom kernels provide a wider range of support for execution scenarios To ensure efficient execution (e,g., to use GPU tensor cores), model configurations need to meet a small number of requirements. This list of requirements will evolve over time, prospectively relaxing constraints limiting the usage of currently supported custom kernels, or providing additional kernels in the future.

For the most up to date list of custom kernels and dispatch constraints, you can refer to sdp_utils.h. As of PyTorch 2.0, the existing fused SDPA kernels have the following constraints:

  • Flash Attention only supports 16 bit floating point data types (float16 and bfloat16).
  • The head dimension must be a multiple of 8 for 16-bit floating point numbers and a multiple of 4 for 32-bit floating point numbers. At present, the maximum head_dim support for the Flash Attention custom kernel is 128.
  • The CUDA architecture level must be sm5x or better for the mem_efficient kernel, and sm80 for Flash Attention.
  • Flash Attention supports arbitrary dropout, in PyTorch 2.0 the mem_efficient kernel does not support dropout (i.e., dropout must be set to zero for this kernel to be selected in PyTorch 2.0).
  • To support variable-sequence length batches, all SDPA kernels support Nested Tensor inputs that combine input data and padding information using variable sequence length tensors for forward. (You can find more information about Nested Tensors in the Nested Tensor tutorial.)
  • You can specify both a key_padding_mask and an attn_mask by combining them before passing them to the SDPA operator. In particular, you can use the per-batch-element key padding mask of the nn.Transformer API to implement training for variable-sequence length inputs in a batch.
  • At present, the only attention mask supported by fused kernel implementation is the causal mask commonly used for training. To specify the causal mask in custom kernels, it must be specified with the is_causal boolean and attn_mask must be None.
  • Support for Nested Tensors is still under development. Specifically, in PyTorch 2.0, only the sdpa_math kernel supports training with Nested Tensors. Also, PyTorch 2.0 does not support Nested Tensors as part of code being compiled with torch.compile().
  • The SDPA operator does not support returning averaged attention weights because computing them defeats the optimizations that enabled fused kernels to execute more efficiently. The argument need_weights for torch.nn.MultiheadAttention’s forward function defaults to True. In order to use the fused kernels, need_weights needs to be set to need_weights=False.

We find that an attention mask is rarely used in real-world applications, except for the causal mask during training. Consequently, we reduce kernel complexity and compute cost by building in the option to use a causal mask as attention mask, and select this new capability with the is_causal parameter introduced in conjunction with the new SDPA operator.

Providing the is_causal Boolean flag for the frequently used causal mask also obviates the expensive and memory-intensive allocation of a causal mask, increasing training memory efficiency by allowing more memory to be used for large batch sizes, and reduce memory bandwidth and cache contention – which are both at a premium in GPU accelerators – by not needing to load an attention mask tensor.

If the constraints of none of the available custom kernels are met, then training falls back to using the default sdpa_math kernel, implementing the mathematical equations for scaled dot product attention using a sequence of PyTorch operator to implement SDPA. This is the most general “catch-all” fallback kernel to ensure successful training for all models.

In addition to the existing Transformer API, model developers may also use the scaled dot product attention kernels directly by calling the new scaled_dot_product_attention() operator. This operator may be used to efficiently implement multi-head attention by combining it with in-projection and outprojection, as described in the SDPA tutorial.

In addition to adding custom kernels, Accelerated PyTorch 2 Transformers are integrated with PyTorch 2.0 compilation. To use your model while benefiting from the additional acceleration of PT2-compilation (for inference or training), pre-process the model with

model = torch.compile(model)

We have achieved major speedups for training transformer models and in particular large language models with Accelerated PyTorch 2 Transformers using a combination of custom kernels and torch.compile().

Better Transformer chartFigure: Using scaled dot product attention with custom kernels and torch.compile delivers significant speedups for training large language models, such as for nanoGPT shown here.

Finally, because the custom kernels are much more memory efficient, try to increase the size of training batches to achieve faster training with increased batch size.

In addition to automatic kernel selection, a context manager enables developers to override the kernel selection algorithm – this is not required for day to day operation, but enables developers to debug their code as well as enable performance engineers to override kernel selection. The SDPA tutorial provides additional information on using the SDPA context manager.

In addition to availability as part of the nn.Transformer API, Accelerated PyTorch 2 Transformer custom kernels are also available in conjunction with the torchtext, torchvision, and fairseq domain libraries with the launch of PyTorch 2.0.

Read More

Inference Speedup - PyTorch/XLA Dynamo on TPU

PyTorch 2.0 & XLA—The Latest Cutting Edge Features

Today, we are excited to share our latest work for PyTorch/XLA 2.0. The release of PyTorch 2.0 is yet another major milestone for this storied community and we are excited to continue to be part of it. When the PyTorch/XLA project started in 2018 between Google and Meta, the focus was on bringing cutting edge Cloud TPUs to help support the PyTorch community. Along the way, others in the community such as Amazon joined the project and very quickly the community expanded. We are excited about XLA’s direction and the benefits this project continues to bring to the PyTorch community. In this blog we’d like to showcase some key features that have been in development, show code snippets, and illustrate the benefit through some benchmarks.

TorchDynamo / torch.compile (Experimental)

TorchDynamo (Dynamo) is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook in; its biggest feature is to dynamically modify Python bytecode just before execution. In the PyTorch/XLA 2.0 release, an experimental backend for Dynamo is provided for both inference and training.

Dynamo provides a Torch FX (FX) graph when it recognizes a model pattern and PyTorch/XLA uses a Lazy Tensor approach to compile the FX graph and return the compiled function. To get more insight regarding the technical details about PyTorch/XLA’s dynamo implementation, check out this dev-discuss post and dynamo doc.

Here is a small code example of running ResNet18 with torch.compile:

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
      xla_resnet18, backend='torchxla_trace_once')
  for data, _ in loader:
    output = dynamo_resnet18(data)

With torch.compile PyTorch/XLA only traces the ResNet18 model once during the init time and executes the compiled binary everytime dynamo_resnet18 is invoked, instead of tracing the model every step. To illustrate the benefits of Dynamo+XLA, below is an inference speedup analysis to compare Dynamo and LazyTensor (without Dynamo) using TorchBench on a Cloud TPU v4-8 where the y-axis is the speedup multiplier.

Inference Speedup - PyTorch/XLA Dynamo on TPU

Dynamo for training is in the development stage with its implementation being at an earlier stage than inference. Developers are welcome to test this early feature, however, in the 2.0 release, PyTorch/XLA supports the forward and backward pass graphs and not the optimizer graph; the optimizer graph is available in the nightly builds and will land in the PyTorch/XLA 2.1 release. Below is an example of what training looks like using the ResNet18 example with torch.compile:

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='aot_torchxla_trace_once')
  for data, target in loader:
    output = dynamo_train_model(xla_resnet18, data, target)

Note that the backend for training is aot_torchxla_trace_once (API will be updated for stable release) whereas the inference backend is torchxla_trace_once (name subject to change). We expect to extract and execute 3 graphs per training step instead of 1 training step if you use the Lazy tensor. Below is a training speedup analysis to compare Dynamo and Lazy using the TorchBench on Cloud TPU v4-8.

Training Speedup - PyTorch/XLA Dynamo on TPU

PJRT Runtime (Beta)

PyTorch/XLA is migrating from XRT to the new PJRT runtime. PJRT is a better-maintained stack, with demonstrated performance advantages, including, on average, a 35% performance for training on TorchBench 2.0 models. It also supports a richer set of features enabling technologies like SPMD. In the PyTorch/XLA 2.0 release, PJRT is the default runtime for TPU and CPU; GPU support is in experimental state. The PJRT features included in the PyTorch/XLA 2.0 release are:

  • TPU runtime implementation in libtpu using the PJRT Plugin API improves performance by up to 30%
  • torch.distributed support for TPU v2 and v3, including pjrt:// init_method (Experimental)
  • Single-host GPU support. Multi-host support coming soon. (Experimental)

Switching to PJRT requires no change (or minimal change for GPUs) to user code (see pjrt.md for more details). Runtime configuration is as simple as setting the PJRT_DEVICE environment variable to the local device type (i.e. TPU, GPU, CPU). Below are examples of using PJRT runtimes on different devices.

# TPU Device
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
# TPU Pod Device
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git"

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
# GPU Device (Experimental)
PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

Below is a performance comparison between XRT and PJRT by task on TorchBench 2.0 on v4-8 TPU. To learn more about PJRT vs. XRT please review the documentation.

TorchBench Training Time

Parallelization

GSPMD (Experimental)

We are delighted to introduce General and Scalable Parallelization for ML Computation Graphs (GSPMD) in PyTorch as a new experimental data & model sharding solution. GSPMD provides automatic parallelization for common ML workloads, allowing developers to write PyTorch programs as if on a single large device and without custom sharded computation ops and/or collective communication ops. The XLA compiler transforms the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. The API (RFC) will be available in the PyTorch/XLA 2.0 release as an experimental feature on a single TPU VM host.

Next Steps for GSPMD

GSPMD is experimental in 2.0 release. To bring it to Stable status, we plan to address a number of feature gaps and known issues in the following releases, including multi-host support, DTensor integration, partial replication sharding, asynchronous data loading, and checkpointing.

FSDP (Beta)

PyTorch/XLA introduced fully sharded data parallel (FSDP) experimental support in version 1.12. This feature is a parallel representation of PyTorch FSDP and there are subtle differences in how XLA and upstream CUDA kernels are set up. auto_wrap_policy is a new argument that enables developers to automatically specify conditions for propagating partitioning specifications to neural network submodules. auto_wrap_policys may be simply passed in as an argument when wrapping a model with FSDP. Two auto_wrap_policy callables worth noting are: size_based_auto_wrap_policy, transformer_auto_wrap_policy.

size_based_auto_wrap_policy enables users to wrap submodules with a minimum number of parameters. The example below wraps model submodules having at least 10M parameters.

auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=1e7)

transformer_auto_wrap_policy enables users to wrap all submodules that match a specific layer type. The example below wraps model submodules named torch.nn.Conv2d. To learn more, review this ResNet example by Ronghang Hu.

auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

PyTorch/XLA FSDP is now integrated in HuggingFace trainer class (PR) enabling users to train much larger models on PyTorch/XLA (official Hugging Face documentation). A 16B parameters GPT2 model trained on Cloud TPU v4-64 with this FSDP configuration achieved 39% hardware utilization.

TPU Accelerator – Num Devices v4-64
GPT2 Parameter Count 16B
Layers Wrapped with FSDP GPT2Block
TFLOPs / Chip 275
PFLOPs / Step 50
Hardware Utilization 39%

Differences Between FSDP & GSPMD

FSDP is a data parallelism technique that reduces device memory footprint by storing model parameters, optimizer states, and gradients all sharded. Note that the actual computation is still local to the device and requires all-gathering the sharded model parameters for both forward and backward passes, hence the name “data parallel”. FSDP is one of the newest additions to PyTorch/XLA to scale large model training.

GSPMD on the other hand, is a general parallelization system that enables various types of parallelisms, including both data and model parallelisms. PyTorch/XLA provides a sharding annotation API and XLAShardedTensor abstraction, so a user can annotate any tensor with sharding specs in the PyTorch program. Developers don’t need to manually implement sharded computations or inject collective communications ops to get it right. The XLA compiler does the work so that each computation can run in a distributed manner on multiple devices.

Examples & Preliminary Results

To learn about PyTorch/XLA parallelism sharding API, visit our RFC and see the Sample Code references. Below is a simple example to enable data and model parallelism.

model = SimpleLinear().to(xm.xla_device())
# Sharding annotate the linear layer weights.
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
# Training loop
model.train()
for step, (data, target) in enumerate(loader):
  optimizer.zero_grad()
  data = data.to(xm.xla_device())
  target = target.to(xm.xla_device())
  # Sharding annotate input data, we can shard any input
  # dimensions. Sharidng the batch dimension enables 
  # data parallelism, sharding the feature dimension enables
  # spatial partitioning.
  xs.mark_sharding(data, mesh, partition_spec)
  ouput = model(data)
  loss = loss_fn(output, target)
  optimizer.step()
  xm.mark_step()

The following graph highlights the memory efficiency benefits of PyTorch/XLA FSDP and SPMD on Cloud TPU v4-8 running ResNet50.

Batch Size Scaling with Spatial Partitioning

Closing Thoughts…

We are excited to bring these features to the PyTorch community, and this is really just the beginning. Areas like dynamic shapes, deeper support for OpenXLA and many others are in development and we plan to put out more blogs to dive into the details. PyTorch/XLA is developed fully open source and we invite you to join the community of developers by filing issues, submitting pull requests, and sending RFCs on GitHub. You can try PyTorch/XLA on a variety of XLA devices including TPUs and GPUs. Here is how to get started.

Congratulations again to the PyTorch community on this milestone!

Cheers,

The PyTorch Team at Google

Read More

Accelerated Diffusers with PyTorch 2.0

Accelerated Diffusers with PyTorch 2.0

PyTorch 2.0 has just been released. Its flagship new feature is torch.compile(), a one-line code change that promises to automatically improve performance across codebases. We have previously checked on that promise in Hugging Face Transformers and TIMM models, and delved deep into its motivation, architecture and the road ahead.

As important as torch.compile() is, there’s much more to PyTorch 2.0. Notably, PyTorch 2.0 incorporates several strategies to accelerate transformer blocks, and these improvements are very relevant for diffusion models too. Techniques such as FlashAttention, for example, have become very popular in the diffusion community thanks to their ability to significantly speed up Stable Diffusion and achieve larger batch sizes, and they are now part of PyTorch 2.0.

In this post we discuss how attention layers are optimized in PyTorch 2.0 and how these optimization are applied to the popular 🧨 Diffusers library. We finish with a benchmark that shows how the use of PyTorch 2.0 and Diffusers immediately translates to significant performance improvements across different hardware.

Accelerating transformer blocks

PyTorch 2.0 includes a scaled dot-product attention function as part of torch.nn.functional. This function encompasses several implementations that can be applied depending on the inputs and the hardware in use. Before PyTorch 2.0, you had to search for third-party implementations and install separate packages in order to take advantage of memory optimized algorithms, such as FlashAttention. The available implementations are:

  • FlashAttention, from the official FlashAttention project.
  • Memory-Efficient Attention, from the xFormers project.
  • A native C++ implementation suitable for non-CUDA devices or when high-precision is required.

All these methods are available by default, and PyTorch will try to select the optimal one automatically through the use of the new scaled dot-product attention (SDPA) API. You can also individually toggle them for finer-grained control, see the documentation for details.

Using scaled dot-product attention in diffusers

The incorporation of Accelerated PyTorch 2.0 Transformer attention to the Diffusers library was achieved through the use of the set_attn_processor method, which allows for pluggable attention modules to be configured. In this case, a new attention processor was created, which is enabled by default when PyTorch 2.0 is available. For clarity, this is how you could enable it manually (but it’s usually not necessary since diffusers will automatically take care of it):

from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.to("cuda")
pipe.unet.set_attn_processor(AttnProcessor2_0())

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

Stable Diffusion Benchmark

We ran a number of tests using accelerated dot-product attention from PyTorch 2.0 in Diffusers. We installed diffusers from pip and used nightly versions of PyTorch 2.0, since our tests were performed before the official release. We also used torch.set_float32_matmul_precision('high') to enable additional fast matrix multiplication algorithms.

We compared results with the traditional attention implementation in diffusers (referred to as vanilla below) as well as with the best-performing solution in pre-2.0 PyTorch: PyTorch 1.13.1 with the xFormers package (v0.0.16) installed.

Results were measured without compilation (i.e., no code changes at all), and also with a single call to torch.compile() to wrap the UNet module. We did not compile the image decoder because most of the time is spent in the 50 denoising iterations that run UNet evaluations.

Results in float32

Diffusers Speedup vs xFormers float32

The following figures explore performance improvement vs batch size for various representative GPUs belonging to different generations. We collected data for each combination until we reached maximum memory utilization. Vanilla attention runs out of memory earlier than xFormers or PyTorch 2.0, which explains the missing bars for larger batch sizes. Similarly, A100 (we used the 40 GB version) is capable of running batch sizes of 64, but the other GPUs could only reach 32 in our tests.

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float32)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (V100, float32)

We found very significant performance improvements over vanilla attention across the board, without even using torch.compile(). An out of the box installation of PyTorch 2.0 and diffusers yields about 50% speedup on A100 and between 35% and 50% on 4090 GPUs, depending on batch size. Performance improvements are more pronounced for modern CUDA architectures such as Ada (4090) or Ampere (A100), but they are still very significant for older architectures still heavily in use in cloud services.

In addition to faster speeds, the accelerated transformers implementation in PyTorch 2.0 allows much larger batch sizes to be used. A single 40GB A100 GPU runs out of memory with a batch size of 10, and 24 GB high-end consumer cards such as 3090 and 4090 cannot generate 8 images at once. Using PyTorch 2.0 and diffusers we could achieve batch sizes of 48 for 3090 and 4090, and 64 for A100. This is of great significance for cloud services and applications, as they can efficiently process more images at a time.

When compared with PyTorch 1.13.1 + xFormers, the new accelerated transformers implementation is still faster and requires no additional packages or dependencies. In this case we found moderate speedups of up to 2% on datacenter cards such as A100 or T4, but performance was great on the two last generations of consumer cards: up to 20% speed improvement on 3090 and between 10% and 45% on 4090, depending on batch size.

When torch.compile() is used, we get an additional performance boost of (typically) 2% and 3% over the previous improvements. As compilation takes some time, this is better geared towards user-facing inference services or training.

Results in float16

Diffusers Speedup vs xFormers float16

Diffusers Inference Speedup vs Vanilla and xFormers Attention (A100, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (4090, float16)

Diffusers Inference Speedup vs Vanilla and xFormers Attention (3090, float16)

When we consider float16 inference, the performance improvements of the accelerated transformers implementation in PyTorch 2.0 are between 20% and 28% over standard attention, across all the GPUs we tested, except for the 4090, which belongs to the more modern Ada architecture. This GPU benefits from a dramatic performance improvement when using PyTorch 2.0 nightlies. With respect to optimized SDPA vs xFormers, results are usually on par for most GPUs, except again for the 4090. Adding torch.compile() to the mix boosts performance a few more percentage points across the board.

Conclusions

PyTorch 2.0 comes with multiple features to optimize the crucial components of the foundational transformer block, and they can be further improved with the use of torch.compile. These optimizations lead to significant memory and time improvements for diffusion models, and remove the need for third-party library installations.

To take advantage of these speed and memory improvements all you have to do is upgrade to PyTorch 2.0 and use diffusers >= 0.13.0.

For more examples and in-detail benchmark numbers, please also have a look at the Diffusers with PyTorch 2.0 docs.

Acknowledgement

The authors are grateful to the PyTorch team for their insights, assistance and suggestions during the elaboration of this post, and for creating such excellent software. We are particularly indebted to Hamid Shojanazeri, Grigory Sizov, Christian Puhrsch, Driss Guessous, Michael Gschwind and Geeta Chauhan.

Read More

PyTorch 2.0: Our next generation release that is faster, more Pythonic and Dynamic as ever

PyTorch 2.0: Our next generation release that is faster, more Pythonic and Dynamic as ever

We are excited to announce the release of PyTorch® 2.0 which we highlighted during the PyTorch Conference on 12/2/22! PyTorch 2.0 offers the same eager-mode development and user experience, while fundamentally changing and supercharging how PyTorch operates at compiler level under the hood with faster performance and support for Dynamic Shapes and Distributed.

This next-generation release includes a Stable version of Accelerated Transformers (formerly called Better Transformers); Beta includes torch.compile as the main API for PyTorch 2.0, the scaled_dot_product_attention function as part of torch.nn.functional, the MPS backend, functorch APIs in the torch.func module; and other Beta/Prototype improvements across various inferences, performance and training optimization features on GPUs and CPUs. For a comprehensive introduction and technical overview of torch.compile, please visit the 2.0 Get Started page.

Along with 2.0, we are also releasing a series of beta updates to the PyTorch domain libraries, including those that are in-tree, and separate libraries including TorchAudio, TorchVision, and TorchText. An update for TorchX is also being released as it moves to community supported mode. More details can be found in this library blog.

This release is composed of over 4,541 commits and 428 contributors since 1.13.1. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.0 and the overall 2-series this year.

Summary:

  • torch.compile is the main API for PyTorch 2.0, which wraps your model and returns a compiled model. It is a fully additive (and optional) feature and hence 2.0 is 100% backward compatible by definition.
  • As an underpinning technology of torch.compile, TorchInductor with Nvidia and AMD GPUs will rely on OpenAI Triton deep learning compiler to generate performant code and hide low level hardware details. OpenAI Triton-generated kernels achieve performance that’s on par with hand-written kernels and specialized cuda libraries such as cublas.
  • Accelerated Transformers introduce high-performance support for training and inference using a custom kernel architecture for scaled dot product attention (SPDA). The API is integrated with torch.compile() and model developers may also use the scaled dot product attention kernels directly by calling the new scaled_dot_product_attention() operator.
  • Metal Performance Shaders (MPS) backend provides GPU accelerated PyTorch training on Mac platforms with added support for Top 60 most used ops, bringing coverage to over 300 operators.
  • Amazon AWS optimizes the PyTorch CPU inference on AWS Graviton3 based C7g instances. PyTorch 2.0 improves inference performance on Graviton compared to the previous releases, including improvements for Resnet50 and Bert.
  • New prototype features and technologies across TensorParallel, DTensor, 2D parallel, TorchDynamo, AOTAutograd, PrimTorch and TorchInductor.
Stable Beta Prototype Performance Improvements

Accelerated PT 2 Transformers

torch.compile

DTensor

CUDA support for 11.7 & 11.8 (deprecating CUDA 11.6)

PyTorch MPS Backend

TensorParallel

Python 1.8 (deprecating Python 1.7)

Scaled dot product attention

2D Parallel

AWS Graviton3

functorch

Torch.compile (dynamic=True)

Dispatchable Collectives
Torch.set_default & torch.device

X86 quantization backend

GNN inference and training performance

*To see a full list of public 2.0, 1.13 and 1.12 feature submissions click here.

Stable Features

[Stable] Accelerated PyTorch 2 Transformers (previously known as “Better Transformer”)

The PyTorch 2.0 release includes a new high-performance implementation of the PyTorch Transformer API, formerly known as “Better Transformer API, “ now renamed Accelerated PyTorch 2 Transformers. In releasing accelerated PT2 Transformers, our goal is to make training and deployment of state-of-the-art Transformer models affordable across the industry. This release introduces high-performance support for training and inference using a custom kernel architecture for scaled dot product attention (SPDA).

Similar to the “fastpath” architecture, custom kernels are fully integrated into the PyTorch Transformer API – thus, using the native Transformer and MultiHeadAttention API will enable users to:

  • transparently see significant speed improvements;
  • support many more use cases including models using Cross-Attention, Transformer Decoders, and for training models; and
  • continue to use fastpath inference for fixed and variable sequence length Transformer Encoder and Self Attention use cases.

To take full advantage of different hardware models and Transformer use cases, multiple SDPA custom kernels are supported (see below), with custom kernel selection logic that will pick the highest-performance kernel for a given model and hardware type. In addition to the existing Transformer API, model developers may also use the

scaled dot product attention kernels directly by calling the new scaled_dot_product_attention() operator. Accelerated PyTorch 2 Transformers are integrated with torch.compile() . To use your model while benefiting from the additional acceleration of PT2-compilation (for inference or training), pre-process the model with model = torch.compile(model).

We have achieved major speedups for training transformer models and in particular large language models with Accelerated PyTorch 2 Transformers using a combination of custom kernels and torch.compile().

alt_textFigure: Using scaled dot product attention with custom kernels and torch.compile delivers significant speedups for training large language models, such as for nanoGPT shown here.

Beta Features

[Beta] torch.compile

torch.compile is the main API for PyTorch 2.0, which wraps your model and returns a compiled model. It is a fully additive (and optional) feature and hence 2.0 is 100% backward compatible by definition.

Underpinning torch.compile are new technologies – TorchDynamo, AOTAutograd, PrimTorch and TorchInductor:

  • TorchDynamo captures PyTorch programs safely using Python Frame Evaluation Hooks and is a significant innovation that was a result of 5 years of our R&D into safe graph capture.
  • AOTAutograd overloads PyTorch’s autograd engine as a tracing autodiff for generating ahead-of-time backward traces.
  • PrimTorch canonicalizes ~2000+ PyTorch operators down to a closed set of ~250 primitive operators that developers can target to build a complete PyTorch backend. This substantially lowers the barrier of writing a PyTorch feature or backend.
  • TorchInductor is a deep learning compiler that generates fast code for multiple accelerators and backends. For NVIDIA and AMD GPUs, it uses OpenAI Triton as a key building block. For intel CPUs, we generate C++ code using multithreading, vectorized instructions and offloading appropriate operations to mkldnn when possible.

With all the new technologies, torch.compile is able to work 93% of time across 165 open-source models and runs 20% faster on average at float32 precision and 36% faster on average at AMP precision.

For more information, please refer to https://pytorch.org/get-started/pytorch-2.0/ and for TorchInductor CPU with Intel here.

[Beta] PyTorch MPS Backend

MPS backend provides GPU-accelerated PyTorch training on Mac platforms. This release brings improved correctness, stability, and operator coverage.

MPS backend now includes support for the Top 60 most used ops, along with the most frequently requested operations by the community, bringing coverage to over 300 operators. The major focus of the release was to enable full OpInfo-based forward and gradient mode testing to address silent correctness issues. These changes have resulted in wider adoption of MPS backend by 3rd party networks such as Stable Diffusion, YoloV5, WhisperAI, along with increased coverage for Torchbench networks and Basic tutorials. We encourage developers to update to the latest macOS release to see the best performance and stability on the MPS backend.

Links

  1. MPS Backend
  2. Developer information
  3. Accelerated PyTorch training on Mac
  4. Metal, Metal Performance Shaders & Metal Performance Shaders Graph

[Beta] Scaled dot product attention 2.0

We are thrilled to announce the release of PyTorch 2.0, which introduces a powerful scaled dot product attention function as part of torch.nn.functional. This function includes multiple implementations that can be seamlessly applied depending on the input and hardware in use.

In previous versions of PyTorch, you had to rely on third-party implementations and install separate packages to take advantage of memory-optimized algorithms like FlashAttention. With PyTorch 2.0, all these implementations are readily available by default.

These implementations include FlashAttention from HazyResearch, Memory-Efficient Attention from the xFormers project, and a native C++ implementation that is ideal for non-CUDA devices or when high-precision is required.

PyTorch 2.0 will automatically select the optimal implementation for your use case, but you can also toggle them individually for finer-grained control. Additionally, the scaled dot product attention function can be used to build common transformer architecture components.

Learn more with the documentation and this tutorial.

[Beta] functorch -> torch.func

Inspired by Google JAX, functorch is a library that offers composable vmap (vectorization) and autodiff transforms. It enables advanced autodiff use cases that would otherwise be tricky to express in PyTorch. Examples include:

We’re excited to announce that, as the final step of upstreaming and integrating functorch into PyTorch, the functorch APIs are now available in the torch.func module. Our function transform APIs are identical to before, but we have changed how the interaction with NN modules work. Please see the docs and the migration guide for more details.

Furthermore, we have added support for torch.autograd.Function: one is now able to apply function transformations (e.g. vmap, grad, jvp) over torch.autograd.Function.

[Beta] Dispatchable Collectives

Dispatchable collectives is an improvement to the existing init_process_group() API which changes backend to an optional argument. For users, the main advantage of this feature is that it will allow them to write code that can run on both GPU and CPU machines without having to change the backend specification. The dispatchability feature will also make it easier for users to support both GPU and CPU collectives, as they will no longer need to specify the backend manually (e.g. “NCCL” or “GLOO”). Existing backend specifications by users will be honored and will not require change.

Usage example:

import torch.distributed.dist
…
# old
dist.init_process_group(backend=”nccl”, ...)
dist.all_reduce(...) # with CUDA tensors works
dist.all_reduce(...) # with CPU tensors does not work

# new
dist.init_process_group(...) # backend is optional
dist.all_reduce(...) # with CUDA tensors works
dist.all_reduce(...) # with CPU tensors works

Learn more here.

[Beta] torch.set_default_device and torch.device as context manager

torch.set_default_device allows users to change the default device that factory functions in PyTorch allocate on. For example, if you torch.set_default_device(‘cuda’), a call to torch.empty(2) will allocate on CUDA (rather than on CPU). You can also use torch.device as a context manager to change the default device on a local basis. This resolves a long standing feature request from PyTorch’s initial release for a way to do this.

Learn more here.

[Beta] “X86” as the new default quantization backend for x86 CPU

The new X86 quantization backend, which utilizes FBGEMM and oneDNN kernel libraries, replaces FBGEMM as the default quantization backend for x86 CPU platforms and offers improved int8 inference performance compared to the original FBGEMM backend, leveraging the strengths of both libraries, with 1.3X – 2X inference performance speedup measured on 40+ deep learning models. The new backend is functionally compatible with the original FBGEMM backend.

Table: Geomean Speedup of X86 Quantization Backend vs. FBGEMM Backend

1 core/instance 2 cores/instance 4 cores/instance 1 socket (32 cores)/instance
Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz 1.76X 1.80X 2.04X 1.34X

By default, users on x86 platforms will utilize the x86 quantization backend and their PyTorch programs will remain unchanged when using the default backend. Alternatively, users have the option to specify “X86” as the quantization backend explicitly. Example code is show below:

import torch
from torch.ao.quantization import get_default_qconfig_mappingfrom torch.quantization.quantize_fx
import prepare_fx, convert_fx
 
# get default configuration
qconfig_mapping = get_default_qconfig_mapping()
 
# or explicitly specify the backend
# qengine = 'x86'
# torch.backends.quantized.engine = qengine
# qconfig_mapping = get_default_qconfig_mapping(qengine)
 
# construct fp32 model
model_fp32 = ...
 
# prepare
prepared_model = prepare_fx(model_fp32, qconfig_mapping, example_inputs=x)
 
# calibrate
...
 
# convert
quantized_model = convert_fx(prepared_model)

Find more information: https://github.com/pytorch/pytorch/issues/83888 and https://www.intel.com/content/www/us/en/developer/articles/technical/accelerate-pytorch-int8-inf-with-new-x86-backend.html.

[Beta] GNN inference and training optimization on CPU

PyTorch 2.0 includes several critical optimizations to improve GNN inference and training performance on CPU. Before 2.0, GNN models of PyG suffers from low efficiency on CPU due to lack of performance tuning for several critical kernels (scatter/gather, etc) and the lack of GNN-related sparse matrix multiplication ops. To be specific, optimizations include:

  • scatter_reduce: performance hotspot in Message Passing when the edge index is stored in Coordinate format (COO).
  • gather: backward of scatter_reduce, specially tuned for the GNN compute when the index is an expanded tensor.
  • torch.sparse.mm with reduce flag: performance hotspot in Message Passing when the edge index is stored in Compressed Sparse Row (CSR). Supported reduce flag of: sum, mean, amax, amin.

On PyG benchmarks/examples, OGB benchmarks, a 1.12x – 4.07x performance speedup is measured (1.13.1 compared with 2.0) for single node inference and training.

Model-Dataset Option Speedup Ratio
GCN-Reddit (inference) 512-2-64-dense 1.22x
1024-3-128-dense 1.25x
512-2-64-sparse 1.31x
1024-3-128-sparse 1.68x
512-2-64-dense 1.22x
GraphSage-ogbn-products (inference) 1024-3-128-dense 1.15x
512-2-64-sparse 1.20x
1024-3-128-sparse 1.33x
full-batch-sparse 4.07x
GCN-PROTEINS (training) 3-32 1.67x
GCN-REDDIT-BINARY (training) 3-32 1.67x
GCN-Reddit (training) 512-2-64-dense 1.20x
1024-3-128-dense 1.12x

Learn more: PyG CPU Performance Optimization.

[Beta] Accelerating inference on CPU with PyTorch by leveraging oneDNN Graph

oneDNN Graph API extends oneDNN with a flexible graph API to maximize the optimization opportunity for generating efficient code on AI hardware.

  • It automatically identifies the graph partitions to be accelerated via fusion.
  • The fusion patterns focus on fusing compute-intensive operations such as convolution, matmul and their neighbor operations for both inference and training use cases.
  • Although work is ongoing to integrate oneDNN Graph with TorchDynamo as well, its integration with the PyTorch JIT Fuser attained beta status in PyTorch 2.0 for Float32 & BFloat16 inference (on machines that support AVX512_BF16 ISA).

From a developer’s/researcher’s perspective, the usage is quite simple & intuitive, with the only change in code being an API invocation:

  • Leverage oneDNN Graph, with JIT-tracing, a model is profiled with an example input.
  • The context manager with torch.jit.fuser(“fuser3”): can also be used instead of invoking torch.jit.enable_onednn_fusion(True).
  • For accelerating BFloat16 inference, we rely on eager-mode AMP (Automatic Mixed Precision) support in PyTorch & disable JIT mode’s AMP, as both of them are currently divergent:
# Assuming we have a model of the name 'model'
 
example_input = torch.rand(1, 3, 224, 224)
 
# enable oneDNN Graph
torch.jit.enable_onednn_fusion(True)
# Disable AMP for JIT
torch._C._jit_set_autocast_mode(False)
with torch.no_grad(), torch.cpu.amp.autocast():
	model = torch.jit.trace(model, (example_input))
	model = torch.jit.freeze(model)
 	# 2 warm-ups (2 for tracing/scripting with an example, 3 without an example)
	model(example_input)
	model(example_input)
 
	# speedup would be observed in subsequent runs.
	model(example_input)

Learn more here.

Prototype Features

Distributed API

[Prototype] DTensor

PyTorch DistributedTensor (DTensor) is a prototyping effort with distributed tensor primitives to allow easier distributed computation authoring in the SPMD (Single Program Multiple Devices) paradigm. The primitives are simple but powerful when used to express tensor distributions with both sharded and replicated parallelism strategies. PyTorch DTensor empowered PyTorch Tensor Parallelism along with other advanced parallelism explorations. In addition, it also offers a uniform way to save/load state_dict for distributed checkpointing purposes, even when there’re complex tensor distribution strategies such as combining tensor parallelism with parameter sharding in FSDP. More details can be found in this RFC and the DTensor examples notebook.

[Prototype] TensorParallel

We now support DTensor based Tensor Parallel which users can distribute their model parameters across different GPU devices. We also support Pairwise Parallel which shards two concatenated linear layers in a col-wise and row-wise style separately so that only one collective(all-reduce/reduce-scatter) is needed in the end. More details can be found in this example.

[Prototype] 2D Parallel

We implemented the integration of the aforementioned TP with FullyShardedDataParallel(FSDP) as 2D parallel to further scale large model training. More details can be found in this slide and code example.

[Prototype] torch.compile(dynamic=True)

Experimental support for PT2 compilation with dynamic shapes is available in this release. Inference compilation with inductor for simple models is supported, but there are a lot of limitations:

  • Training available in a future release (This is partially fixed in nightlies!)
  • Minifier available in a future release.
  • It is easy to end up in a situation where the dimension you wanted to be dynamic gets specialized anyway. Some of these issues are fixed in nightlies, others are not.
  • We do not appropriately propagate Inductor guards to the top-level, this is tracked at #96296.
  • Data-dependent operations like nonzero still require a graph break.
  • Dynamic does not work with non-standard modes like reduce-overhead or max-autotune.
  • There are many bugs in Inductor compilation. To track known bugs, check the dynamic shapes label on the PyTorch issue tracker.

For the latest and greatest news about dynamic shapes support on master, check out our status reports.

Highlights/Performance Improvements

Deprecation of Cuda 11.6 and Python 1.7 support for PyTorch 2.0

If you are still using or depending on CUDA 11.6 or Python 3.7 builds, we strongly recommend moving to at least CUDA 11.7 and Python 3.8, as it would be the minimum versions required for PyTorch 2.0. For more detail, please refer to the Release Compatibility Matrix for PyTorch releases.

Python 3.11 support on Anaconda Platform

Due to lack of Python 3.11 support for packages that PyTorch depends on, including NumPy, SciPy, SymPy, Pillow and others on the Anaconda platform. We will not be releasing Conda binaries compiled with Python 3.11 for PyTorch Release 2.0. The Pip packages with Python 3.11 support will be released, hence if you intend to use PyTorch 2.0 with Python 3.11 please use our Pip packages. Please note: Conda packages with Python 3.11 support will be made available on our nightly channel. Also we are planning on releasing Conda Python 3.11 binaries as part of future release once Anaconda provides these key dependencies. More information and instructions on how to download the Pip packages can be found here.

Optimized PyTorch Inference with AWS Graviton processors

The optimizations focused on three key areas: GEMM kernels, bfloat16 support, primitive caching and the memory allocator. For aarch64 platforms, PyTorch supports Arm Compute Library (ACL) GEMM kernels via Mkldnn(OneDNN) backend. The ACL library provides Neon/SVE GEMM kernels for fp32 and bfloat16 formats. The bfloat16 support on c7g allows efficient deployment of bfloat16 trained, AMP (Automatic Mixed Precision) trained, or even the standard fp32 trained models. The standard fp32 models leverage bfloat16 kernels via OneDNN fast math mode, without any model quantization. Next we implemented primitive caching for conv, matmul and inner product operators. More information on the updated PyTorch user guide with the upcoming 2.0 release improvements and TorchBench benchmark details can be found here.

Read More

New Library Updates in PyTorch 2.0

Summary

We are bringing a number of improvements to the current PyTorch libraries, alongside the PyTorch 2.0 release. These updates demonstrate our focus on developing common and extensible APIs across all domains to make it easier for our community to build ecosystem projects on PyTorch.

Along with 2.0, we are also releasing a series of beta updates to the PyTorch domain libraries, including those that are in-tree, and separate libraries including TorchAudio, TorchVision, and TorchText. An update for TorchX is also being released as it moves to community supported mode. Please find the list of the latest stable versions and updates below.

Latest Stable Library Versions (Full List)

TorchArrow 0.1.0 TorchRec 0.4.0 TorchVision 0.15
TorchAudio 2.0 TorchServe 0.7.1 TorchX 0.4.0
TorchData 0.6.0 TorchText 0.15.0 PyTorch on XLA Devices 1.14

*To see prior versions or (unstable) nightlies, click on versions in the top left menu above ‘Search Docs’.

TorchAudio

[Beta] Data augmentation operators

The release adds several data augmentation operators under torchaudio.functional and torchaudio.transforms:

  • torchaudio.functional.add_noise
  • torchaudio.functional.convolve
  • torchaudio.functional.deemphasis
  • torchaudio.functional.fftconvolve
  • torchaudio.functional.preemphasis
  • torchaudio.functional.speed
  • torchaudio.transforms.AddNoise
  • torchaudio.transforms.Convolve
  • torchaudio.transforms.Deemphasis
  • torchaudio.transforms.FFTConvolve
  • torchaudio.transforms.Preemphasis
  • torchaudio.transforms.Speed
  • torchaudio.transforms.SpeedPerturbation

The operators can be used to synthetically diversify training data to improve the generalizability of downstream models.

For usage details, please refer to the functional and transform documentation and Audio Data Augmentation tutorial.

[Beta] WavLM and XLS-R models

The release adds two self-supervised learning models for speech and audio.

  • WavLM that is robust to noise and reverberation.
  • XLS-R that is trained on cross-lingual datasets.

Besides the model architectures, torchaudio also supports corresponding pre-trained pipelines:

  • torchaudio.pipelines.WAVLM_BASE
  • torchaudio.pipelines.WAVLM_BASE_PLUS
  • torchaudio.pipelines.WAVLM_LARGE
  • torchaudio.pipelines.WAV2VEC_XLSR_300M
  • torchaudio.pipelines.WAV2VEC_XLSR_1B
  • torchaudio.pipelines.WAV2VEC_XLSR_2B

For usage details, please refer to the factory function and pre-trained pipelines documentation.

TorchRL

The initial release of torchrl includes several features that span across the entire RL domain. TorchRL can already be used in online, offline, multi-agent, multi-task and distributed RL settings, among others. See below:

[Beta] Environment wrappers and transforms

torchrl.envs includes several wrappers around common environment libraries. This allows users to swap one library with another without effort. These wrappers build an interface between these simulators and torchrl:

  • dm_control:
  • Gym
  • Brax
  • EnvPool
  • Jumanji
  • Habitat

It also comes with many commonly used transforms and vectorized environment utilities that allow for a fast execution across simulation libraries. Please refer to the documentation for more detail.

[Beta] Datacollectors

Data collection in RL is made easy via the usage of single process or multiprocessed/distributed data collectors that execute the policy in the environment over a desired duration and deliver samples according to the user’s needs. These can be found in torchrl.collectors and are documented here.

[Beta] Objective modules

Several objective functions are included in torchrl.objectives, among which:

  • A generic PPOLoss class and derived ClipPPOLoss and KLPPOLoss
  • SACLoss and DiscreteSACLoss
  • DDPGLoss
  • DQNLoss
  • REDQLoss
  • A2CLoss
  • TD3Loss
  • ReinforceLoss
  • Dreamer

Vectorized value function operators also appear in the library. Check the documentation here.

[Beta] Models and exploration strategies

We provide multiple models, modules and exploration strategies. Get a detailed description in the doc.

[Beta] Composable replay buffer

A composable replay buffer class is provided that can be used to store data in multiple contexts including single and multi-agent, on and off-policy and many more.. Components include:

  • Storages (list, physical or memory-based contiguous storages)
  • Samplers (Prioritized, sampler without repetition)
  • Writers
  • Possibility to add transforms

Replay buffers and other data utilities are documented here.

[Beta] Logging tools and trainer

We support multiple logging tools including tensorboard, wandb and mlflow.

We provide a generic Trainer class that allows for easy code recycling and checkpointing.

These features are documented here.

TensorDict

TensorDict is a new data carrier for PyTorch.

[Beta] TensorDict: specialized dictionary for PyTorch

TensorDict allows you to execute many common operations across batches of tensors carried by a single container. TensorDict supports many shape and device or storage operations, and can readily be used in distributed settings. Check the documentation to know more.

[Beta] @tensorclass: a dataclass for PyTorch

Like TensorDict, tensorclass provides the opportunity to write dataclasses with built-in torch features such as shape or device operations.

[Beta] tensordict.nn: specialized modules for TensorDict

The tensordict.nn module provides specialized nn.Module subclasses that make it easy to build arbitrarily complex graphs that can be executed with TensorDict inputs. It is compatible with the latest PyTorch features such as functorch, torch.fx and torch.compile.

TorchRec

[Beta] KeyedJaggedTensor All-to-All Redesign and Input Dist Fusion

We observed performance regression due to a bottleneck in sparse data distribution for models that have multiple, large KJTs to redistribute.

To combat this we altered the comms pattern to transport the minimum data required in the initial collective to support the collective calls for the actual KJT tensor data. This data sent in the initial collective, ‘splits’ means more data is transmitted over the comms stream overall, but the CPU is blocked for significantly shorter amounts of time leading to better overall QPS.

Furthermore, we altered the TorchRec train pipeline to group the initial collective calls for the splits together before launching the more expensive KJT tensor collective calls. This fusion minimizes the CPU blocked time as launching each subsequent input distribution is no longer dependent on the previous input distribution.

With this feature, variable batch sizes are now natively supported across ranks. These features are documented here.

TorchVision

[Beta] Extending TorchVision’s Transforms to Object Detection, Segmentation & Video tasks

TorchVision is extending its Transforms API! Here is what’s new:

  • You can use them not only for Image Classification but also for Object Detection, Instance & Semantic Segmentation and Video Classification.
  • You can use new functional transforms for transforming Videos, Bounding Boxes and Segmentation Masks.

Learn more about these new transforms from our docs, and submit any feedback in our dedicated issue.

TorchText

[Beta] Adding scriptable T5 and Flan-T5 to the TorchText library with incremental decoding support!

TorchText has added the T5 model architecture with pre-trained weights for both the original T5 paper and Flan-T5. The model is fully torchscriptable and features an optimized multiheaded attention implementation. We include several examples of how to utilize the model including summarization, classification, and translation.

For more details, please refer to our docs.

TorchX

TorchX is moving to community supported mode. More details will be coming in at a later time.

Read More

Democratizing AI with PyTorch Foundation and ROCm™ support for PyTorch

Democratizing AI with PyTorch Foundation and ROCm™ support for PyTorch

AMD Founding Member

Last year, Meta announced that PyTorch joined the Linux Foundation as a neutral home for growing the machine learning project and community with AMD representation as a part of the founding membership and governing board.

PyTorch Foundation’s mission is to drive AI adoption by democratizing its software ecosystem through open source principles aligning with the AMD core principle of an Open software ecosystem. AMD strives to foster innovation through the support for latest generations of hardware, tools, libraries, and other components to simplify and accelerate adoption of AI across a broad range of scientific discoveries.

AMD, along with key PyTorch codebase developers (including those at Meta AI), delivered a set of updates to the ROCm™ open software ecosystem that brings stable support for AMD Instinct™ accelerators as well as many Radeon™ GPUs. This now gives PyTorch developers the ability to build their next great AI solutions leveraging AMD GPU accelerators & ROCm. The support from PyTorch community in identifying gaps, prioritizing key updates, providing feedback for performance optimizing and supporting our journey from “Beta” to “Stable” was immensely helpful and we deeply appreciate the strong collaboration between the two teams at AMD and PyTorch. The move for ROCm support from “Beta” to “Stable” came in the PyTorch 1.12 release (June 2022) brings the added support to easily run PyTorch on native environment without having to configure custom dockers. This is a sign of confidence about the quality of support and performance of PyTorch using AMD Instinct and ROCm. The results of these collaborative efforts are evident in the performance measured on key industry benchmarks like Microsoft’s SuperBench shown below in Graph 1.

“We are excited to see the significant impact of developers at AMD to contribute to and extend features within PyTorch to make AI models run in a more performant, efficient, and scalable way. A great example of this is the thought-leadership around unified memory approaches between the framework and future hardware systems, and we look forward to seeing that feature progress.”

– Soumith Chintala, PyTorch lead-maintainer and Director of Engineering, Meta AI

The progressive improvements on both the AMD CDNA™ architecture as well as ROCm and PyTorch shows single GPU model throughput increase from AMD Instinct MI100 to the latest generation AMD Instinct MI200 family GPUs going from ROCm 4.2 to ROCm 5.3 and from PyTorch 1.7 to PyTorch 1.12.

Graph 1: ML model performance over generation using Microsoft Superbench Suite

Graph 1: ML model performance over generation using Microsoft Superbench Suite 1, 2, 3

Below are a few of the key updates for ROCm support since the PyTorch 1.12 release

Full Continuous Integration (CI) for ROCm on PyTorch

With the ROCm support for PyTorch move from “Beta” to “Stable,” all the functions and features commits are now verified through a full Continuous Integration (CI) process. The CI process helps ensure the proper build and test process ahead of an expected Docker and PIP wheel release with stable commits forthcoming.

Support for Kineto Profiler

The addition of Kineto profiler support to ROCm now helps developers and users understand performance bottlenecks through effective diagnosis and profiling tools. The tool also provides recommendations to improve known issues and visualization through TensorBoard UI.

Key PyTorch Libraries support added

PyTorch ecosystem libraries like TorchText (Text classification), TorchRec (libraries for recommender systems – RecSys), TorchVision (Computer Vision), TorchAudio (audio and signal processing) are fully supported since ROCm 5.1 and upstreamed with PyTorch 1.12.

Key libraries provided with the ROCm software stack including MIOpen (Convolution models), RCCL (ROCm Collective Communications) and rocBLAS (BLAS for transformers) were further optimized to offer new potential efficiencies and higher performance.

MIOpen innovates on several fronts, such as implementing fusion to optimize for memory bandwidth and GPU launch overheads, providing an auto-tuning infrastructure to overcome the large design space of problem configurations, and implementing different algorithms to optimize convolutions for different filter and input sizes. MIOpen is one of the first libraries to publicly support the bfloat16 data-type for convolutions, allowing efficient training at lower precision maintaining expected accuracy.

RCCL (pronounced “Rickle”) is a stand-alone library of standard collective communication routines for GPUs, implementing all-reduce, all-gather, reduce, broadcast, reduce-scatter, gather, scatter, and all-to-all. There is support for direct GPU-to-GPU send and receive operations. It has been optimized to achieve high bandwidth on platforms using PCIe®, Infinity Fabric™ (GPU to GPU) as well as networking using InfiniBand Verbs or TCP/IP sockets. RCCL supports an arbitrary number of GPUs installed in single or multiple nodes and can be used in either single- or multi-process (e.g., MPI) applications.

Along with the above key highlights, over 50 features and functionality improvements were completed jointly between AMD and PyTorch to add stable support for ROCm. These include improvements to tools, compilers, runtime, graph optimizations through TorchScript, INT8 quant path usage, and ONNX runtime integration including support for Navi 21 based Radeon™ PRO datacenter graphics card to name a few.

AITemplate Inference Engine

MetaAI recently published a blog announcing the release of its open source AITemplate (link) for a unified inference system supporting AMD Instinct GPU accelerators using the AMD ROCm stack. This Python based framework can help significantly improve performance through increased utilization of AMD matrix cores for transformer blocks. This is achieved through the AMD Composable Kernel (CK) library which provides performance critical Kernels for ML AI workloads across multiple architectures including GPUs and CPUs through HIP & C++.

Moreover, the AITemplate also provides out-of-the-box support for widely used AI models like BERT, ResNET, Vision Transformer, Stable Diffusion etc. simplifying deployment process through these pretrained models.

What’s coming with future ROCm releases?

Unified memory models for CPU + GPU

As system architecture evolves to address the complexity of large problem sizes and data sets, memory management becomes a key performance bottle neck that needs a cohesive strategy to be addressed through innovations at both hardware and software levels. AMD is uniquely positioned to address this problem with its effective data center solutions integrating AMD EPYC™ CPU cores with its AMD Instinct GPU compute units in a truly unified datacenter APU (Accelerated Processing Unit) form factor set to be launched in 2H 2023.

The software work to leverage the unified CPU + GPU memory has already started in collaboration with the PyTorch team, to enable the usage of a fast, low latency, synchronized memory model that enables not only AMD but also other AI accelerators to address the complex memory management problem of today. We are looking forward to this joint effort and announcement soon.

Acknowledgement

The content in this blog highlights the joint work between AMD and key PyTorch contributors including Meta, working on many of the core features, as well as Microsoft enabling ONNX Runtime support. We are looking forward to working with the other founding members at the PyTorch Foundation on the next steps and improvements to democratize and grow adoption of PyTorch across the industry.

CAUTIONARY STATEMENT


This blog contains forward-looking statements concerning Advanced Micro Devices, Inc. (AMD) such as the availability, timing and expected benefits of an AMD datacenter APU form factor, which are made pursuant to the Safe Harbor provisions of the Private Securities Litigation Reform Act of 1995. Forward-looking statements are commonly identified by words such as “would,” “may,” “expects,” “believes,” “plans,” “intends,” “projects” and other terms with similar meaning. Investors are cautioned that the forward-looking statements in this blog are based on current beliefs, assumptions and expectations, speak only as of the date of this blog and involve risks and uncertainties that could cause actual results to differ materially from current expectations. Such statements are subject to certain known and unknown risks and uncertainties, many of which are difficult to predict and generally beyond AMD’s control, that could cause actual results and other future events to differ materially from those expressed in, or implied or projected by, the forward-looking information and statements. Investors are urged to review in detail the risks and uncertainties in AMD’s Securities and Exchange Commission filings, including but not limited to AMD’s most recent reports on Forms 10-K and 10-Q. AMD does not assume, and hereby disclaims, any obligation to update forward-looking statements made in this blog, except as may be required by law.

Endnotes

  1. MI100D-01 SuperBench v0.5 model training results based on AMD internal testing as of 11/09/2022 measuring the total training throughput, at half precision, using a 2P AMD EPYC™ 7763 CPU server tested with 1x AMD Instinct™ MI100 (32GB HBM2e) 300W GPU, SBIOS 2.2, Ubuntu® 20.04.5 LTS, host ROCm™ 5.2.0, guest ROCm 4.2, PyTorch 1.7.0. Server manufacturers may vary configurations, yielding different results. Performance may vary based factors including use of latest drivers and optimizations.
  2. MI200D-01 SuperBench v0.6 model training results based on AMD internal testing as of 11/09/2022 measuring the total training throughput, at half precision, using a 2P AMD EPYC™ 7763 CPU server tested with 1x AMD Instinct™ MI210 (64GB HBM2e) 300W GPU, SBIOS 2.2, Ubuntu 20.04.5 LTS, host ROCm 5.3.0, guest ROCm 5.3, PyTorch 1.12. Server manufacturers may vary configurations, yielding different results. Performance may vary based factors including use of latest drivers and optimizations.
  3. MI200D-02: SuperBench v0.6 model training results based on AMD internal testing as of 11/09/2022 measuring the total training throughput, at half precision, using a 2P AMD EPYC™️ 7763 CPU server tested with 1x AMD Instinct™️ MI250 (128GB HBM2e) 560W GPU, SBIOS M12, Ubuntu 20.04 LTS, host ROCm 5.3.0, guest ROCm 5.3, PyTorch 1.12. Server manufacturers may vary configurations, yielding different results. Performance may vary based factors including use of latest drivers and optimizations.

Read More

Deprecation of CUDA 11.6 and Python 3.7 Support

For the upcoming PyTorch 2.0 feature release (target March 2022), we will target CUDA 11.7 as the stable version and CUDA 11.8 as the experimental version of CUDA and Python >=3.8, <=3.11.

If you are still using or depending on CUDA 11.6 or Python 3.7 builds, we strongly recommend moving to at least CUDA 11.7 and Python 3.8, as it would be the minimum versions required for PyTorch 2.0.

Please note that as of Feb 1, CUDA 11.6 and Python 3.7 are no longer included in the nightlies

Please refer to the Release Compatibility Matrix for PyTorch releases:

PyTorch Version Python Stable CUDA Experimental CUDA
2.0 >=3.8, <=3.11 CUDA 11.7, CUDNN 8.5.0.96 CUDA 11.8, CUDNN 8.7.0.84
1.13 >=3.7, <=3.10 CUDA 11.6, CUDNN 8.3.2.44 CUDA 11.7, CUDNN 8.5.0.96
1.12 >=3.7, <=3.10 CUDA 11.3, CUDNN 8.3.2.44 CUDA 11.6, CUDNN 8.3.2.44

As of 2/1/2023

For more information on PyTorch releases, updated compatibility matrix and release policies, please see (and bookmark) Readme.

Read More

Performance experiments with Stable Diffusion

This is a companion to the main blog “Accelerated Stable Diffusion with PyTorch 2”, containing detailed information on benchmarking setup and results of individual experiments. It is mainly aimed at a hands-on reader who would want to reproduce or develop further the work we described in the main text. Please see the main text for all the context and the summary of results.

Appendix 1: benchmarked versions definition

Here we define precisely what we mean by “original code” and “optimized code” in the main text.

Original code

Lives in https://github.com/sgrigory/stablediffusion2 on original-benchmark branch, specifically in this commit. This is almost the same code as in https://github.com/Stability-AI/stablediffusion, with minimal modifications necessary for benchmarking. In particular, the code is able to turn off xFormers attention when the environment variable USE_XFORMERS is set to False.

This code uses PyTorch 1.12 and the original custom implementation of attention.

Optimized code

The optimized version is the code living here. It has all the optimizations we mentioned in the main text:

  • nn.MultiheadAttention in CrossAttention instead of custom attention implementation
  • Compilation with torch.compile
  • Other minor optimizations in PyTorch-related code.

The first optimization (using nn.MultiheadAttention in CrossAttention) schematically boils down to the following pseudocode:

class CrossAttention(nn.Module):
    def __init__(self, ...):
        # Create matrices: Q, K, V, out_proj
        ...
    def forward(self, x, context=None, mask=None):
       # Compute out = SoftMax(Q*K/sqrt(d))V
       # Return out_proj(out)
       …

gets replaced with

class CrossAttention(nn.Module):
    def __init__(self, ...):
        self.mha = nn.MultiheadAttention(...)
    def forward(self, x, context):
	return self.mha(x, context, context)

See the full diff here.

We have also introduced the following CLI flags:

  • --disable_math, --disable_mem_efficient, --disable_flash to allow turning specific attention backends off
  • --compile to turn on PyTorch compilation

The optimized version uses PyTorch 2.0.0.dev20230111+cu117

Flags added to both code versions

In both code versions we have added the following CLI options to txt2img.py.

  • --skip_first to use a “warm-up” iteration before starting to measure time. See the end of section “Benchmarking setup and results summary” in the main text on why this was necessary
  • --time_file <FILENAME> to write runtime in seconds in text format to the specified file

Prompts

Now it should already be clear how to run the 5 configurations mentioned in the main text. For completeness we provide the prompts which can be used to run each of them. This assumes you have

  • installed dependencies from the original version into conda environment ldm-original
  • installed dependencies from the optimized version into conda environment ldm
  • downloaded model weights into /tmp/model.ckpt
  • converted model weights to the new architecture and saved them into /tmp/model_native_mha.ckpt

(see Colab for a bash script which does that)

Prompts for 5 configurations:

# Run optimized with memory-efficient attention and compilation
conda activate ldm
git checkout optmize-w-compile
python scripts/txt2img.py --prompt "A photo" --seed 1 --plms --config configs/stable-diffusion/v2-inference_native_mha.yaml --ckpt /tmp/model_native_mha.ckpt --n_iter 2 --n_samples 1 --compile --skip_first

# Run optimized with memory-efficient attention
conda activate ldm
git checkout optmize-w-compile
python stable-diffusion/scripts/txt2img.py --prompt "A photo" --seed 1 --plms --config stable-diffusion/configs/stable-diffusion/v2-inference_native_mha.yaml --ckpt /tmp/model_native_mha.ckpt --n_iter 2 --n_samples 1 --skip_first

# Run optimized without memory-efficient or flash attention
conda activate ldm
git checkout optmize-w-compile
python stable-diffusion/scripts/txt2img.py --prompt "A photo" --seed 1 --plms --config stable-diffusion/configs/stable-diffusion/v2-inference_native_mha.yaml --ckpt /tmp/model_native_mha.ckpt --n_iter 2 --n_samples 1 --disable_mem_efficient --disable_flash --skip_first 

# Run original code with xFormers
conda activate ldm-original
git checkout original-benchmark
python stable-diffusion-original/scripts/txt2img.py --prompt "A photo" --seed 1 --plms --config stable-diffusion-original/configs/stable-diffusion/v2-inference.yaml --ckpt /tmp/model.ckpt --n_iter 2 --n_samples 1 --skip_first

# Run original code without xFormers
conda activate ldm-original
git checkout original-benchmark
USE_XFORMERS=False python stable-diffusion-original/scripts/txt2img.py --prompt "A photo" --seed 1 --plms --config stable-diffusion-original/configs/stable-diffusion/v2-inference.yaml --ckpt /tmp/model.ckpt --n_iter 2 --n_samples 1 --skip_first

Appendix 2: per-run data

Plots with per-run benchmark data can be found here. Each plot shows all the runs for a particular GPU (P100, V100, T4, A10, A100) and batch size (1, 2, or 4). The bar charts in the main text are obtained from this data by averaging. The file names are self-explanatory, for example “original_vs_optimized_A10_n_samples_2_n_iter_2_sd2.png” contains runs for A10 GPU, batch size 2 and number of iterations 2.

Appendix 3: Accelerated Stable Diffusion 1

Before the work on Stable Diffusion 2 described in the main text, we also applied similar optimizations to Stable Diffusion 1 by CompVis prior to the release of Stable Diffusion 2. The original implementation of SD1 does not integrate with xFormers yet, and so the speedup from just using the PyTorch optimized attention instead of custom implementation is significant. It should be noted that the HuggingFace Diffusers port of SD1 allows integration with xFormers, so an interesting open question which we didn’t explore would be how the performance of SD1 with PyTorch optimized attention compares to HuggingFace SD1+xFormers.

We benchmarked two versions of SD1, original and optimized:

  • As the original version we took the first SD release, and placed it here with minimal modifications to simplify benchmarking. It uses PyTorch 1.11 and custom implementation of attention.
  • The optimized version is the code living here. It uses nn.MultiheadAttention in CrossAttention and PyTorch 2.0.0.dev20221220+cu117.

Here are the results for different GPU architectures and batch size 2:

Version

T4 P100 V100 A100
Original SD1 (runtime in s)

70.9 71.5 20.3 14.4
Optimized SD1 (runtime in s)

52.7 (-25.6%) 57.5 (-19.5%) 14.3 (-29.3%) 10.4 (27.9%)

Same as for SD2, we used Meta hardware for P100, V100, A100 benchmarks. The T4 benchmark was done in Google Colab here.

We didn’t apply compilation to SD1, and so didn’t include a “warm-up” iteration in these benchmarks, as we did for SD2.

Both applying torch.compile to SD1 and benchmarking HuggingFace version of SD1 with PyTorch 2 optimisations would be a great exercise for the reader – try it and let us know if you get interesting results.

Read More

Accelerated Stable Diffusion with PyTorch 2

Accelerated Stable Diffusion with PyTorch 2

TL;DR: PyTorch 2.0 nightly offers out-of-the-box performance improvement for Stable Diffusion 2.1 by using the new torch.compile() compiler and optimized implementations of Multihead Attention integrated with PyTorch 2.

Introduction

Stable Diffusion (SD) is a great example of Generative AI, producing high quality images from text prompts. However, as well as for other diffusion-based models, its generation is rather slow, due to the iterative nature of the sampling process by which the images are produced. This makes it important to optimize the code running inside the sampling loop.

We took SD 2.1 from Stability AI as a starting point and accelerated its text-to-image generation using two optimizations available in PyTorch 2: compilation and fast attention implementation. Together with a few minor memory processing improvements in the code these optimizations give up to 49% inference speedup relative to the original SD implementation without xFormers, and 39% inference speedup relative to using SD with xFormers (excluding the compilation time), depending on the GPU architecture and batch size. Importantly, the speedup comes without a need to install xFormers or any other extra dependencies.

The table below shows the improvement in runtime between the original implementation with xFormers installed and our optimized version with PyTorch-integrated memory efficient attention (originally developed for and released in the xFormers library) and PyTorch compilation. The compilation time is excluded.

Runtime improvement in % compared to original+xFormers

See the absolute runtime numbers in section “Benchmarking setup and results summary”

GPU Batch size 1 Batch size 2 Batch size 4
P100 (no compilation) -3.8 0.44 5.47
T4 2.12 10.51 14.2
A10 -2.34 8.99 10.57
V100 18.63 6.39 10.43
A100 38.5 20.33 12.17

One can notice the following:

  • The improvements are significant for powerful GPUs like A100 and V100. For those GPUs the improvement is most pronounced for batch size 1
  • For less powerful GPUs we observe smaller speedups (or in two cases slight regressions). The batch size trend is reversed here: improvement is larger for larger batches

In the following sections we describe the applied optimizations and provide detailed benchmarking data, comparing SD performance with various optimization features on/off.

Specifically, we benchmark 5 configurations and the plots below compare their absolute performance for different GPUs and batch sizes. For definitions of these configurations see section “Benchmarking setup and results”.

Benchmark of Stable Diffusion 2 versions across GPU architectures, batch size 1

Benchmark of Stable Diffusion 2 versions across GPU architectures, batch size 2

Benchmark of Stable Diffusion 2 versions across GPU architectures, batch size 4

If you prefer looking directly at the code, see the Google Colab which runs the benchmark on T4.

Optimizations

Here we’ll go into more detail about the optimizations introduced into the SD code. At the moment they rely on features only available in the nightlies, so we pinned the PyTorch version to a recent nightly (see here). Once the PyTorch 2.0 release comes out, these optimizations won’t have to rely on nightlies any more.

Optimized Attention

One part of the code which we optimized was the scaled dot-product attention. Attention is known to be a heavy operation: naive implementation materializes the attention matrix, leading to time and memory complexity quadratic in sequence length. In Stable Diffusion attention (CrossAttention) appears as part of Transformer blocks in multiple parts of the U-Net. Since the U-Net runs at every sampling step, this becomes a critical point to optimize. In PyTorch 2 optimized attention implementation is integrated into torch.nn.MultiheadAttention, and so we used it to replace the custom attention implementation in CrossAttention.

The optimized implementation of attention was available already in PyTorch 1.13 (see here) and widely adopted (see e.g. HuggingFace transformers library example). In particular, it integrates memory-efficient attention from the xFormers library and flash attention from https://arxiv.org/abs/2205.14135. PyTorch 2.0 expands this to additional attention functions such as cross attention and custom kernels for further acceleration, making it applicable to SD.

Flash attention is available on GPUs with compute capability SM 7.5 or SM 8.x – for example, on T4, A10, and A100, which are included in our benchmark (you can check compute capability of each NVIDIA GPU here). However, in our tests on A100 the memory efficient attention performed better than flash attention for the particular case of SD, due to the small number of attention heads and small batch size. PyTorch understands this and chooses memory efficient attention over flash attention for SD when both are available (see the logic here). For full control over the attention backends (memory-efficient attention, flash attention, “vanilla math”, or any future ones), power users can enable and disable them manually with the help of the context manager torch.backends.cuda.sdp_kernel.

Compilation

Compilation is a new feature of PyTorch 2.0, enabling significant speedups with a very simple user experience. To invoke the default behavior, simply wrap a PyTorch module or a function into torch.compile:

model = torch.compile(model)

PyTorch compiler then turns Python code into a set of instructions which can be executed efficiently without Python overhead. The compilation happens dynamically the first time the code is executed. With the default behavior, under the hood PyTorch utilized TorchDynamo to compile the code and TorchInductor to further optimize it. See this tutorial for more details.

Although the one-liner above is enough for compilation, certain modifications in the code can squeeze a larger speedup. In particular, one should avoid so-called graph breaks – places in the code which PyTorch can’t compile. As opposed to previous PyTorch compilation approaches (like TorchScript), PyTorch 2 compiler doesn’t break in this case. Instead it falls back on eager execution – so the code runs, but with reduced performance. We introduced a few minor changes to the SD code to eliminate graph breaks (here and here). See this doc to learn more about graph breaks and how to eliminate them.

Note that compilation requires GPU compute capability >= SM 7.0 to run in non-eager mode. This covers all GPUs in our benchmarks – T4, V100, A10, A100 – except for P100 (see the full list).

Other optimizations

In addition, we have improved efficiency of some memory operations – e.g. creating a tensor on GPU directly rather than creating it on CPU and later moving to GPU (see here and here). The places where such optimizations were necessary were determined by line-profiling and looking at CPU/GPU traces and Flame Graphs.

Benchmarking setup and results summary

We have two versions of SD code to compare: original and optimized. On top of this, several optimization features (xFormers, PyTorch memory efficient attention, compilation) can be turned on/off. Overall, as mentioned in the introduction, we will be benchmarking 5 configurations:

  • Original code without xFormers
  • Original code with xFormers
  • Optimized code with vanilla math attention backend and no compilation
  • Optimized code with memory-efficient attention backend and no compilation
  • Optimized code with memory-efficient attention backend and compilation

As the original version we took the SD 2.1 release, and placed it here with minimal modifications necessary for benchmarking. It uses PyTorch 1.12 and a custom implementation of attention.

The optimized version is the code living here. It uses nn.MultiheadAttention in CrossAttention and PyTorch 2.0.0.dev20230111+cu117. It also has a few other minor optimizations in PyTorch-related code.

Please see the appendix “Benchmarked versions definition” in the companion page for the precise definition of the 5 configurations and prompts triggering each of them.

The table below shows runtime of each version of the code in seconds, and the percentage improvement compared to the original with xFormers. The compilation time is excluded.

Runtimes for batch size 1. In parenthesis – relative improvement with respect to the “Original with xFormers” row

Configuration P100 T4 A10 V100 A100
Original without xFormers 30.4s (-19.3%) 29.8s (-77.3%) 13.0s (-83.9%) 10.9s (-33.1%) 8.0s (-19.3%)
Original with xFormers 25.5s (0.0%) 16.8s (0.0%) 7.1s (0.0%) 8.2s (0.0%) 6.7s (0.0%)
Optimized with vanilla math attention, no compilation 27.3s (-7.0%) 19.9s (-18.7%) 13.2s (-87.2%) 7.5s (8.7%) 5.7s (15.1%)
Optimized with mem. efficient attention, no compilation 26.5s (-3.8%) 16.8s (0.2%) 7.1s (-0.8%) 6.9s (16.0%) 5.3s (20.6%)
Optimized with mem. efficient attention and compilation 16.4s (2.1%) 7.2s (-2.3%) 6.6s (18.6%) 4.1s (38.5%)

Runtimes for batch size 2

Configuration P100 T4 A10 V100 A100
Original without xFormers 58.0s (-21.6%) 57.6s (-84.0%) 24.4s (-95.2%) 18.6s (-63.0%) 12.0s (-50.6%)
Original with xFormers 47.7s (0.0%) 31.3s (0.0%) 12.5s (0.0%) 11.4s (0.0%) 8.0s (0.0%)
Optimized with vanilla math attention, no compilation 49.3s (-3.5%) 37.9s (-21.0%) 17.8s (-42.2%) 12.7s (-10.7%) 7.8s (1.8%)
Optimized with mem. efficient attention, no compilation 47.5s (0.4%) 31.2s (0.5%) 12.2s (2.6%) 11.5s (-0.7%) 7.0s (12.6%)
Optimized with mem. efficient attention and compilation 28.0s (10.5%) 11.4s (9.0%) 10.7s (6.4%) 6.4s (20.3%)

Runtimes for batch size 4

Configuration P100 T4 A10 V100 A100
Original without xFormers 117.9s (-20.0%) 112.4s (-81.8%) 47.2s (-101.7%) 35.8s (-71.9%) 22.8s (-78.9%)
Original with xFormers 98.3s (0.0%) 61.8s (0.0%) 23.4s (0.0%) 20.8s (0.0%) 12.7s (0.0%)
Optimized with vanilla math attention, no compilation 101.1s (-2.9%) 73.0s (-18.0%) 28.3s (-21.0%) 23.3s (-11.9%) 14.5s (-13.9%)
Optimized with mem. efficient attention, no compilation 92.9s (5.5%) 61.1s (1.2%) 23.9s (-1.9%) 20.8s (-0.1%) 12.8s (-0.9%)
Optimized with mem. efficient attention and compilation 53.1s (14.2%) 20.9s (10.6%) 18.6s (10.4%) 11.2s (12.2%)

To minimize fluctuations and external influence on the performance of the benchmarked code, we ran each version of the code one after another, and then repeated this sequence 10 times: A, B, C, D, E, A, B, … So the results of a typical run would look like the one in the picture below. For results of all runs please see appendix “Per-run data” in the companion page. Note that one shouldn’t rely on comparison of absolute run times between different graphs, but comparison of run times inside one graph is pretty reliable, thanks to our benchmarking setup.

Stable Diffusion 2.1 benchmarks

Each run of txt2img.py generates several batches, which is regulated by the CLI parameter --n_iter. In the benchmarks we used n_iter = 2, but introduced an additional “warm-up” iteration, which doesn’t contribute to the run time. This was necessary for the runs with compilation, because compilation happens the first time the code runs, and so the first iteration is much longer than all subsequent. To make comparison fair, we also introduced this additional “warm-up” iteration to all other runs, which is turned on by CLI option --skip_first provided to the modified txt2img.py.

The numbers in the table above are for number of iterations 2 (plus a “warm-up one”), prompt ”A photo”, seed 1, PLMS sampler, and autocast turned on. See the companion page for precise CLI commands in appendix “Benchmarked versions definition” and detailed results of individual runs in appendix “Per-run data”.

The P100, V100, and A100 benchmarks were done on Meta internal infrastructure. The T4 benchmarks were done in Google Colab Pro (see the Google Colab notebook). The A10 benchmarks were done on g5.4xlarge AWS instances with 1 GPU.

Conclusions and next steps

We have shown that new features of PyTorch 2 – compiler and optimized attention implementation – give performance improvements exceeding or comparable with what previously required installation of an external dependency (xFormers). PyTorch achieved this, in particular, by integrating memory efficient attention from xFormers into its codebase. This is a significant improvement for user experience, given that xFormers, being a state-of-the-art library, in many scenarios requires custom installation process and long builds.

There are a few natural directions in which this work can be continued:

  • There are new implementations of SD, including a port to HuggingFace diffusers library. It would be interesting to benchmark against them. Note that diffusers also require installing xFormers in order to use memory efficient attention
  • The optimizations we implemented and described here are only benchmarked for text-to-image inference so far. It would be interesting to see how they affect training. PyTorch compilation can be directly applied to training; enabling training with PyTorch optimized attention is on the roadmap
  • We intentionally minimized changes to the original SD code. Further profiling and optimization can probably bring more improvements
  • At the moment compilation is applied only to the U-Net model inside the sampler. Since there is a lot happening outside of U-Net (e.g. operations directly in the sampling loop), it would be beneficial to compile the whole sampler. However, this would require analysis of the compilation process to avoid recompilation at every sampling step
  • Current code only applies compilation within the PLMS sampler, but it should be trivial to extend it to other samplers
  • Besides text-to-image generation, SD 2.1 has other pipelines – image-to-image and inpainting. It would be interesting to measure how their performance improves from PyTorch 2 optimizations

Try some of this in the Colab or on a GPU of your choice. See if you can further increase the performance of SD, and share the results! This is your chance to get a preview of PyTorch 2.0 and experience the features coming in the next release.

As a note, if you want access to new PyTorch features which come after this post is published, just tweak the PyTorch and TorchVision versions in environment.yaml.

Resources

Acknowledgements

We would like to thank Geeta Chauhan, Natalia Gimelshein, Patrick Labatut, Bert Maher, Mark Saroufim, Michael Voznesensky and Francisco Massa for their valuable advice and early feedback on the text.

Special thanks to Yudong Tao for creating the first version of Stable Diffusion with PyTorch native attention.

For more information, visit this page with additional resources.

Read More