ExecuTorch Alpha: Taking LLMs and AI to the Edge with Our Community and Partners

We are excited to announce the release of ExecuTorch alpha, focused on deploying large language models (LLMs) and large ML models to the edge, stabilizing the API surface, and improving our installation processes. It has been an exciting few months from our 0.1 (preview) release in collaboration with our partners at Arm, Apple, and Qualcomm Technologies, Inc.

In this post we’ll discuss our full support for Meta’s Llama 2, early support for Meta’s Llama 3, broad model support in ExecuTorch, and highlight the important work our partners have done to move us forward.

Large Language Models on Mobile

Mobile devices are highly constrained for compute, memory, and power. To bring LLMs to these devices, we heavily leverage quantization and other techniques to pack these models appropriately.

ExecuTorch alpha supports 4-bit post-training quantization using GPTQ. We’ve provided broad device support on CPU by landing dynamic shape support and new dtypes in XNNPack. We’ve also made significant improvements in export and lowering, reduced memory overhead and improved runtime performance. This enables running Llama 2 7B efficiently on iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S22, S23, and S24 phones and other edge devices. Early support for Llama 3 8B is also included. We are always improving the token/sec on various edge devices and you can visit GitHub for the latest performance numbers.

We’re working closely with our partners at Apple, Arm, and Qualcomm Technologies to delegate to GPU and NPU for performance through Core ML, MPS, TOSA, and Qualcomm AI Stack backends respectively.

Supported Models

We remain committed to supporting an ever-expanding list of models with ExecuTorch. Since preview, we have significantly expanded our tested models across NLP, vision and speech, with full details in our release notes. Although support for on-device LLMs is early, we anticipate most traditional models to function seamlessly out of the box, with delegation to XNNPACK, Core ML, MPS, TOSA, and HTP for performance. If you encounter any problems please open a GitHub issue with us.

Productivity

Deploying performant models tuned for specific platforms often require deep visualization into the on-device runtime data to determine the right changes to make in the original PyTorch model. With ExecuTorch alpha, we provide a powerful SDK with observability throughout the process from model authoring to deployment, including delegate and hardware-level information.

The ExecuTorch SDK was enhanced to include better debugging and profiling tools. Because ExecuTorch is built on PyTorch, the debugging capabilities include the ability to map from operator nodes back to original Python source code for more efficient anomaly resolution and performance tuning for both delegated and non-delegated model instances. You can learn more about the ExecuTorch SDK here.

Partnerships

ExecuTorch has only been possible because of strong collaborations across Arm, Apple, and Qualcomm Technologies. The collaboration for the initial launch of ExecuTorch continues as we support LLMs and large AI models on the edge for PyTorch. As we’ve seen with this early work for ExecuTorch alpha, there are unique challenges with these larger models and we’re excited to develop in the open.

We also want to highlight the great partnership with Google on XNNPACK for CPU performance. The teams continue to work together upstreaming our changes and across the TensorFlow and PyTorch teams to make sure we can all support generative AI models on the edge with SOTA performance.

Lastly, our hardware partner MediaTek has been doing work enabling the Llama collection of models with ExecuTorch on their SoCs. We’ll have more to share in the future.

Alpha and Production Usage

With our alpha release, we have production-tested ExecuTorch. Meta is using ExecuTorch for hand tracking on Meta Quest 3 and a variety of models on Ray-Ban Meta Smart Glasses. In addition, we have begun the rollout of ExecuTorch with Instagram and are integrating with other Meta products. We are excited to see how ExecuTorch can be used for other edge experiences.

Community

We are excited to see various efforts in the community to adopt or contribute to ExecuTorch. For instance, Unity recently shared their work at the Game Developers Conference (GDC) on leveraging ExecuTorch and Edge IR to run PyTorch models with their neural network inference library Sentis. Leveraging ExecuTorch’s hackability and extensibility, Unity introduced their own custom backend that serializes ExecuTorch’s Edge Dialect IR into Sentis’ native serialized format enabling developers to begin using PyTorch models easily in their games and apps.

We’ve been building and innovating with ExecuTorch in the open. Our north star is to empower the community to deploy any ML model on edge devices painlessly and efficiently. Whether you are a hobbyist or this is your day job, we’d love for you to jump in to bring your ML models to the edge. We are looking for your help to:

  1. Use ExecuTorch to run your LLM models locally on various deployment targets and share your feedback
  2. Expand our supported models, including bug reports
  3. Expand our quantization schemes
  4. Help us build out delegates to GPU and NPU

To all individual contributors and early adopters of ExecuTorch, a big thank you as well. We can’t wait to have more of you join us!

Read More

PyTorch 2.3 Release Blog

We are excited to announce the release of PyTorch® 2.3 (release note)! PyTorch 2.3 offers support for user-defined Triton kernels in torch.compile, allowing for users to migrate their own Triton kernels from eager without experiencing performance regressions or graph breaks. Tensor Parallelism improves the experience for training Large Language Models using native PyTorch functions, which has been validated on training runs for 100B parameter models. As well, semi-structured sparsity implements semi-structured sparsity as a Tensor subclass, with observed speedups of up to 1.6 over dense matrix multiplication.

This release is composed of 3393 commits and 426 contributors since PyTorch 2.2. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.3. More information about how to get started with the PyTorch 2-series can be found at our Getting Started page.

Beta Prototype Performance Improvements
User-defined Triton kernels in torch.compile torch.export adds new API to specify dynamic_shapes Weight-Only-Quantization introduced into Inductor CPU backend
Tensor parallelism within PyTorch Distributed Asynchronous checkpoint generation
Support for semi-structured sparsity

*To see a full list of public feature submissions click here.

Beta Features

[Beta] Support for User-defined Triton kernels in torch.compile

Allows for PyTorch code that contains triton kernels to be executed natively using torch.compile. This enables users to migrate code containing triton kernels from eager PyTorch to torch.compile without running into performance regressions or graph breaks. Native support also creates an opportunity for Torch Inductor to precompile the user-defined Triton kernel as well as better organize code around the Triton kernel allowing for further optimizations.

You can find more information about how to utilize user defined Triton kernels in torch.compile within this tutorial.

[Beta] Tensor Parallelism introduces more efficient ways to train LLMs

The Tensor Parallel API facilitates various tensor manipulations across GPUs/hosts and integrates with FSDP for 2D Parallelism (Tensor parallelism across devices + Data Parallelism across hosts). It also offers a low-level API for constructing higher-level Tensor parallel APIs. This API has been validated to support the training of transformer models with over 100 billion parameters.

You can find more information on how to utilize this within your workflows within this tutorial.

[Beta] Semi-structured sparsity provides users with a way to take advantage of accelerated sparse inference and memory savings

torch.sparse.SparseSemiStructuredTensor implements semi-structured sparsity as a Tensor subclass, which have observed speedups of up to 1.6 over dense matrix multiplication.

In particular it adds:

  • Additional support for quantization composability (mixed dtype, dequant fusion)
  • Updated cuSPARSELt and CUTLASS kernels
  • torch.compile support

You can find more information on how to take advantage of semi-structured sparsity here.

Prototype Features

[PROTOTYPE] torch.export adds new API to specify dynamic_shapes

You can now use torch.export.Dim to better represent dynamic shapes by enabling developers to specify ranges (min and max values) that can be reused across different input dimensions that are constrained to be equal.

To learn more about torch.export.Dim as well as how it can be used to express more interesting relationships (such as linear arithmetic expressions) check out the tutorial here.

[PROTOTYPE] Asynchronous checkpoint generation

Asynchronous checkpoint generation allows users to continue their training loops while checkpoints are being generated, essentially offloading much of the checkpointing cost.

You can find out how to utilize this within your own workflows with this example.

Performance Improvements

[PROTOTYPE] Weight-Only-Quantization introduced into Inductor CPU backend

PyTorch 2.3 enhances LLM inference performance on torch inductor CPU backend. The project gpt-fast offers a simple and efficient PyTorch native acceleration for transformer text generation with torch.compile. Prior to 2.3 only CUDA devices were supported and this feature enables the CPU counterpart by providing highly optimized kernels for the int4 and int8 weight only quantization Linear.

For more information / how to utilize this feature please refer to the gpt-fast README.

Read More

torchtune: Easily fine-tune LLMs using PyTorch

We’re pleased to announce the alpha release of torchtune, a PyTorch-native library for easily fine-tuning large language models.

Staying true to PyTorch’s design principles, torchtune provides composable and modular building blocks along with easy-to-extend training recipes to fine-tune popular LLMs on a variety of consumer-grade and professional GPUs.

torchtune supports the full fine-tuning workflow from start to finish, including

  • Downloading and preparing datasets and model checkpoints.
  • Customizing the training with composable building blocks that support different model architectures, parameter-efficient fine-tuning (PEFT) techniques, and more.
  • Logging progress and metrics to gain insight into the training process.
  • Quantizing the model post-tuning.
  • Evaluating the fine-tuned model on popular benchmarks.
  • Running local inference for testing fine-tuned models.
  • Checkpoint compatibility with popular production inference systems.

To get started, jump right into the code or walk through our many tutorials!

Why torchtune?

Over the past year there has been an explosion of interest in open LLMs. Fine-tuning these state of the art models has emerged as a critical technique for adapting them to specific use cases. This adaptation can require extensive customization from dataset and model selection all the way through to quantization, evaluation and inference. Moreover, the size of these models poses a significant challenge when trying to fine-tune them on consumer-level GPUs with limited memory.

Existing solutions make it hard to add these customizations or optimizations by hiding the necessary pieces behind layers of abstractions. It’s unclear how different components interact with each other and which of these need to be updated to add new functionality. torchtune empowers developers to adapt LLMs to their specific needs and constraints with full control and visibility.

torchtune’s Design

torchtune was built with the following principles in mind

  • Easy extensibility – New techniques emerge all the time and everyone’s fine-tuning use case is different. torchtune’s recipes are designed around easily composable components and hackable training loops, with minimal abstraction getting in the way of fine-tuning your fine-tuning. Each recipe is self-contained – no trainers or frameworks, and is designed to be easy to read – less than 600 lines of code!
  • Democratize fine-tuning – Users, regardless of their level of expertise, should be able to use torchtune. Clone and modify configs, or get your hands dirty with some code! You also don’t need beefy data center GPUs. Our memory efficient recipes have been tested on machines with a single 24GB gaming GPU.
  • Interoperability with the OSS LLM ecosystem – The open source LLM ecosystem is absolutely thriving, and torchtune takes advantage of this to provide interoperability with a wide range of offerings. This flexibility puts you firmly in control of how you train and use your fine-tuned models.

Over the next year, open LLMs will become even more powerful, with support for more languages (multilingual), more modalities (multimodal) and more tasks. As the complexity of these models increases, we need to pay the same attention to “how” we design our libraries as we do to the features provided or performance of a training run. Flexibility will be key to ensuring the community can maintain the current pace of innovation, and many libraries/tools will need to play well with each other to power the full spectrum of use cases. torchtune is built from the ground up with this future in mind.

In the true PyTorch spirit, torchtune makes it easy to get started by providing integrations with some of the most popular tools for working with LLMs.

  • Hugging Face Hub – Hugging Face provides an expansive repository of open source models and datasets for fine-tuning. torchtune seamlessly integrates through the tune download CLI command so you can get started right away with fine-tuning your first model.
  • PyTorch FSDP – Scale your training using PyTorch FSDP. It is very common for people to invest in machines with multiple consumer level cards like the 3090/4090 by NVidia. torchtune allows you to take advantage of these setups by providing distributed recipes powered by FSDP.
  • Weights & Biases – torchtune uses the Weights & Biases AI platform to log metrics and model checkpoints during training. Track your configs, metrics and models from your fine-tuning runs all in one place!
  • EleutherAI’s LM Evaluation Harness – Evaluating fine-tuned models is critical to understanding whether fine-tuning is giving you the results you need. torchtune includes a simple evaluation recipe powered by EleutherAI’s LM Evaluation Harness to provide easy access to a comprehensive suite of standard LLM benchmarks. Given the importance of evaluation, we will be working with EleutherAI very closely in the next few months to build an even deeper and more “native” integration.
  • ExecuTorch – Models fine-tuned with torchtune can be easily exported to ExecuTorch, enabling efficient inference to be run on a wide variety of mobile and edge devices.
  • torchao – Easily and efficiently quantize your fine-tuned models into 4-bit or 8-bit using a simple post-training recipe powered by the quantization APIs from torchao.

What’s Next?

This is just the beginning and we’re really excited to put this alpha version in front of a vibrant and energetic community. In the coming weeks, we’ll continue to augment the library with more models, features and fine-tuning techniques. We’d love to hear any feedback, comments or feature requests in the form of GitHub issues on our repository, or on our Discord channel. As always, we’d love any contributions from this awesome community. Happy Tuning!

Read More

Accelerating MoE model inference with Locality-Aware Kernel Design

Accelerating MoE model inference with Locality-Aware Kernel Design

1.0 Summary

We show that by implementing column-major scheduling to improve data locality, we can accelerate the core Triton GEMM (General Matrix-Matrix Multiply) kernel for MoEs (Mixture of Experts) up to 4x on A100, and up to 4.4x on H100 Nvidia GPUs. This post demonstrates several different work decomposition and scheduling algorithms for MoE GEMMs and shows, at the hardware level, why column-major scheduling produces the highest speedup.

Repo and code available at: https://github.com/pytorch-labs/applied-ai/tree/main/triton/.

Figure 1A. Optimized Fused MoE GEMM Kernel TFLOPs on A100 for varying Batch Sizes M

Figure 1A. Optimized Fused MoE GEMM Kernel TFLOPs on A100 for varying Batch Sizes M

Figure 1B. Optimized Fused MoE GEMM Kernel TFLOPs on H100 for varying Batch Sizes M

Figure 1B. Optimized Fused MoE GEMM Kernel TFLOPs on H100 for varying Batch Sizes M

2.0 Background

OpenAI’s Triton is a hardware-agnostic language and compiler that as our prior blog post has shown can be used to accelerate quantization workflows. We also showed that in terms of kernel development, much of the same learnings and performance analysis tools from CUDA can be leveraged to provide similar insights into how Triton kernels work under-the-hood and subsequent measures to speedup these kernels in latency sensitive environments. As Triton becomes increasingly adopted in production settings, it is important that developers understand the common tips and tricks to developing performant kernels as well as the generality of these methods to various different architectures and workflows. Thus, this post will explore how we optimized the Triton kernel developed by vLLM for the popular Mixture of Experts (MoE) Mixtral model using classical techniques and how these techniques can be implemented in Triton to achieve performance gain.

Mixtral 8x7B is a sparse Mixture of Experts Language Model. Unlike the classical dense transformer architecture, each transformer block houses 8 MLP layers where each MLP is an ‘expert’. As a token flows through, a router network selects which 2 of the 8 experts should process that token and the results are then combined. The selected experts for the same token vary at each layer. As a result, while Mixtral 8x7B has a total of 47B params, during inference only 13B params are active.

The MoE GEMM (General Matrix-Matrix Multiply) kernel receives a stacked weight matrix containing all the experts, and must subsequently route each token to the TopK (2 for Mixtral) experts by utilizing a mapping array produced by the resultant scores of the router network. In this post, we provide methods to efficiently parallelize this computation during inference time, specifically during autoregression (or decoding stages).

3.0 Work Decomposition – SplitK

We have previously shown that for the matrix problem sizes found in LLM inference, specifically in the context of W4A16 quantized inference, GEMM kernels can be accelerated by applying a SplitK work decomposition. Thus, we started our MoE acceleration research by implementing SplitK in the vLLM MoE Kernel, which produced speedups of approximately 18-20% over the Data Parallel approach.

This result shows that the SplitK optimization can be used as a part of a more formulaic approach to improving/developing Triton kernels in inference settings. To build intuition about these different work decompositions, let’s consider a simple example for the multiplication of two 4×4 matrices and SplitK=2.

In the data parallel GEMM kernel shown below, the computation for a single block of the output matrix will be handled by 1 threadblock, TB0.

Figure 2. Data Parallel GEMM

Figure 2. Data Parallel GEMM

In contrast, in the SplitK kernel, the work required to compute 1 block in the output matrix, is “split” or shared amongst 2 thread blocks TB0 and TB1. This provides better load balancing and increased parallelism.

Figure 3. SplitK GEMM

Figure 3. SplitK GEMM

The key idea is that we’ve increased our parallelism from MN to MN*SplitK. This approach does incur some costs such as adding inter-threadblock communication via atomic operations. However, these costs are minimal compared to the savings of other constrained GPU resources like shared memory and registers. Most importantly, the SplitK strategy provides superior load balancing characteristics for skinny matrices, (as is the case in MoE inference) and is the common matrix profile during decoding and inference.

4.0 GEMM Hardware Scheduling – Column Major

To improve upon the ~20% speedup with SplitK we focused our investigation on the logic that controls the hardware scheduling of the GEMM in Triton Kernels. Our profiling of the vLLM MoE kernel showed a low L2 cache hit rate, thus we investigated three scheduling options – column-major, row-major and grouped launch. Due to some intrinsic properties of MoE models, such as large expert matrices, and having to dynamically load TopK (2 for Mixtral) matrices during the duration of the kernel, cache reuse/hit rate becomes a bottleneck that this optimization will target.

For background, in our previous blog, we touched on the concept of “tile swizzling”, a method to achieve greater L2 cache hit rate. This concept relates to how the software schedules the GEMM onto the SMs of a GPU. In Triton, this schedule is determined by the pid_m and pid_n calculations. Our key insight is that for skinny matrix multiplications, a column-major ordering ensures optimal reuse of the columns of the weight matrix, B. To illustrate this, let’s take a look at a snippet of what a column major computation of pid_m, and pid_n would look like:

Figure 4. Column Major ordering in PyTorch

Figure 4. Column Major ordering in PyTorch

From above, we note that with this mapping, we schedule the GEMM such that we calculate the output blocks of C in the following order: C(0, 0), C(1, 0), C(2, 0),… etc. To understand the implications we provide the following illustration:

Activation matrix / Weight matrix

L1/L2 Cache

C - Output Matrix

Figure 5. Cache Reuse Pattern for a Column-Major GEMM Schedule

In the above simplified view of a column-major schedule, let’s assume for a GEMM with skinny activation matrix A, that the entire matrix can fit in the GPU cache which is a reasonable assumption to make for the type of problem sizes we encounter in MoE inference. This allows for maximal reuse of the columns of the weight matrix B, due to the fact that the B column can be re-used for the corresponding output tile calculations, C(0,0), C(1, 0) and C(2, 0). Consider instead, a row-major schedule, C(0,0), C(0,1), C(0, 2) etc. We would have to evict the column of B, and issue multiple load instructions to DRAM to calculate the same amount of output blocks.

An important design consideration when optimizing kernels is a memory access pattern that results in the least amount of global load instructions. This optimal memory access pattern is achieved with the column-major schedule. The results below showcase the performance of the three schedules we investigated:

Figure 6. Comparison of GEMM Schedules on A100 for varying Batch Sizes M

Figure 6. Comparison of GEMM Schedules on A100 for varying Batch Sizes M

The column-major schedule provides up to a 4x speedup over the other patterns, and as we’ll show in the next section, provides an optimal memory access pattern due to greatly improved data locality.

5.0 Nsight Compute Analysis – Throughput and Memory Access Pattern

For performance analysis, we focus on the M = 2 case for the H100. A similar study can be done for the A100 as many of the same observations carry over. We note the following salient results, that showcase the impact of our optimizations.

Figure 7. H100 Memory Throughput Chart for M = 2.  Note the very large increase in the cache hit rates L1 cache hit rate (+2696%) and L2 cache hit rate (+254%).

Figure 7. H100 Memory Throughput Chart for M = 2. Note the very large increase in the cache hit rates L1 cache hit rate (+2696%) and L2 cache hit rate (+254%).

Figure 8. H100 Memory Instruction Statistics M = 2. Note the 49% reduction in global memory loads.

Figure 8. H100 Memory Instruction Statistics M = 2. Note the 49% reduction in global memory loads.

These statistics show that our optimizations had the intended effect, which can be seen in the reduced cache misses, reduced memory accesses and the resultant 2.7x speedup. More concretely, the trace shows us a 2.54x increase in L2 hit rate (Figure 7), and a ~50% reduction in DRAM accesses (Figure 8).

These improvements ultimately yield the reduced latency, with the optimized kernel being 2.7x faster for bs=2 and 4.4x for bs=512.

6.0 Future Work

Our kernel was tested in FP16, which showcases the numerics and performance of the column major scheduling for MoE, but most production models are using BFloat16. We encountered a limitation in Triton such that tl.atomic_add does not support Bfloat16 and hit launch latency concerns which would require cuda graph support for column major production use. In initial testing this translated to a 70% end-to-end speedup but, we encountered some expert mapping inconsistencies in an end to end environment that are not reflected in the test environment, so further work is needed to fully realize these speedups.

For future work, we intend to move this into a CUDA kernel which will ensure full BFloat16 support and reduced launch latency relative to Triton, and potentially resolve the expert routing inconsistency. We’ve also previously published work on enabling GPTQ W4A16 with Triton GEMM kernels, so natural follow-on work would include fusing dequantization into this kernel to allow for a GPTQ quantized inference path.

7.0 Reproducibility

We have open sourced the Triton kernel code along with an easy to run performance benchmark for readers interested in comparing or verifying the performance on their own GPU.

Acknowledgements

We want to thank Daniel Han, Raghu Ganti, Mudhakar Srivatsa, Bert Maher, Gregory Chanan, Eli Uriegas, and Geeta Chauhan for their review of the presented material and Woo Suk from the vLLM team as we built on his implementation of the Fused MoE kernel.

Read More

Maximizing training throughput using PyTorch FSDP

Maximizing training throughput using PyTorch FSDP

In this blog, we demonstrate the scalability of FSDP with a pre-training exemplar, a 7B model trained for 2T tokens, and share various techniques we used to achieve a rapid training speed of 3,700 tokens/sec/GPU, or 40B tokens/day on 128 A100 GPUs. This translates to a model FLOPS utilization (MFU) and hardware FLOPS utilization (HFU) of 57%. Additionally, we have observed near linear scaling of FSDP to 512 GPUs, implying that training a 7B model on 512 GPUs to 2T tokens using this method would take just under two weeks.

IBM researchers trained a Meta Llama 2 7B architecture to 2T tokens, which we will refer to as LlamaT(est). This model demonstrates comparable model quality as Llama 2 on various academic benchmarks. All of the training code, along with our methodology to achieve this throughput, can be found in this blog. We also share the configuration knobs that work well for the Llama 2 models – 7B, 13B, 34B, and 70B for A100s and H100s.

In this process, we also propose a _new _selective activation checkpointing mechanism that applies to FSDP which gives us a 10% boost beyond out-of-the box FSDP. We have open sourced the training code base and an associated scalable data loader as the methodology to achieve this throughput.

One key benefit of a PyTorch native pathway for training is the ability to seamlessly train on multiple hardware backends. For example, the recent end-to-end stack for training that was released by AllenAI through OLMo also leverages PyTorch FSDP for training on AMD and NVIDIA GPUs. There are three main components that we leverage from FSDP to achieve our throughput:

  1. SDPA Flash attention, that enables fused attention kernels and efficient attention computation
  2. Overlap in computation and communication allows for better utilization of the GPU
  3. Selective activation checkpointing enables us to tradeoff between GPU memory and compute

IBM has been working closely with Team PyTorch at Meta on PyTorch FSDP for nearly two years: introducing the rate limiter for achieving better throughput on Ethernet interconnects, distributed checkpointing to improve the checkpoint times by an order of magnitude, and implementing the early version of checkpointing for the hybrid sharding mode of FSDP. Late last year, we used FSDP to train a model end-to-end.

Training Details

The 7B model is trained on 128 A100 GPUs with 400Gbps network connectivity and GPU direct RDMA. We use SDPA FlashAttention v2 for attention computation, and for this model we turned off activation checkpointing that limits the batch size, but provides the highest throughput – batch size is 1 million tokens per batch for 128 GPUs and improves throughput by about 10% when compared to activation checkpointing. With these parameters, we have an almost full overlap in computation and communication. We use the AdamW optimizer in 32-bit with beta1 of 0.9 and beta2 of 0.95, weight decay of 0.1, and a learning rate ending at 3e-5 with a warmup to max learning rate of 3e-4 and a cosine schedule to reduce to 3e-5 over 2T tokens. The training was performed using mixed precision bf16 on an internal dataset. The training stack is using IBM’s Foundation Model Stack for model architecture and PyTorch nightlies post-2.2 release for FSDP and SDPA. We tried a few different nightlies during the time period of Nov 2023 through Feb 2024 and we observed an improvement in the throughput.

Selective activation checkpointing

We jointly implemented a simple and effective mechanism of selective activation checkpointing (AC). In FSDP, the common practice is to checkpoint each transformer block. A simple extension is to checkpoint every _n _blocks and reduce the amount of recomputation, while increasing the memory needed. This is quite effective for the 13B model size, increasing the throughput by 10%. For the 7B model size, we did not need activation checkpointing at all. Future versions of FSDP will provide selective activation checkpointing at an operator level, enabling an optimal compute-memory tradeoff. The code for the above is implemented here.

Throughput and MFU, HFU computation

While we only trained the 7B model to 2T tokens, we performed numerous experiments on the other model sizes to provide the best configuration options. This is summarized in the table below for two types of infrastructure — an A100 cluster with 128 GPUs and 400Gbps inter-node interconnect, and an H100 cluster with 96 GPUs and 800Gbps inter-node interconnect.

Model size

Batch size

Activation checkpoint

Throughput tokens/sec/GPU (A100 80GB and 400Gbps interconnect)

MFU % (A100 80GB)

HFU % (A100 80GB)

Throughput tokens/sec/GPU (H100 80GB and 800Gbps interconnect)

MFU % (H100 80GB)

HFU % (H100 80GB)

7B

2

No

3700

0.57

0.57

7500

0.37

0.37

13B

2

Selective

1800

0.51

0.59

3800

0.35

0.40

34B

2

Yes

700

0.47

0.64

1550

0.32

0.44

70B

2

Yes

370

0.50

0.67

800

0.34

0.45

Table 1: Model and Hardware FLOPS utilization of various model sizes on A100 and H100 GPUs

HFU numbers are computed using the PyTorch FLOP counter and the theoretical bf16 performance of A100 and H100 GPUs, whereas MFU numbers are computed using the methodology outlined in NanoGPT and the PaLM paper. We also note that the batch sizes we use for the larger models are intentionally kept at 2 per GPU to mimic choices made in training models of 4k sequence length and achieve this up to 512 GPUs without exceeding the 4M tokens popular batch size. Beyond that, we would need tensor parallelism or sequence parallelism.

We note in the table above that for A100s, that activation recomputation causes the MFU to reduce, while HFU increases! With the introduction of better activation checkpointing schemes, we expect MFU to increase and catch up with HFU. However, we observe that for H100s, both MFU and HFU are relatively low. We analyze the PyTorch profile traces on H100 and observe that there is a 10% gap due to network “peeking” out. In addition, we hypothesize that the HBM bandwidth of H100s is the cause for the reduced HFU/MFU on H100s and not being able to obtain the 3x improvement (H100s are theoretically 3x faster than A100s – 312 vs 989TFLOPS, but only have <2x the HBM bandwidth than A100s – 2.0 vs 3.35TBps). We plan to try out other configuration options like Tensor Parallel to improve the knobs for the 70B model on H100s.

Model details

The loss curve for training is shown in the below figure.

loss curve for training

Figure 1: LlamaT training loss curve

The 2T checkpoint is converted to Hugging Face format by a script that is provided in the repository and we then use lm-evaluation-harness to compute key academic benchmarks and compare that by running it on Llama2-7B. These results are captured in the below table.

Evaluation metric Llama2-7B (baseline) LlamaT-7B
MMLU (zero shot) 0.41 0.43
MMLU (5-shot weighted avg) 0.47 0.50
Arc challenge 0.46 0.44
Arc easy 0.74 0.71
Boolq 0.78 0.76
Copa 0.87 0.83
Hellaswag 0.76 0.74
Openbookqa 0.44 0.42
Piqa 0.79 0.79
Sciq 0.91 0.91
Winogrande 0.69 0.67
Truthfulqa 0.39 0.39
GSM8k (8-shot) 0.13 0.11

Table 1: LM eval harness scores

We observe that the model performs competitively with Llama2 (bolder is better).

Training chronicles

Training was stable with no crashes, though we did observe a few hiccups:

0-200B tokens: We observed a slowdown in the iteration time (time taken to execute one training step). We stopped the job to ensure that the data loader was not causing any slowdowns and the checkpointing was performant and accurate. We did not find any issues. By this time, HSDP checkpointing code was available in PyTorch, and we took this opportunity to make the switch to PyTorch checkpointing code.

200B tokens-1.9T: We did not do any manual intervention in the job in late December. When we came back early January, disk space had exceeded and checkpoints were failing to be written, although the training job continued. The last known checkpoint was 1.5T.

1.5T-1.7T: We evaluated the 1.5T checkpoint with lm-evaluation-harness and discovered that model has been trained with an extra special token between two documents due to the Hugging Face tokenizer introducing a separator token and our dataloader also appending its own document separator. We modified the dataloader to eliminate the extra special token, and continued training with the modified dataloader from 1.7T token onwards.

1.7T-2T: The loss initially spiked due to the change in the special tokens which was quickly recovered in a few billion tokens. The training finished without any other manual intervention!

Key takeaways and even more speed

We demonstrated how one can use FSDP to train a model to 2T tokens with an excellent performance of 3700 tokens/sec/GPU and that generates a good quality model. As part of this exercise, we open sourced all our code for training and the knobs to achieve this throughput. These knobs can be leveraged by not only large-scale runs, but also smaller scale tuning runs. You can find the code here.

FSDP APIs implement the ZeRO algorithms in a PyTorch native manner and allow for tuning and training of large models. In the past, we have seen FSDP proof points (Stanford Alpaca, Hugging Face, Llama 2 recipes) on tuning a variety of LLMs (such as Meta Llama 2 7B to 70B Llama) using simple training loops and achieving good throughputs and training times.

Finally, we note that there are several levers for speeding up training:

  1. Node optimizations that can speedup specific operations (e.g., attention computation using Flash Attention V2)
  2. Graph optimizations (e.g., fusing kernels, torch.compile)
  3. Overlap in compute-communications
  4. Activation recomputation

We have leveraged 1, 3, and a variation of 4 in this blog and are working closely with Team PyTorch at Meta to get torch.compile (2) as well as a more advanced version of 4 with per-operator selective activation recomputation. We plan to share a simple formatting code and example data to ingest into our data loader to enable others to use the code base for training of models.

Acknowledgements

There are several teams that have been involved in reaching this proof point and we would like to thank the teams across Meta and IBM. Specifically, we extend our gratitude to the PyTorch distributed team, Facebook Research and Applied AI teams that built the FSDP APIs and made enhancements based on our feedback. We also wish to thank the data team at IBM Research that curated the data corpus used in this exercise and the infrastructure team at IBM Research (especially, Claudia Misale, Shweta Salaria, and Seetharami Seelam) that optimized NCCL and network configurations. By building and leveraging all of these components, we have successfully demonstrated the LlamaT proof point.

The selective activation checkpointing was conceptualized at IBM by Linsong Chu, Davis Wertheimer, Mudhakar Srivatsa, and Raghu Ganti and implemented by Less Wright at Meta.

Special thanks to Stas Bekman and Minjia Zhang, who provided extensive feedback and helped improve the blog. Their insights have been invaluable in highlighting key aspects of optimizing the training and exploring further enhancements.

Appendix

Communication computation overlap

Another key aspect of training in a multi-node setting is the ability to overlap communication and computation. In FSDP, there are multiple opportunities for overlapping – during the FSDP unit gathering phase at forward pass as well as the backward pass computation. Overlapping the gather during forward pass while the computation of the previous unit and overlapping backward computation with the next unit gathering and gradient scattering help improve GPU utilization by nearly 2x. We illustrate this on the 400Gbps network interconnect with A100 80GB GPUs. In the case of HSDP, there is no inter-node traffic during the pre-fetch stage for forward pass and the overlap is only for the backward gradient computation phase. Of course, HSDP is feasible only when the model can be sharded within a single node, limiting the size of models to around 30B parameters.

The below figure shows three steps in FSDP with the communication between nodes at the bottom and the compute stream at the top of the second half of the image. For the 7B model with no activation recomputation, we observe the overlap to be complete. In practice, the overlap percentage possible is 90% since the first block during forward pass and the last block during backward pass are not able to overlap.

three steps in FSDP with the communication between nodes at the bottom and the compute stream at the top of the second half

A zoomed in view of the above three-step process is shown below for a single step. We can clearly see the granularity of the computation and communication and how they overlap in an interleaved manner.

zoomed in view of the above three-step process

Read More

PyTorch 2 paper and tutorial @ ASPLOS 2024

The PyTorch team is excited to share that our paper on PyTorch 2 has been accepted for presentation at the ACM International Conference on Architectural Support for Programming Languages and Operating Systems (ASPLOS), scheduled to take place from April 27 to May 1, 2024, in San Diego, CA, USA.

The paper delves into the implementation of torch.compile and highlights the key technologies driving it, including TorchDynamo (graph capture), TorchInductor (backend compiler), and Dynamic Shape support.

During the ASPLOS conference, we’ll be conducting a tutorial on Saturday, April 27, focusing on the inner workings of PyTorch 2 and how systems researchers can leverage and build upon it. Stay tuned for more details as the event approaches – we look forward to your participation!

A preview of the paper is attached below:

Title: PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation. Full Paper PDF

Abstract

This paper introduces two extensions to the popular PyTorch machine learning framework, TorchDynamo and TorchInductor, which implement the torch.compile feature released in PyTorch 2. TorchDynamo is a Python-level just-in-time (JIT) compiler that enables graph compilation in PyTorch programs without sacrificing the flexibility of Python. It achieves this by dynamically modifying Python bytecode before execution and extracting sequences of PyTorch operations into an FX graph, which is then JIT compiled using one of many extensible backends. TorchInductor is the default compiler backend for TorchDynamo, which translates PyTorch programs into OpenAI’s Triton for GPUs and C++ for CPUs. Results show that TorchDynamo is able to capture graphs more robustly than prior approaches while adding minimal overhead, and TorchInductor is able to provide a 2.27x inference and 1.41x training geometric mean speedup on an NVIDIA A100 GPU across 180+ real-world models, which outperforms six other compilers. These extensions provide a new way to apply optimizations through compilers in eager mode frameworks like PyTorch.

Authors

Jason Ansel (Meta); Edward Yang (Meta); Horace He (Meta); Natalia Gimelshein (OpenAI); Animesh Jain (Meta); Michael Voznesensky (Meta); Bin Bao (Meta); David Berard (Meta); Geeta Chauhan (Meta); Anjali Chourdia (Meta); Will Constable (Meta); Alban Desmaison (Meta); Zachary DeVito (Meta); Elias Ellison (Meta); Will Feng (Meta); Jiong Gong (Intel); Michael Gschwind (Meta); Brian Hirsh (Meta); Sherlock Huang (Meta); Laurent Kirsch (Meta); Michael Lazos (Meta); Yanbo Liang (Meta); Jason Liang (Meta); Yinghai Lu (Meta); CK Luk (Meta); Bert Maher (Meta); Yunjie Pan (University of Michigan); Christian Puhrsch (Meta); Matthias Reso (Meta); Mark Saroufim (Meta); Helen Suk (Meta); Michael Suo (Meta); Phil Tillet (OpenAI); Eikan Wang (Intel); Xiaodong Wang (Meta); William Wen (Meta); Shunting Zhang (Meta); Xu Zhao (Meta); Keren Zhou (OpenAI & George Mason University); Richard Zou (Meta); Ajit Mathews (Meta); Gregory Chanan (Meta); Peng Wu (Meta); Soumith Chintala (Meta)

Read More

What’s New in PyTorch Documentation

Greetings to the PyTorch community! Here is a quick update on PyTorch docs.

In November 2023, we successfully conducted a PyTorch Docathon, a community event where PyTorch community members gathered together to improve PyTorch documentation and tutorials. This event saw a global participation of contributors who dedicated their time and effort to enhance our docs. We extend our sincere gratitude to everyone involved.

A key accomplishment of the Docathon was the comprehensive work carried out on docstrings. Our community contributors meticulously reviewed and improved the docstrings based on the provided tasks.

In addition to that, we’ve added three new tutorials that showcase real-world applications of PyTorch. We are particularly proud that two of these tutorials were contributed by PyTorch ecosystem partners.

Here is the new tutorials for you to explore:

  • Whole Slide Image Classification Using PyTorch and TIAToolbox —This tutorial demonstrates how to classify Whole Slide Images (WSIs) using PyTorch deep learning models with TIAToolbox, which are images of human tissue samples used by pathologists and researchers to study diseases like cancer at the microscopic level.
  • Semi-Supervised Learning using USB built upon PyTorch – This tutorial introduces USB, a flexible and modular semi-supervised learning framework based on PyTorch, demonstrating its ease of use in training a FreeMatch/SoftMatch model on CIFAR-10 using pre-trained ViT and its adaptability to various algorithms and imbalanced datasets.
  • Deploying a PyTorch Stable Diffusion model as a Vertex AI Endpoint – This tutorial provides a step-by-step guide on how to streamline the deployment of a PyTorch Stable Diffusion model (v1.5) using Vertex AI, a fully-managed machine learning platform, by creating a custom TorchServe handler, uploading model artifacts to Google Cloud Storage, creating a Vertex AI model with the model artifacts and a prebuilt PyTorch container image, and finally deploying the model onto an endpoint.

We’re planning more community events this year, so stay tuned!

And finally, we just published new 2.2 PyTorch documentation and tutorials. Check it out!

Best regards,
The PyTorch Team

Read More

PyTorch 2.2: FlashAttention-v2 integration, AOTInductor

We are excited to announce the release of PyTorch® 2.2 (release note)! PyTorch 2.2 offers ~2x performance improvements to scaled_dot_product_attention via FlashAttention-v2 integration, as well as AOTInductor, a new ahead-of-time compilation and deployment tool built for non-python server-side deployments.

This release also includes improved torch.compile support for Optimizers, a number of new inductor optimizations, and a new logging mechanism called TORCH_LOGS.

Please note that we are deprecating macOS x86 support, and PyTorch 2.2.x will be the last version that supports macOS x64.

Along with 2.2, we are also releasing a series of updates to the PyTorch domain libraries. More details can be found in the library updates blog.

This release is composed of 3,628 commits and 521 contributors since PyTorch 2.1. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.2. More information about how to get started with the PyTorch 2-series can be found at our Getting Started page.

Summary:

  • scaled_dot_product_attention (SDPA) now supports FlashAttention-2, yielding around 2x speedups compared to previous versions.
  • PyTorch 2.2 introduces a new ahead-of-time extension of TorchInductor called AOTInductor, designed to compile and deploy PyTorch programs for non-python server-side.
  • torch.distributed supports a new abstraction for initializing and representing ProcessGroups called device_mesh.
  • PyTorch 2.2 ships a standardized, configurable logging mechanism called TORCH_LOGS.
  • A number of torch.compile improvements are included in PyTorch 2.2, including improved support for compiling Optimizers and improved TorchInductor fusion and layout optimizations.
  • Please note that we are deprecating macOS x86 support, and PyTorch 2.2.x will be the last version that supports macOS x64.
Stable Beta Performance Improvements
FlashAttention-2 Integration Inductor optimizations
AOTInductor aarch64 optimizations
TORCH_LOGS
device_mesh
Optimizer compilation

*To see a full list of public feature submissions click here.

Beta Features

[Beta] FlashAttention-2 support in torch.nn.functional.scaled_dot_product_attention

torch.nn.functional.scaled_dot_product_attention (SDPA) now supports FlashAttention-2, yielding around 2x speedups (compared to the previous version) and reaching ~50-73% of theoretical maximum FLOPs/s on A100 GPUs.

More information is available on FlashAttention-2 in this paper.

For a tutorial on how to use SDPA please see this tutorial.

[Beta] AOTInductor: ahead-of-time compilation and deployment for torch.export-ed programs

AOTInductor is an extension of TorchInductor, designed to process exported PyTorch models, optimize them, and produce shared libraries as well as other relevant artifacts. These compiled artifacts can be deployed in non-Python environments, which are frequently employed for inference on the server-side. Note that AOTInductor supports the same backends as Inductor, including CUDA, ROCm, and CPU.

For more information please see the AOTInductor tutorial.

[Beta] Fine-grained configurable logging via TORCH_LOGS

PyTorch now ships a standardized, configurable logging mechanism that can be used to analyze the status of various subsystems such as compilation and distributed operations.

Logs can be enabled via the TORCH_LOGS environment variable. For example, to set the log level of TorchDynamo to logging.ERROR and the log level of TorchInductor to logging.DEBUG pass TORCH_LOGS=”-dynamo,+inductor” to PyTorch.

For more information, please see the logging documentation and tutorial.

[Beta] torch.distributed.device_mesh

PyTorch 2.2 introduces a new abstraction for representing the ProcessGroups involved in distributed parallelisms called torch.distributed.device_mesh. This abstraction allows users to represent inter-node and intra-node process groups via an N-dimensional array where, for example, one dimension can data parallelism in FSDP while another could represent tensor parallelism within FSDP.

For more information, see the device_mesh tutorial.

[Beta] Improvements to torch.compile-ing Optimizers

A number of improvements have been made to torch.compile-ing Optimizers including less overhead and support for cuda graphs.

More technical details of the improvements are available on dev-discuss, and a recipe for torch.compile-ing optimizers is available here.

Performance Improvements

Inductor Performance Optimizations

A number of performance optimizations have been added to TorchInductor including horizontal fusion support for torch.concat, improved convolution layout optimizations, and improved scaled_dot_product_attention pattern matching.

For a complete list of inductor optimizations, please see the Release Notes.

aarch64 Performance Optimizations

PyTorch 2.2 includes a number of performance enhancements for aarch64 including support for mkldnn weight pre-packing, improved ideep primitive caching, and improved inference speed via fixed format kernel improvements to OneDNN.

For a complete list of aarch64 optimizations, please see the Release Notes.

Read More

New Library Updates in PyTorch 2.2

Summary

We are bringing a number of improvements to the current PyTorch libraries, alongside the PyTorch 2.2 release. These updates demonstrate our focus on developing common and extensible APIs across all domains to make it easier for our community to build ecosystem projects on PyTorch.

Latest Stable Library Versions (Full List)*
TorchArrow 0.1.0 TorchRec 0.6.0 TorchVision 0.17
TorchAudio 2.2.0 TorchServe 0.9.0 TorchX 0.7.0
TorchData 0.7.1 TorchText 0.17.0 PyTorch on XLA Devices 2.1

*To see prior versions or (unstable) nightlies, click on versions in the top left menu above ‘Search Docs’.

TorchRL

Feature: TorchRL’s Offline RL Data Hub

TorchRL now provides one of the largest dataset hubs for offline RL and imitation learning, and it all comes under a single data format (TED, for TorchRL Episode Data format). This makes it possible to easily swap from different sources in a single training loop. It is also now possible to easily combine datasets of different sources through the ReplayBufferEnsemble class. The data processing is fully customizable. Sources include simulated tasks (Minari, D4RL, VD4RL), robotic datasets (Roboset, OpenX Embodied dataset) and gaming (GenDGRL/ProcGen, Atari/DQN). Check these out in the documentation.

Aside from these changes, our replay buffers can now be dumped on disk using the .dumps() method which will serialize the buffers on disk using the TensorDict API which is faster, safer and more efficient than using torch.save.

Finally, replay buffers can now be read and written from separate processes on the same machine without any extra code needed from the user!

TorchRL2Gym environment API

To facilitate TorchRL’s integration in existing code-bases and enjoy all the features of TorchRL’s environment API (execution on device, batched operations, transforms…) we provide a TorchRL-to-gym API that allows users to register any environment they want in gym or gymnasium. This can be used in turn to make TorchRL a universal lib-to-gym converter that works across stateless (eg, dm_control) and stateless (Brax, Jumanji) environments. The feature is thoroughly detailed in the doc. The info_dict reading API has also been improved.

Environment speedups

We added the option of executing environments on a different environment than the one used to deliver data in ParallelEnv. We also speeded up the GymLikeEnv class to a level that now makes it competitive with gym itself.

Scaling objectives

The most popular objectives for RLHF and training at scale (PPO and A2C) are now compatible with FSDP and DDP models!

TensorDict

Feature: MemoryMappedTensor to replace MemmapTensor

We provide a much more efficient mmap backend for TensorDict; MemoryMappedTensor, which directly subclasses torch.Tensor. It comes with a bunch of utils to be constructed, such as from_tensor, empty and many more. MemoryMappedTensor is now much safer and faster than its counterpart. The library remains fully compatible with the previous class to facilitate transition.

We also introduce a new set of multithreaded serialization methods that make tensordict serialization highly competitive with torch.save, with serialization and deserialization speeds for LLMs more than 3x faster than with torch.save.

Feature: Non-tensor data within TensorDict

It is not possible to carry non-tensor data through the NonTensorData tensorclass. This makes it possible to build tensordicts with metadata. The memmap-API is fully compatible with these values, allowing users to seamlessly serialize and deserialize such objects. To store non-tensor data in a tensordict, simply assign it using the __setitem__ method.

Efficiency improvements

Several methods runtime have been improved, such as unbind, split, map or even TensorDict instantiation. Check our benchmarks!

TorchRec/fbgemm_gpu

VBE

TorchRec now natively supports VBE (variable batched embeddings) within the EmbeddingBagCollection module. This allows variable batch size per feature, unlocking sparse input data deduplication, which can greatly speed up embedding lookup and all-to-all time. To enable, simply initialize KeyedJaggedTensor with stride_per_key_per_rank and inverse_indices fields, which specify batch size per feature and inverse indices to reindex the embedding output respectively.

In addition to the TorchRec library changes, fbgemm_gpu has added the support for variable batch size per feature in TBE. VBE is enabled on split TBE training for both weighted and unweighted cases. To use VBE, please make sure to use the latest fbgemm_gpu version.

Embedding offloading

This technique refers to using CUDA UVM to cache ‘hot’ embeddings (i.e. store embedding tables on host memory with cache on HBM memory), and prefetching the cache. Embedding offloading allows running a larger model with fewer GPUs, while maintaining competitive performance. Use the prefetching pipeline (PrefetchTrainPipelineSparseDist) and pass in per-table cache load factor and the prefetch_pipeline flag through constraints in the planner to use this feature.

Fbgemm_gpu has introduced UVM cache pipeline prefetching in v0.5.0 for TBE performance speedup. This allows cache-insert to be executed in parallel with TBE forward/backward. To enable this feature, please be sure to use the latest fbgemm_gpu version.

Trec.shard/shard_modules

These APIs replace embedding submodules with its sharded variant. The shard API applies to an individual embedding module while the shard_modules API replaces all embedding modules and won’t touch other non-embedding submodules.

Embedding sharding follows similar behavior to the prior TorchRec DistributedModuleParallel behavior, except the ShardedModules have been made composable, meaning the modules are backed by TableBatchedEmbeddingSlices which are views into the underlying TBE (including .grad). This means that fused parameters are now returned with named_parameters(), including in DistributedModuleParallel.

TorchVision

The V2 transforms are now stable!

The torchvision.transforms.v2 namespace was still in BETA stage until now. It is now stable! Whether you’re new to Torchvision transforms, or you’re already experienced with them, we encourage you to start with Getting started with transforms v2 in order to learn more about what can be done with the new v2 transforms.

Browse our main docs for general information and performance tips. The available transforms and functionals are listed in the API reference. Additional information and tutorials can also be found in our example gallery, e.g. Transforms v2: End-to-end object detection/segmentation example or How to write your own v2 transforms.

Towards torch.compile() support

We are progressively adding support for torch.compile() to torchvision interfaces, reducing graph breaks and allowing dynamic shape.

The torchvision ops (nms, [ps_]roi_align, [ps_]roi_pool and deform_conv_2d) are now compatible with torch.compile and dynamic shapes.

On the transforms side, the majority of low-level kernels (like resize_image() or crop_image()) should compile properly without graph breaks and with dynamic shapes. We are still addressing the remaining edge-cases, moving up towards full functional support and classes, and you should expect more progress on that front with the next release.

Read More

Accelerating Generative AI with PyTorch IV: Seamless M4T, fast

Accelerating Generative AI with PyTorch IV: Seamless M4T, fast

This post is the fourth part of a multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch. To skip to the code, check out our github (seamless_communication, fairseq2). We are excited to share a breadth of newly released PyTorch performance features alongside practical examples to see how far we can push PyTorch native performance. In part one, we showed how to accelerate Segment Anything over 8x using only pure, native PyTorch. In part two, we showed how to accelerate Llama-7B by almost 10x using only native PyTorch optimizations. In part three, we showed how to accelerate text-to-image diffusion models up to 3x using only native Pytorch optimizations.

In this blog, we’ll focus on speeding up FAIR’s Seamless M4T-v2 model resulting in 2x speedup for text decoder module and 30x for vocoder module, resulting in 2.7x speedup for end-to-end inference, with no loss of accuracy by using CUDA Graph and native PyTorch optimization:

End to End Inference Speedup

Introduction

Seamless M4T is an open-source foundational speech/text translation and transcription technology developed by FAIR. Seamless M4T is a massively multilingual and multimodal machine translation model, with the latest version (Seamless M4T-v2) released on November 30th, 2023. The high-level model architecture of Seamless M4T-v2 is illustrated in Figure 1.

Model Architecture of Seamless M4T-v2

Figure 1. Model Architecture of Seamless M4T-v2.

Accelerating inference latency is crucial for translation models to improve user experience through faster communication across languages. In particular, batch_size=1 is often used for fast translation where latency matters a lot in applications such as chatbots, speech translation, and live subtitling. Therefore, we conducted the performance analysis on inference with batch_size=1, as shown in Figure 2 to understand the Amdahl’s Law bottleneck. Our results indicate that the text decoder and vocoder are the most time-consuming modules, accounting for 61% and 23% of the inference time, respectively.

Text decoder and vocoder are the most time consuming module. Breakdown of inference time by modules for English-Spanish S2ST (Speech-to-Speech-Text) task for batch_size=1 on A100 GPU.

Figure 2. Text decoder and vocoder are the most time consuming module. Breakdown of inference time by modules for English-Spanish S2ST (Speech-to-Speech-Text) task for batch_size=1 on A100 GPU.

To take a closer look at the performance bottleneck of the text decoder and vocoder, we analyzed GPU traces for the text decoder and vocoder for the 8th sample for the English-Spanish translation example of FLEURS dataset as shown in Figure 3. It revealed that the text decoder and vocoder are heavily CPU-bound modules. We observed a significant gap incurred by CPU overhead that delayed the launch of GPU kernels, resulting in a substantial increase in the execution time for both the modules.

CPU and GPU trace for Text Decoder

(a) CPU and GPU trace for Text Decoder

CPU and GPU trace for Vocoder

(b) CPU and GPU trace for Vocoder

Figure 3. Text Decoder and Vocoder are heavily CPU-bound modules. CPU and GPU trace for (a) Text Decoder (b) Vocoder for the 8th sample for English-Spanish translation example of FLEURS dataset. The trace is obtained by running inference with batch_size=1 on A100 gpu.

Based on the real-system performance analysis results that text_decoder and vocoder are heavily CPU bound modules in Seamless M4T-v2, we enabled torch.compile + CUDA Graph to those modules. In this post, we share modifications required to enable torch.compile + CUDA Graph on each module for batch_size=1 inference scenario, discussion on CUDA Graph and next step plans.

Torch.compile with CUDA Graph

torch.compile is a PyTorch API that allows users to compile PyTorch models into a standalone executable or script which is generally used for optimizing model performance by removing unnecessary overhead.

CUDA Graph is a feature provided by NVIDIA that allows for the optimization of kernel launches in CUDA applications. It creates an execution graph of CUDA kernels, which can be pre-processed and optimized by the driver before being executed on the GPU. The main advantage of using CUDA Graph is that it reduces the overhead associated with launching individual kernels, as the graph can be launched as a single unit, reducing the number of API calls and data transfers between the host and device. This can lead to significant performance improvements, especially for applications that have a large number of small kernels or repeat the same set of kernels multiple times. If this is something you are interested in learning more about, check out this paper that highlights the important role of data for accelerated computing: Where is the data? Why you cannot debate CPU vs. GPU performance without the answer by our own Kim Hazelwood! This is when NVIDIA was heavily investing in general-purpose GPU (GPGPUs) and before deep learning revolutionized the computing industry!

However, because CUDA Graph operates on 1) fixed memory pointer, 2) fixed shape of tensors, that are recorded at the compile time, we introduced the following improvements for CUDA Graph to be reused across multiple sizes of inputs to prevent CUDA Graph generation for each iteration and let the data inside CUDA Graph be reused across different runs to share KV Cache for multiple decoding steps.

Text Decoder

The Text Decoder in Seamless is a decoder from NLLB [1] that performs T2TT (Text to Text Translation). Also, this module is a CPU-bound model where gpu execution time is not long enough to hide CPU overhead because of the nature of auto-regressive generation that requires sequential processing of tokens, which limits the amount of parallelism that can be achieved on the GPU. Based on this observation, we enabled torch.compile + CUDA Graph for the text decoders to reduce the dominating CPU overhead as shown in Figure 4.

CPU and GPU trace for Text Decoder after torch.compile + CUDA Graph are enabled

Figure 4. CPU and GPU trace for Text Decoder after torch.compile + CUDA Graph are enabled.

1. Updating and retrieving KV cache

During inference, the text decoder has two computation phases: a prefill phase that consumes the prompt and an incremental generation phase that generates output tokens one by one. Given a high enough batch size or input length, prefill operates on a sufficiently high number of tokens in parallel — GPU performance is the bottleneck and the CPU overheads do not impact performance significantly. On the other hand, incremental token generation is always executed with sequence length 1 and it is often executed with a small batch size (even 1), e.g. for interactive use cases. Thus, incremental generation can be limited by the CPU speed and thus is a good candidate for torch.compile + CUDA Graph.

However, during the incremental token generation phase, the sequence_length dimension of key and value involved in the attention computation increases by one with each step while the sequence length of query always remains 1. Specifically, key/value are generated by appending the newly computed key/value of sequence length 1 to the key/value stored in the KV cache so far. But as mentioned above, CUDA Graph records all the shapes of tensors during compilation and replay with the recorded shapes. Thus, few modifications have been made to address this issue following the great work here.

a) We modify the KV-cache handling to take the indices in which to write new values in a CUDA Tensor (i.e., valid_seq_pos) rather than a Python integer.

Modification to KV cache append and get

Figure 5. Modification to KV cache append and get

b) We also modify attention to work with the fixed shape of key and value over the max_seq_length. We only compute softmax over the sequence positions up to the current decoding step (i.e., valid_seq_pos) . To mask out sequence positions > current decoding step (i.e., valid_seq_pos), we create a boolean mask tensor (i.e., mask) where sequence positions > valid_seq_pos are set to False.

Helper function to generate valid_seq_pos and mask

Figure 6. Helper function to generate valid_seq_pos and mask

It’s important to post that these modifications result in an increase in the amount of computation required, as we compute attention over more sequence positions than necessary (up to max_seq_length). However, despite this drawback, our results demonstrate that torch.compile + CUDA Graph still provide significant performance benefits compared to standard PyTorch code.

c) As different inference samples have different sequence length, it also generates different shapes of inputs that are to be projected to key and value for the cross attention layers. Thus, we pad the input to have a static shape and generate a padding mask to mask out padded output.

2. Memory Pointer Management

As CUDA Graph records memory pointers along with the shape of tensors, it is important to make different inference samples to correctly reference the recorded memory pointer (e.g., KV cache) to avoid compiling CUDA Graph for each inference sample. However, some parts of the Seamless codebase made different inference samples to refer to different memory addresses, so we made modifications to improve the memory implications.

e) Seamless adopts beam search as a text decoding strategy. In the beam search process, we need to perform KV cache reordering for all the attention layers for each incremental decoding step to make sure each selected beam performs with corresponding KV cache as shown in the code snippet below.

KV cache reordering operation for beam search decoding strategy

Figure 8. KV cache reordering operation for beam search decoding strategy.

The above code allocates new memory space and overwrites the original memory pointer for cache_k and cache_v. Thus we modified KV cache reordering to keep the memory pointer of each cache as was recorded during compilation by using copy_ operator.

In-place update for KV cache using copy_ operator

Figure 9. In-place update for KV cache using copy_ operator

f) After enabling torch.compile + CUDA Graph to text decoder by modifying the code as mentioned above, the overhead of text decoder shifts to KV cache reordering as shown in Figure 10. KV cache reordering repeatedly calls index_select 96 times (assuming 24 decoder layers where each layer consists of two types of attention layers with cache for key and value).

CPU and GPU trace for Text Decoder after enabling torch.compile + CUDA Graph

Figure 10. CPU and GPU trace for Text Decoder after enabling torch.compile + CUDA Graph.

As part of accelerating text decoder, we additionally applied torch.compile to KV cache reordering to benefit from fusing kernels as shown in Figure 11. Note that we cannot use CUDA Graph here (mode='max-autotune') here, because copy_ operation modifies the inputs which violates the static input requirement of CUDA graph version in torch.compile.

Applying torch.compile to KV Cache reordering

Figure 11. Applying torch.compile to KV Cache reordering.

As a result of enabling torch.compile to KV cache reordering, the gpu kernels that were launched separately (Figure 12(a)) are now fused so there are much fewer gpu kernels to launch (Figure 12(b)).

CPU and GPU trace for KV cache reordering before enabling torch.compile

(a) CPU and GPU trace for KV cache reordering before enabling torch.compile

CPU and GPU trace for KV cache reordering after enabling torch.compile

(b) CPU and GPU trace for KV cache reordering after enabling torch.compile

Figure 12. CPU and GPU trace for KV cache reordering (a) before and (b) after enabling torch.compile

Vocoder

Vocoder in Seamless is a HiFi-GAN unit-vocoder that converts generated units to waveform output where an unit is a representation of speech that combines different aspects such as phonemes and syllables, which can be used to generate sounds that are audible to humans. Vocoder is a relatively simple module that consists of Conv1d and ConvTranspose1d layers and is a CPU bound module as shown in FIgure 3. Based on this observation, we decided to enable torch.compile + CUDA Graph for vocoder to reduce the disproportionally large CPU overhead as shown in Figure 10. But there were several fixes to be made.

CPU and GPU trace for Vocoder after torch.compile + CUDA Graph are enabled

Figure 13. CPU and GPU trace for Vocoder after torch.compile + CUDA Graph are enabled.

a) The input tensor shape of the vocoder is different across different inference samples. But as CUDA Graph records the shape of tensors and replays them, we had to pad the input to the fixed size with zeros. Since vocoder only consists of Conv1d layers, we do not need an additional padding mask, and padding with zeros is sufficient.

b) Vocoder consists of conv1d layers wrapped with torch.nn.utils.weight_norm (see here). However, applying torch.compile directly to Vocoder incurs graph break as below, which leads to suboptimal performance improvement. This graph break happens inside the hook handling part in the PyTorch code of weight_norm.

[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] Graph break: setattr(UserDefinedObjectVariable) <function Module.__setattr__ at 0x7fac8f483c10> from user code at:
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/mnt/fsx-home/yejinlee/yejinlee/seamless_communication/src/seamless_communication/models/vocoder/vocoder.py", line 49, in forward
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     return self.code_generator(x, dur_prediction)  # type: ignore[no-any-return]1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/data/home/yejinlee/mambaforge/envs/fairseq2_12.1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     return forward_call(*args, **kwargs)
[2023-12-13 04:26:16,822] [1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/mnt/fsx-home/yejinlee/yejinlee/seamless_communication/src/seamless_communication/models/vocoder/codehifigan.py", line 101, in forward
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     return super().forward(x)
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/mnt/fsx-home/yejinlee/yejinlee/seamless_communication/src/seamless_communication/models/vocoder/hifigan.py", line 185, in forward
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     x = self.ups[i](x)
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/data/home/yejinlee/mambaforge/envs/fairseq2_12.1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1550, in _call_impl
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     args_result = hook(self, args)
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/data/home/yejinlee/mambaforge/envs/fairseq2_12.1/lib/python3.8/site-packages/torch/nn/utils/weight_norm.py", line 65, in __call__
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     setattr(module, self.name, self.compute_weight(module))
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] 

Since the weights of layers do not change during the inference, we do not need weight normalization. So we simply removed weight normalization for Vocoder as shown in Figure 14, by utilizing remove_weight_norm function which is already provided at the Seamless codebase (here).

Removing weight_norm for Vocoder

Figure 14. Removing weight_norm for Vocoder

Performance Evaluation + Impact of CUDA Graph

Figure 15 shows the speedup result when enabling torch.compile(mode=”max-autotune”) + CUDA Graph on the text decoder and vocoder. We achieve 2x speedup for the text decoder and 30x speedup for vocoder, leading to 2.7x faster end-to-end inference time.

Inference time speedup of text decoder and vocoder of applying torch.compile and torch.compile + CUDA Graph

Inference time speedup of text decoder and vocoder of applying torch.compile and torch.compile + CUDA Graph

Figure 15. Inference time speedup of text decoder and vocoder of applying torch.compile and torch.compile + CUDA Graph

We also report the speedups for text decoder and vocoder using torch.compile without CUDA Graph, which is supported by torch.compile’s API (i.e., torch.compile(mode="max-autotune-no-cudagraphs")), to identify the impact of CUDA Graph on the performance. Without CUDA Graph, the speedup for text decoder and vocoder reduces to 1.17x and 18.4x. While still quite significant, it indicates the important role of CUDA Graph. We conclude that Seamless M4T-v2 is exposed to a lot of time launching CUDA kernels, especially when we use small batch size (e.g., 1) where the GPU kernel execution time is not long enough to amortize the GPU kernel launch time.

End-to-end inference speedup of applying torch.compile and CUDA graph incrementally

Figure 16. End-to-end inference speedup of applying torch.compile and CUDA graph incrementally. a) “Inc. Decoding”: Apply torch.compile only to the text decoder b) “Inc. Decoding w/ CUDA Graph”: Apply torch.compile + CUDA Graph to the text decoder c) “+KV Cache Reordering”: Additionally apply torch.compile to KV cache reordering operation upon b) d) “+Vocoder”: Additionally apply torch.compile to the vocoder upon c) e) “+Vocoder w/ CUDA Graph”: Additionally apply torch.compile + CUDA Graph to the vocoder upon d).

Figure 16 represents the cumulative effect of applying torch.compile with and without CUDA Graph to the modules. The results indicate a significant improvement in the end-to-end inference speedup, demonstrating the effectiveness of these techniques in optimizing the overall latency. As a result, we gain 2.7x end-to-end inference speedup for Seamless M4T-v2 with batch_size=1.

Acknowledgements

We thank the PyTorch team and Seamless team for their tremendous support with this work.

Read More