Accelerating Generative AI with PyTorch IV: Seamless M4T, fast

Accelerating Generative AI with PyTorch IV: Seamless M4T, fast

This post is the fourth part of a multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch. To skip to the code, check out our github (seamless_communication, fairseq2). We are excited to share a breadth of newly released PyTorch performance features alongside practical examples to see how far we can push PyTorch native performance. In part one, we showed how to accelerate Segment Anything over 8x using only pure, native PyTorch. In part two, we showed how to accelerate Llama-7B by almost 10x using only native PyTorch optimizations. In part three, we showed how to accelerate text-to-image diffusion models up to 3x using only native Pytorch optimizations.

In this blog, we’ll focus on speeding up FAIR’s Seamless M4T-v2 model resulting in 2x speedup for text decoder module and 30x for vocoder module, resulting in 2.7x speedup for end-to-end inference, with no loss of accuracy by using CUDA Graph and native PyTorch optimization:

End to End Inference Speedup

Introduction

Seamless M4T is an open-source foundational speech/text translation and transcription technology developed by FAIR. Seamless M4T is a massively multilingual and multimodal machine translation model, with the latest version (Seamless M4T-v2) released on November 30th, 2023. The high-level model architecture of Seamless M4T-v2 is illustrated in Figure 1.

Model Architecture of Seamless M4T-v2

Figure 1. Model Architecture of Seamless M4T-v2.

Accelerating inference latency is crucial for translation models to improve user experience through faster communication across languages. In particular, batch_size=1 is often used for fast translation where latency matters a lot in applications such as chatbots, speech translation, and live subtitling. Therefore, we conducted the performance analysis on inference with batch_size=1, as shown in Figure 2 to understand the Amdahl’s Law bottleneck. Our results indicate that the text decoder and vocoder are the most time-consuming modules, accounting for 61% and 23% of the inference time, respectively.

Text decoder and vocoder are the most time consuming module. Breakdown of inference time by modules for English-Spanish S2ST (Speech-to-Speech-Text) task for batch_size=1 on A100 GPU.

Figure 2. Text decoder and vocoder are the most time consuming module. Breakdown of inference time by modules for English-Spanish S2ST (Speech-to-Speech-Text) task for batch_size=1 on A100 GPU.

To take a closer look at the performance bottleneck of the text decoder and vocoder, we analyzed GPU traces for the text decoder and vocoder for the 8th sample for the English-Spanish translation example of FLEURS dataset as shown in Figure 3. It revealed that the text decoder and vocoder are heavily CPU-bound modules. We observed a significant gap incurred by CPU overhead that delayed the launch of GPU kernels, resulting in a substantial increase in the execution time for both the modules.

CPU and GPU trace for Text Decoder

(a) CPU and GPU trace for Text Decoder

CPU and GPU trace for Vocoder

(b) CPU and GPU trace for Vocoder

Figure 3. Text Decoder and Vocoder are heavily CPU-bound modules. CPU and GPU trace for (a) Text Decoder (b) Vocoder for the 8th sample for English-Spanish translation example of FLEURS dataset. The trace is obtained by running inference with batch_size=1 on A100 gpu.

Based on the real-system performance analysis results that text_decoder and vocoder are heavily CPU bound modules in Seamless M4T-v2, we enabled torch.compile + CUDA Graph to those modules. In this post, we share modifications required to enable torch.compile + CUDA Graph on each module for batch_size=1 inference scenario, discussion on CUDA Graph and next step plans.

Torch.compile with CUDA Graph

torch.compile is a PyTorch API that allows users to compile PyTorch models into a standalone executable or script which is generally used for optimizing model performance by removing unnecessary overhead.

CUDA Graph is a feature provided by NVIDIA that allows for the optimization of kernel launches in CUDA applications. It creates an execution graph of CUDA kernels, which can be pre-processed and optimized by the driver before being executed on the GPU. The main advantage of using CUDA Graph is that it reduces the overhead associated with launching individual kernels, as the graph can be launched as a single unit, reducing the number of API calls and data transfers between the host and device. This can lead to significant performance improvements, especially for applications that have a large number of small kernels or repeat the same set of kernels multiple times. If this is something you are interested in learning more about, check out this paper that highlights the important role of data for accelerated computing: Where is the data? Why you cannot debate CPU vs. GPU performance without the answer by our own Kim Hazelwood! This is when NVIDIA was heavily investing in general-purpose GPU (GPGPUs) and before deep learning revolutionized the computing industry!

However, because CUDA Graph operates on 1) fixed memory pointer, 2) fixed shape of tensors, that are recorded at the compile time, we introduced the following improvements for CUDA Graph to be reused across multiple sizes of inputs to prevent CUDA Graph generation for each iteration and let the data inside CUDA Graph be reused across different runs to share KV Cache for multiple decoding steps.

Text Decoder

The Text Decoder in Seamless is a decoder from NLLB [1] that performs T2TT (Text to Text Translation). Also, this module is a CPU-bound model where gpu execution time is not long enough to hide CPU overhead because of the nature of auto-regressive generation that requires sequential processing of tokens, which limits the amount of parallelism that can be achieved on the GPU. Based on this observation, we enabled torch.compile + CUDA Graph for the text decoders to reduce the dominating CPU overhead as shown in Figure 4.

CPU and GPU trace for Text Decoder after torch.compile + CUDA Graph are enabled

Figure 4. CPU and GPU trace for Text Decoder after torch.compile + CUDA Graph are enabled.

1. Updating and retrieving KV cache

During inference, the text decoder has two computation phases: a prefill phase that consumes the prompt and an incremental generation phase that generates output tokens one by one. Given a high enough batch size or input length, prefill operates on a sufficiently high number of tokens in parallel — GPU performance is the bottleneck and the CPU overheads do not impact performance significantly. On the other hand, incremental token generation is always executed with sequence length 1 and it is often executed with a small batch size (even 1), e.g. for interactive use cases. Thus, incremental generation can be limited by the CPU speed and thus is a good candidate for torch.compile + CUDA Graph.

However, during the incremental token generation phase, the sequence_length dimension of key and value involved in the attention computation increases by one with each step while the sequence length of query always remains 1. Specifically, key/value are generated by appending the newly computed key/value of sequence length 1 to the key/value stored in the KV cache so far. But as mentioned above, CUDA Graph records all the shapes of tensors during compilation and replay with the recorded shapes. Thus, few modifications have been made to address this issue following the great work here.

a) We modify the KV-cache handling to take the indices in which to write new values in a CUDA Tensor (i.e., valid_seq_pos) rather than a Python integer.

Modification to KV cache append and get

Figure 5. Modification to KV cache append and get

b) We also modify attention to work with the fixed shape of key and value over the max_seq_length. We only compute softmax over the sequence positions up to the current decoding step (i.e., valid_seq_pos) . To mask out sequence positions > current decoding step (i.e., valid_seq_pos), we create a boolean mask tensor (i.e., mask) where sequence positions > valid_seq_pos are set to False.

Helper function to generate valid_seq_pos and mask

Figure 6. Helper function to generate valid_seq_pos and mask

It’s important to post that these modifications result in an increase in the amount of computation required, as we compute attention over more sequence positions than necessary (up to max_seq_length). However, despite this drawback, our results demonstrate that torch.compile + CUDA Graph still provide significant performance benefits compared to standard PyTorch code.

c) As different inference samples have different sequence length, it also generates different shapes of inputs that are to be projected to key and value for the cross attention layers. Thus, we pad the input to have a static shape and generate a padding mask to mask out padded output.

2. Memory Pointer Management

As CUDA Graph records memory pointers along with the shape of tensors, it is important to make different inference samples to correctly reference the recorded memory pointer (e.g., KV cache) to avoid compiling CUDA Graph for each inference sample. However, some parts of the Seamless codebase made different inference samples to refer to different memory addresses, so we made modifications to improve the memory implications.

e) Seamless adopts beam search as a text decoding strategy. In the beam search process, we need to perform KV cache reordering for all the attention layers for each incremental decoding step to make sure each selected beam performs with corresponding KV cache as shown in the code snippet below.

KV cache reordering operation for beam search decoding strategy

Figure 8. KV cache reordering operation for beam search decoding strategy.

The above code allocates new memory space and overwrites the original memory pointer for cache_k and cache_v. Thus we modified KV cache reordering to keep the memory pointer of each cache as was recorded during compilation by using copy_ operator.

In-place update for KV cache using copy_ operator

Figure 9. In-place update for KV cache using copy_ operator

f) After enabling torch.compile + CUDA Graph to text decoder by modifying the code as mentioned above, the overhead of text decoder shifts to KV cache reordering as shown in Figure 10. KV cache reordering repeatedly calls index_select 96 times (assuming 24 decoder layers where each layer consists of two types of attention layers with cache for key and value).

CPU and GPU trace for Text Decoder after enabling torch.compile + CUDA Graph

Figure 10. CPU and GPU trace for Text Decoder after enabling torch.compile + CUDA Graph.

As part of accelerating text decoder, we additionally applied torch.compile to KV cache reordering to benefit from fusing kernels as shown in Figure 11. Note that we cannot use CUDA Graph here (mode='max-autotune') here, because copy_ operation modifies the inputs which violates the static input requirement of CUDA graph version in torch.compile.

Applying torch.compile to KV Cache reordering

Figure 11. Applying torch.compile to KV Cache reordering.

As a result of enabling torch.compile to KV cache reordering, the gpu kernels that were launched separately (Figure 12(a)) are now fused so there are much fewer gpu kernels to launch (Figure 12(b)).

CPU and GPU trace for KV cache reordering before enabling torch.compile

(a) CPU and GPU trace for KV cache reordering before enabling torch.compile

CPU and GPU trace for KV cache reordering after enabling torch.compile

(b) CPU and GPU trace for KV cache reordering after enabling torch.compile

Figure 12. CPU and GPU trace for KV cache reordering (a) before and (b) after enabling torch.compile

Vocoder

Vocoder in Seamless is a HiFi-GAN unit-vocoder that converts generated units to waveform output where an unit is a representation of speech that combines different aspects such as phonemes and syllables, which can be used to generate sounds that are audible to humans. Vocoder is a relatively simple module that consists of Conv1d and ConvTranspose1d layers and is a CPU bound module as shown in FIgure 3. Based on this observation, we decided to enable torch.compile + CUDA Graph for vocoder to reduce the disproportionally large CPU overhead as shown in Figure 10. But there were several fixes to be made.

CPU and GPU trace for Vocoder after torch.compile + CUDA Graph are enabled

Figure 13. CPU and GPU trace for Vocoder after torch.compile + CUDA Graph are enabled.

a) The input tensor shape of the vocoder is different across different inference samples. But as CUDA Graph records the shape of tensors and replays them, we had to pad the input to the fixed size with zeros. Since vocoder only consists of Conv1d layers, we do not need an additional padding mask, and padding with zeros is sufficient.

b) Vocoder consists of conv1d layers wrapped with torch.nn.utils.weight_norm (see here). However, applying torch.compile directly to Vocoder incurs graph break as below, which leads to suboptimal performance improvement. This graph break happens inside the hook handling part in the PyTorch code of weight_norm.

[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] Graph break: setattr(UserDefinedObjectVariable) <function Module.__setattr__ at 0x7fac8f483c10> from user code at:
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/mnt/fsx-home/yejinlee/yejinlee/seamless_communication/src/seamless_communication/models/vocoder/vocoder.py", line 49, in forward
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     return self.code_generator(x, dur_prediction)  # type: ignore[no-any-return]1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/data/home/yejinlee/mambaforge/envs/fairseq2_12.1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     return forward_call(*args, **kwargs)
[2023-12-13 04:26:16,822] [1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/mnt/fsx-home/yejinlee/yejinlee/seamless_communication/src/seamless_communication/models/vocoder/codehifigan.py", line 101, in forward
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     return super().forward(x)
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/mnt/fsx-home/yejinlee/yejinlee/seamless_communication/src/seamless_communication/models/vocoder/hifigan.py", line 185, in forward
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     x = self.ups[i](x)
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/data/home/yejinlee/mambaforge/envs/fairseq2_12.1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1550, in _call_impl
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     args_result = hook(self, args)
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]   File "/data/home/yejinlee/mambaforge/envs/fairseq2_12.1/lib/python3.8/site-packages/torch/nn/utils/weight_norm.py", line 65, in __call__
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG]     setattr(module, self.name, self.compute_weight(module))
[1/0_2] torch._dynamo.symbolic_convert.__graph_breaks: [DEBUG] 

Since the weights of layers do not change during the inference, we do not need weight normalization. So we simply removed weight normalization for Vocoder as shown in Figure 14, by utilizing remove_weight_norm function which is already provided at the Seamless codebase (here).

Removing weight_norm for Vocoder

Figure 14. Removing weight_norm for Vocoder

Performance Evaluation + Impact of CUDA Graph

Figure 15 shows the speedup result when enabling torch.compile(mode=”max-autotune”) + CUDA Graph on the text decoder and vocoder. We achieve 2x speedup for the text decoder and 30x speedup for vocoder, leading to 2.7x faster end-to-end inference time.

Inference time speedup of text decoder and vocoder of applying torch.compile and torch.compile + CUDA Graph

Inference time speedup of text decoder and vocoder of applying torch.compile and torch.compile + CUDA Graph

Figure 15. Inference time speedup of text decoder and vocoder of applying torch.compile and torch.compile + CUDA Graph

We also report the speedups for text decoder and vocoder using torch.compile without CUDA Graph, which is supported by torch.compile’s API (i.e., torch.compile(mode="max-autotune-no-cudagraphs")), to identify the impact of CUDA Graph on the performance. Without CUDA Graph, the speedup for text decoder and vocoder reduces to 1.17x and 18.4x. While still quite significant, it indicates the important role of CUDA Graph. We conclude that Seamless M4T-v2 is exposed to a lot of time launching CUDA kernels, especially when we use small batch size (e.g., 1) where the GPU kernel execution time is not long enough to amortize the GPU kernel launch time.

End-to-end inference speedup of applying torch.compile and CUDA graph incrementally

Figure 16. End-to-end inference speedup of applying torch.compile and CUDA graph incrementally. a) “Inc. Decoding”: Apply torch.compile only to the text decoder b) “Inc. Decoding w/ CUDA Graph”: Apply torch.compile + CUDA Graph to the text decoder c) “+KV Cache Reordering”: Additionally apply torch.compile to KV cache reordering operation upon b) d) “+Vocoder”: Additionally apply torch.compile to the vocoder upon c) e) “+Vocoder w/ CUDA Graph”: Additionally apply torch.compile + CUDA Graph to the vocoder upon d).

Figure 16 represents the cumulative effect of applying torch.compile with and without CUDA Graph to the modules. The results indicate a significant improvement in the end-to-end inference speedup, demonstrating the effectiveness of these techniques in optimizing the overall latency. As a result, we gain 2.7x end-to-end inference speedup for Seamless M4T-v2 with batch_size=1.

Acknowledgements

We thank the PyTorch team and Seamless team for their tremendous support with this work.

Read More

Accelerate PyTorch Models Using Quantization Techniques with Intel Extension for PyTorch

Overview

PyTorch is a Python-based framework for developing deep learning models. It is one of the most popular industry-standard AI frameworks and is used for a wide variety of computer vision and natural language processing applications. PyTorch was developed by Meta and is now part of The Linux Foundation. Intel works with the open source PyTorch project to optimize the PyTorch framework for Intel® hardware. The newest optimizations and features are first released in Intel® Extension for PyTorch before upstreaming them into PyTorch. The Intel extension provides quantization features to deliver good accuracy results for large deep learning models.

This article introduces quantization, types of quantization, and demonstrates a code sample on how to accelerate PyTorch-based models by applying Intel Extension for PyTorch quantization.

What Is Quantization?

Quantization is a systematic reduction of the precision of all or several layers within the model. This means a higher-precision type (like single precision floating-point (FP32) that is mostly used in deep learning) is converted into a lower-precision type, such as FP16 (16 bits) or int8 (8 bits).

This helps to achieve:

  • Lower memory bandwidth
  • Lower storage
  • Higher performance with minimum to zero accuracy loss

Quantization is especially important with large models such as those based on the Transformer architecture (like BERT or GPT).

There are two types of quantization:

  • Static: This quantizes the weights and activations of the model, and is used when memory bandwidth and compute savings are important.
  • Dynamic: The weights are quantized ahead of time, but the activations are dynamically quantized during inference.

How to Perform Static Quantization and Dynamic Quantization

The Intel extension extends PyTorch with up-to-date features and optimizations for an extra performance boost on Intel hardware.

Installation Instructions for Intel Extension for PyTorch

The extension can be loaded as a Python module or linked as a C++ library. Python users can enable it dynamically by importing intel_extension_for_pytorch. The extension provides built-in quantization to deliver good statistical accuracy for most popular deep learning workloads including convolutional neural networks (CNN), natural language processing (NLP), and recommendation models. The quantization functionality in the Intel extension currently supports post-training quantization.

To quantize the existing FP32 model to an int8 model using static quantization:

  1. Prepare the quantization configuration. For default static quantization configuration, use ipex.quantization.default_static_qconfig.
  2. Prepare the model for calibration using the ipex.quantization.prepare method.
  3. Perform calibration against the dataset. This calibration is specific for static quantization as it needs the representative dataset to determine the optimal quantization parameters, so the user should provide data to the model in batches to calibrate it.
  4. Convert the model from FP32 to int8 using the ipex.quantization.convert method. This function converts the FP32 model to int8 based on the applied calibration and configuration.

To quantize the existing FP32 model to an int8 model using dynamic quantization, which is similar to static quantization:

  1. Prepare the quantization configuration. For default dynamic quantization configuration, use ipex.quantization.default_dynamic_qconfig.
  2. Prepare the FP32 model by using the ipex.quantization.prepare method. Provide the parameters, such as FP32 model to quantize, the prepared configuration, example inputs, and information.
  3. Convert the model from FP32 to int8 using the ipex.quantization.convert method. The input model is the model prepared in Step 2.

Code Sample

Dataset

For static quantization, the model is calibrated with the CIFAR-10 dataset. The CIFAR-10 is a subset of the 80 million tiny images dataset collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.

This dataset contains 60,000 images in 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and track). Every class has exactly 6,000 images. All images are 32 x 32 pixels and are colored. Also, the classes are completely mutually exclusive, which means there is no overlapping between classes.

Implementation

The code sample demonstrates how to quantize (using static and dynamic quantization) a ResNet*-50 model using Intel Extension for PyTorch. The following steps are implemented in the code sample:

Download and Prepare the Dataset

Here, we use the CIFAR-10 dataset available in torchvision.

  1. To make data fit the model:
  • Transform the data.
  • Change the size of the images from 32 x 32 pixels to 224 x 224 pixels.
  • Convert them to tensors.
  • Normalize them.
  1. Prepare transformations of the dataset as shown:
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((224, 224)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  1. Initialize the dataset.
test_dataset = torchvision.datasets.CIFAR10(root=DATA, train=False, transform=transform, download=Ture)

Prepare the Data Loader

To load a dataset for static quantization calibration in specific size batches, create the loader as shown:

calibration_data_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=128
)

Create the Model

Use the pretrained ResNet-50 model available in the Torchvision library with default weights. The prepared model is FP32.

model_fp32 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)

Apply Static Quantization

Create a staticQuantize function that implements the steps described previously.

  1. To perform static quantization, we need:
  • FP32 model loaded earlier
  • Example data
  • Calibration dataset
  1. Prepare the quantization configuration:
config_static = ipex.quantization.default_static_qconfig

In this code sample, we are using the default quantization configuration, but you can also define your own.

  1. Prepare the model using the declared configuration:
prepared_model_static = prepare(model_fp32,
qconfig_static,
example_inputs=data,
inplace=False)
  1. Calibrate the model with the calibration dataset. Feed the model with successive batches of data from the dataset.
for batch_idx, (data, target) in enumerate(calibration_data_loader):
prepared_model_static(data)
if batch_idx % 10 == 0:
print("Batch %d/%d complete, continue ..." %(batch_idx+1, len(calibration_data_loader)))
  1. Convert the model.
converted_model_static = convert(prepared_model_static)

Apply Dynamic Quantization

Create the dynamicQuantize function similar to the staticQuantize function.

  1. To perform dynamic quantization, we only need:
  • The FP32 model loaded earlier
  • Example data
  1. Prepare the quantization configuration:
qconfig_dynamic = ipex.quantization.default_dynamic_qconfig
  1. Prepare the model.
prepared_model_dynamic = prepare(model_fp32,
qconfig_dynamic,
example_inputs=data,
inplace=False)
  1. Convert the model from FP32 to int8.
converted_model_dynamic = convert(prepared_model_dynamic)

In this way, two functions are created to take advantage of the optimizations that quantization offers:

  • DynamicQuantize for dynamic quantization of models
  • StaticQuantize for static model quantization

Next Steps

Get started with Intel Extension for PyTorch quantization today and use it to achieve better accuracy results for deep learning workloads. Additionally, Intel® Neural Compressor provides quantization to improve the speed of inference.

Check out and incorporate Intel’s other AI and machine learning framework optimizations and end-to-end portfolio of tools into your AI workflow.

Learn about the unified, open, standards-based oneAPI programming model that forms the foundation of Intel’s AI Software Portfolio to help you prepare, build, deploy, and scale your AI solutions.

For more details about the 4th gen Intel® Xeon® Scalable processors, visit the Intel® AI platform overview where you can learn how Intel is empowering developers to run end-to-end AI pipelines on these powerful CPUs.

Additional Resources

Read More

Accelerating Triton Dequantization Kernels for GPTQ

Accelerating Triton Dequantization Kernels for GPTQ

TL;DR

Leveraging a first principles approach, we showcase a step by step process undertaken to accelerate the current Triton GPTQ kernels by 3x (core GPTQ) and 6x (AutoGPTQ). Example: 275us to 47us on a typical Llama style inference input. The goal is to provide a helpful template for accelerating any given Triton kernel. We provide a background on Triton and GPTQ quantization and dequantization process, showcase the impact of coalesced memory access to improve shared and global memory throughput, highlight changes made to reduce warp stalling to improve total throughput, and an overview on integrating Triton kernels into PyTorch code. Longer term, we hope to surpass the existing CUDA native GPTQ kernel with our Triton kernel.

Fig 1: Performance benchmarking the optimized AutoGTPQ kernel vs the current AutoGPTQ kernel on H100

Fig 1: Performance benchmarking the optimized AutoGTPQ kernel vs the current AutoGPTQ kernel on H100

Fig 2: Performance benchmarking the newly optimized AutoGTPQ kernel vs the current AutoGPTQ kernel on A100

Fig 2: Performance benchmarking the newly optimized AutoGTPQ kernel vs the current AutoGPTQ kernel on A100

Fig 3: Even with these improvements, there remains a gap between our optimized Triton kernel and the CUDA native AutoGTPQ kernel on A100.

Fig 3: Even with these improvements, there remains a gap between our optimized Triton kernel and the CUDA native AutoGTPQ kernel on A100. More to come…

1.0 Introduction to Triton

The Triton framework provides a hardware agnostic way of programming and targeting GPUs, currently supporting both NVIDIA and AMD, with support for additional hardware vendors in progress. Triton is now a mainstay for PyTorch 2.0 as torch.compile decomposes eager PyTorch and re-assembles it into a high percentage of Triton kernels with PyTorch connecting code.

As Triton becomes more widely adopted, it will be essential that programmers understand how to systematically step through the Triton stack (from the high level Python down to the low-level SASS) to address performance bottlenecks in order to optimize GPU efficiency for algorithms that go beyond torch.compile generated kernels.

In this post, we will introduce some core concepts of the Triton programming language, how to identify common performance limiters in GPU kernels, and in parallel, tune a quantization kernel used in AutoGPTQ that can be used for high throughput inference applications.

Intro to GPTQ Quantization and Dequantization

GPTQ is a quantization algorithm that is able to compress ultra-large (175B+) LLMs efficiently to int4 bit representation, via approximate second order information (Hessian inverse). AutoGPTQ is a framework built on GPTQ, allowing for rapid dequantization and inference/serving of LLMs that have been quantized with GPTQ.

As part of the AutoGPTQ stack, they provide a Triton GPTQ kernel to handle the dequantization of a model for inference.

The basic process for INT quantization is shown below and involves determining the scale and zero point, and then computing the quantized 4bit Weight using the Scale and Zero point:

The basic process for INT quantization

We thus store the 4 Bit weights along with the meta information of Scale and ZeroPoint for each group of weights.

To ‘dequant’ these weights, we do the following:

To ‘dequant’ these weights

And then proceed to Matrix Multiply the dequantized weights with the dense input feature matrix for this linear layer.

2.0 Identify the Bottlenecks – Optimizing Matrix Multiplication

As it turns out, making a fast matrix multiplication kernel is not trivial. A naively implemented matrix multiply will rarely reach peak throughput performance on highly parallel machines like GPUs. So – we need to tackle our compute and memory subsystems in our GPU in an hierarchical fashion to make sure we are maximally utilizing each resource.

We start our optimization process, by running the unoptimized Triton Kernel, through the Nvidia Nsight Compute tool and taking a note of some important metrics and warnings:

some important metrics and warnings

Fig xy (todo)

some important metrics and warnings

We notice first that both compute and memory throughput are low, 7.40% and 21.19% respectively (fig xy) . Knowing that for typical inference matrix problem sizes, we are in the memory bound regime, we will attempt to optimize the kernel by applying code changes that target the memory subsystem of our A100 GPU.

The three topics this post will cover are:

  1. L2 Optimization
  2. Vectorized Load
  3. Warp Stalling

Let’s walk through each topic, make the appropriate changes, and see its corresponding impact on our Triton Kernel. This Triton kernel is a fused dequantization kernel that dequantizes a packed int32 weight (we will refer to this as the B Matrix) Tensor into int4 weights, performs matrix multiplication with the activation tensor (refer to as the A matrix) in FP16 mode, and then stores the results back to a matrix C.

The above is referred to as W4A16 quantization. Keep in mind that the process we describe can and should be used for the development of any GPU kernel, as these are common bottlenecks in any unoptimized kernel.

3.0 L2 Optimization

This optimization already exists in the AutoGPTQ kernel, but we’d like to dedicate a section to this to help readers better understand how mapping and execution order of thread blocks is handled in Triton. Thus, we will step through a naive mapping and then a more optimal mapping to see its corresponding impact.

Let’s build up our kernel naively, starting with a “linear” load from global memory and then compare it to a more optimized “swizzled” load. Linear vs Swizzled determines the execution order of our grid of work on the GPU. Let’s take a look at the hints that the Nvidia Nsight Compute Tool provides regarding our kernels shared memory access pattern in the naive case:

the hints from the Nvidia Nsight Compute Tool

To tackle this issue we can use an approach referred to as “tile-swizzling.” The idea of this method is to launch our thread blocks in a more L2 cache friendly order.

Let’s take a step back and familiarize ourselves with some Triton semantics and make a simple CUDA analogy to understand the concept better. Triton kernels launch “programs”. These so-called programs map to the concept of a Thread Block in CUDA and it is the basic unit of parallelism in a Triton Kernel. Every program has with it associated a “pid” and all the threads in a program are guaranteed to be executing the same instruction.

The Triton programs will be distributed onto your SMs in a naive-way if you do a simple linear mapping of “pid” to a 2D grid location of your output matrix C.

This 2D grid location is determined by pid_m and pid_n in Triton. We would like to exploit data and cache locality in the L2 cache of our GPU, when we distribute our grid of work. To do this in Triton we can make the following changes:

To do this in Triton

The code highlighted in red would be the naive “linear” tile ordering, and the code highlighted in green is the “swizzled” tile ordering. This type of launch promotes a sense of locality. Here is a visual to help understand this better.

a sense of locality

After incorporating this change, the profiler no longer complains about uncoalesced memory accesses. Let’s take a look at how our memory throughput has changed:

how our memory throughput has changed

This change was tested on a simple load store kernel. Looking at the GPU speed of light statistics section in the profiler we also see a 112.07% increase in the memory throughput of the simple load kernel, which is what we were after with this optimization. Again, this optimization already exists in the AutoGPTQ kernel, but is the boilerplate logic that every Triton Kernel programmer will have to write in the beginning of their kernel, before any of the exciting dequantization or matrix multiply logic. It is thus important to understand that:

  1. This mapping is not unique

  2. Triton does not automatically handle this kind of optimization for the programmer, and careful thought must be taken to ensure your kernel is optimally handling shared memory accesses

These are not obvious for those new to Triton, as much of the shared memory access optimization is handled by the Triton compiler. However, in the cases where these are not handled by the compiler, it is important to be able to understand what tools and methods are available to us to be able to influence memory behavior.

4.0 Vectorized Load

Now, back to the original complaints of our unoptimized kernel. We want to optimize the global memory access pattern of our kernel. From the details page of the Nvidia Nsight compute tool, we see the following note, where the profiler is complaining about uncoalesced global memory accesses.

Let’s dig deeper and take a look at the SASS (Assembly) Code load for an unoptimized memory read:

an unoptimized memory read

This load operation resulted in 32 global load operations that are 16 bit wide. This is not optimal.

We would like to do our global memory loads in a vectorized way so that it results in the least amount of load instructions. To combat this we can give the Triton Compiler some help.

code block

The green highlighted lines above act as a compiler hint. It tells the compiler that these elements are contiguous in memory and that this load operation can be coalesced.

Let’s see the effect in assembly after adding these lines.

the effect in assembly after adding these lines

The load is now performed in 4 global load operations that are each 128 bit wide, instead of 32 16 bit global load operations. This means 28 fewer memory fetch instructions, and importantly a coalesced memory access. This can be seen from the fact that a single thread is not accessing consecutive memory addresses, which without the compiler hint, was the behavior.

The resulting effect is 73x speedup in an isolated load operation, and after incorporating it in the full dequantization kernel we were able to see another 6% speedup. Another step in the right direction!

5.0 Warp Stalling

performance limiter, warp stalling

Now putting all the changes back into our full dequantization kernel, we see the following performance limiter, warp stalling.

These warp stalls are mostly caused by ‘Long Scoreboard’ stalls, accounting for 92.63% of the total.

At a high level, long scoreboard stalls happen when a warp requires data that may not be ready yet in order to be in the “issued” state. In other words GPUs are throughput machines, and we need to hide the latency of load instructions with compute instructions. By loading more data and rearranging where the load instructions are in the script we can take care of this problem.

In an ideal scenario, each warp scheduler would be able to issue 1 instruction every clock cycle. Note – Every SM on an A100 GPU has 4 warp schedulers.

However – our kernel has bottlenecks and is spending 4.4 cycles in the stall state with the block size that AutoGPTQ Triton kernel deems as optimal given the presets it has.

How do we improve this?

We want to be able to increase our memory throughput so that we can increase the chance that when a warp issues an instruction, we won’t be waiting for loads to be stored in SRAM so that they can be used for computation. We played around with multiple parameters (such as number of pipeline stages, and number of warps) and the one that had the biggest impact was increasing the block size by a factor of 2 in the k dimension.

These changes yield an immediate impact on both compute and memory throughput.

an immediate impact on both compute and memory throughput

We also see the long scoreboard wait time at the step where we shift and scale the quantized weights drop significantly, which is what we identified as the original bottleneck in the source code. While there are still stalls at this line, only 68% of them are caused by long scoreboard stalls, compared to the original 92%. Ideally, we do not observe ANY stalls, so there is still work to be done here, but a reduction in the amount of stalls caused by long scoreboard tells us that our data is at this point ready to be used (in L1TEX) memory by an instruction that a warp wants to execute, at a higher frequency then the original kernel.

1.4x speedup in the execution time of our kernel

The corresponding impact is a 1.4x speedup in the execution time of our kernel.

6.0 Results

By tackling all these problem areas methodically our resulting kernel is 6x faster on the Nvidia A100 GPU than if you were to use the Triton kernel AutoGPTQ provides out-of-the-box.

Taking a relevant Llama inference sample data point, the Triton kernel we’ve developed takes 47us to perform dequantization and matrix multiplication compared to the 275us taken by the AutoGPTQ kernel for the same matrix size.

By replicating this step-by-step approach it should be possible to get similar speedups in other kernels, and help build understanding on common GPU bottlenecks and how to tackle them.

It is important to note that while strides have been made in improving the performance of the AutoGPTQ Triton Kernel, we have still not closed the gap on the current exllamaV2 CUDA kernels found in AutoGPTQ.

More research is required to understand how we can further optimize this kernel to match equivalent custom CUDA kernel performance.

Summary and Future work

Triton extends PyTorch by allowing low level GPU optimizations to be done at a higher level of abstraction than CUDA programming, with the net result that adding optimized Triton kernels can help PyTorch models run faster.

Our goal in this post was to show an example of accelerating the GPTQ dequant kernel and provide a template workflow for how the accelerations were achieved.

For future work, SplitK work decomposition for the matrix multiplication is a potential speed up we’ll investigate.

Integrating custom Triton Kernels into PyTorch

Given the acceleration shown above, a common question is how to actually use a custom kernel in a given PyTorch codebase.

A triton kernel will contain at least two parts – the actual Triton kernel code which will be compiled by the Triton compiler:

the actual Triton kernel code which will be compiled by the Triton compiler

Along with the actual kernel code is a python wrapper, that may or may not subclass the PyTorch autograd class – depending if it’s going to support a backwards pass (i.e. for training purposes or only for inference purposes).

You simply import the python class into your PyTorch code where you want to use it much like any other Python / PyTorch function.

import the python class into your PyTorch code

In this case, simply importing and then using ‘fast_qlinear’ would invoke the underlying Triton kernel with the speed-ups we’ve shown above applied to your PyTorch model.

Acknowledgements

Thanks to Jamie Yang and Hao Yu from IBM Research for their technical guidance in the collection of these results.

Read More

Finetune LLMs on your own consumer hardware using tools from PyTorch and Hugging Face ecosystem

Finetune LLMs on your own consumer hardware using tools from PyTorch and Hugging Face ecosystem

We demonstrate how to finetune a 7B parameter model on a typical consumer GPU (NVIDIA T4 16GB) with LoRA and tools from the PyTorch and Hugging Face ecosystem with complete reproducible Google Colab notebook.

Introduction

Large Language Models (LLMs) have shown impressive capabilities in industrial applications. Often, developers seek to tailor these LLMs for specific use-cases and applications to fine-tune them for better performance. However, LLMs are large by design and require a large number of GPUs to be fine-tuned.

Let’s focus on a specific example by trying to fine-tune a Llama model on a free-tier Google Colab instance (1x NVIDIA T4 16GB). Llama-2 7B has 7 billion parameters, with a total of 28GB in case the model is loaded in full-precision. Given our GPU memory constraint (16GB), the model cannot even be loaded, much less trained on our GPU. This memory requirement can be divided by two with negligible performance degradation. You can read more about running models in half-precision and mixed precision for training here.

What makes our Llama fine-tuning expensive?

In the case of full fine-tuning with Adam optimizer using a half-precision model and mixed-precision mode, we need to allocate per parameter:

  • 2 bytes for the weight
  • 2 bytes for the gradient
  • 4 + 8 bytes for the Adam optimizer states

→ With a total of 16 bytes per trainable parameter, this makes a total of 112GB (excluding the intermediate hidden states). Given that the largest GPU available today can have up to 80GB GPU VRAM, it makes fine-tuning challenging and less accessible to everyone. To bridge this gap, Parameter Efficient Fine-Tuning (PEFT) methods are largely adopted today by the community.

Parameter Efficient Fine-Tuning (PEFT) methods

PEFT methods aim at drastically reducing the number of trainable parameters of a model while keeping the same performance as full fine-tuning.

They can be differentiated by their conceptual framework: does the method fine-tune a subset of existing parameters, introduce new parameters, introduce trainable prompts, etc.? We recommend readers to have a look at the paper shared below that extensively compares existing PEFT methods.

Venn diagram

Image taken from the paper: Scaling Down to Scale Up: A Guide to Parameter-Efficient Fine-Tuning

For this blog post, we will focus on Low-Rank Adaption for Large Language Models (LoRA), as it is one of the most adopted PEFT methods by the community.

Low-Rank Adaptation for Large Language Models (LoRA) using 🤗 PEFT

The LoRA method by Hu et al. from the Microsoft team came out in 2021, and works by attaching extra trainable parameters into a model(that we will denote by base model).

To make fine-tuning more efficient, LoRA decomposes a large weight matrix into two smaller, low-rank matrices (called update matrices). These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn’t receive any further adjustments. To produce the final results, both the original and the adapted weights are combined.

This approach has several advantages:

  • LoRA makes fine-tuning more efficient by drastically reducing the number of trainable parameters.
  • The original pre-trained weights are kept frozen, which means you can have multiple lightweight and portable LoRA models for various downstream tasks built on top of them.
  • LoRA is orthogonal to many other parameter-efficient methods and can be combined with many of them.
  • The performance of models fine-tuned using LoRA is comparable to the performance of fully fine-tuned models.
  • LoRA does not add any inference latency when adapter weights are merged with the base model

In principle, LoRA can be applied to any subset of weight matrices in a neural network to reduce the number of trainable parameters. However, for simplicity and further parameter efficiency, in Transformer models LoRA is typically applied to attention blocks only. The resulting number of trainable parameters in a LoRA model depends on the size of the low-rank update matrices, which is determined mainly by the rank r and the shape of the original weight matrix.

Animated diagram that show how LoRA works in practice

Animated diagram that show how LoRA works in practice – original content adapter from the figure 1 of LoRA original paper

Below is a code snippet showing how to train LoRA model using Hugging Face PEFT library:

code snippet showing how to train LoRA model using  Hugging Face PEFT library

The base model can be in any dtype: leveraging SOTA LLM quantization and loading the base model in 4-bit precision

According to the LoRA formulation, the base model can be compressed in any data type (‘dtype’) as long as the hidden states from the base model are in the same dtype as the output hidden states from the LoRA matrices.

Compressing and quantizing large language models has recently become an exciting topic as SOTA models become larger and more difficult to serve and use for end users. Many people in the community proposed various approaches for effectively compressing LLMs with minimal performance degradation.

This is where the bitsandbytes library comes in. Its purpose is to make cutting-edge research by Tim Dettmers, a leading academic expert on quantization and the use of deep learning hardware accelerators, accessible to the general public.

QLoRA: One of the core contributions of bitsandbytes towards the democratization of AI

Quantization of LLMs has largely focused on quantization for inference, but the QLoRA (Quantized model weights + Low-Rank Adapters) paper showed the breakthrough utility of using backpropagation through frozen, quantized weights at large model scales.

With QLoRA we are matching 16-bit fine-tuning performance across all scales and models, while reducing fine-tuning memory footprint by more than 90%— thereby allowing fine-tuning of SOTA models on consumer-grade hardware.

In this approach, LoRA is pivotal both for purposes of fine-tuning and the correction of minimal, residual quantization errors. Due to the significantly reduced size of the quantized model it becomes possible to generously place low-rank adaptors at every network layer, which together still make up just 0.2% of the original model’s weight memory footprint. Through such usage of LoRA, we achieve performance that has been shown to be equivalent to 16-bit full model finetuning.

System diagram

In addition to generous use of LoRA, to achieve high-fidelity fine-tuning of 4-bit models, QLoRA uses 3 further algorithmic tricks:

  1. 4-bit NormalFloat (NF4) quantization, a custom data type exploiting the property of the normal distribution of model weights and distributing an equal number of weights (per block) to each quantization bin—thereby enhancing information density.
  2. Double Quantization, quantization of the quantization constants (further savings).
  3. Paged Optimizers, preventing memory spikes during gradient checkpointing from causing out-of-memory errors.

An interesting aspect is the dequantization of 4-bit weights in the GPU cache, with matrix multiplication performed as a 16-bit floating point operation. In other words, we use a low-precision storage data type (in our case 4-bit, but in principle interchangeable) and one normal precision computation data type. This is important because the latter defaults to 32-bit for hardware compatibility and numerical stability reasons, but should be set to the optimal BFloat16 for newer hardware supporting it to achieve the best performance.

To conclude, through combining these refinements to the quantization process and generous use of LoRA, we compress the model by over 90% and retain full model performance without the usual quantization degradation, while also retaining full fine-tuning capabilities with 16-bit LoRA adapters at every layer.

Using QLoRA in practice

These SOTA quantization methods come packaged in the bitsandbytes library and are conveniently integrated with HuggingFace 🤗 Transformers. For instance, to use LLM.int8 and QLoRA algorithms, respectively, simply pass load_in_8bit and load_in_4bit to the from_pretrained method.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "facebook/opt-125m"
# For LLM.int8()
# model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)

# For QLoRA
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)

You can read more about quantization features in this specific section of the documentation: https://huggingface.co/docs/transformers/main_classes/quantization

When using QLoRA with Adam optimizer using a 4-bit base model and mixed-precision mode, we need to allocate per parameter:

  • ~0.5 bytes for the weight
  • 2 bytes for the gradient
  • 4 + 8 bytes for the Adam optimizer states

Giving a total of 14 bytes per trainable parameter times 0.0029 as we end up having only 0.29% trainable parameters with QLoRA, this makes the QLoRA training setup cost around 4.5GB to fit, but requires in practice ~7-10GB to include intermediate hidden states which are always in half-precision (7 GB for a sequence length of 512 and 10GB for a sequence length of 1024) in the Google Colab demo shared in the next section.

Below is the code snippet showing how to train QLoRA model using Hugging Face PEFT:

code snippet showing how to train QLoRA model using Hugging Face PEFT

Using TRL for LLM training

Models such as ChatGPT, GPT-4, and Claude are powerful language models that have been fine-tuned using a method called Reinforcement Learning from Human Feedback (RLHF) to be better aligned with how we expect them to behave and would like to use them. The finetuning goes through 3 steps:

  • Supervised Fine-tuning (SFT)
  • Reward / preference modeling (RM)
  • Reinforcement Learning from Human Feedback (RLHF)

Process diagram

From InstructGPT paper: Ouyang, Long, et al. “Training language models to follow instructions with human feedback.” arXiv preprint arXiv:2203.02155 (2022).

Here, we will only focus on the supervised fine-tuning step. We train the model on the new dataset following a process similar to that of pretraining. The objective is to predict the next token (causal language modeling). Multiple techniques can be applied to make the training more efficient:

  • Packing: Instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with an End-Of-Sentence (EOS) token in between and cut chunks of the context size to fill the batch without any padding. This approach significantly improves training efficiency as each token processed by the model contributes to training.

Sample diagram

  • Train on completion only: We want the model to be able to understand the prompt and generate an answer/. Instead of training the model on the whole input (prompt + answer), the training will be more efficient if we only train the model on completion.

You can perform supervised fine-tuning with these techniques using SFTTrainer:

from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
    dataset_text_field="text",
    max_seq_length=1024,
    packing=True,
)

Since SFTTrainer back-end is powered by 🤗accelerate, you can easily adapt the training to your hardware setup in one line of code!

For example, with you have 2 GPUs, you can perform Distributed Data Parallel training with using the following command:

accelerate launch --num_processes=2 training_llama_script.py

Putting all the pieces together

We made a complete reproducible Google Colab notebook that you can check through this link. We use all the components shared in the sections above and fine-tune a llama-7b model on UltraChat dataset using QLoRA. As it can be observed through the screenshot below, when using a sequence length of 1024 and a batch size od 4, the memory usage remains very low (around 10GB).

Memory usage diagram

Read More

Accelerate AI models on GPU using Amazon SageMaker multi-model endpoints with TorchServe, saving up to 75% on inference costs

Accelerate AI models on GPU using Amazon SageMaker multi-model endpoints with TorchServe, saving up to 75% on inference costs

Multi-model endpoints (MMEs) are a powerful feature of Amazon SageMaker designed to simplify the deployment and operation of machine learning (ML) models. With MMEs, you can host multiple models on a single serving container and host all the models behind a single endpoint. The SageMaker platform automatically manages the loading and unloading of models and scales resources based on traffic patterns, reducing the operational burden of managing a large quantity of models. This feature is particularly beneficial for deep learning and generative AI models that require accelerated compute. The cost savings achieved through resource sharing and simplified model management makes SageMaker MMEs an excellent choice for you to host models at scale on AWS.

Recently, generative AI applications have captured widespread attention and imagination. Customers want to deploy generative AI models on GPUs but at the same time are conscious of costs. SageMaker MMEs support GPU instances and is a great option for these types of applications. Today, we are excited to announce TorchServe support for SageMaker MMEs. This new model server support gives you the advantage of all the benefits of MMEs while still using the serving stack that TorchServe customers are most familiar with. In this post, we demonstrate how to host generative AI models, such as Stable Diffusion and Segment Anything Model, on SageMaker MMEs using TorchServe and build a language-guided editing solution that can help artists and content creators develop and iterate their artwork faster.

Solution overview

Language-guided editing is a common cross-industry generative AI use case. It can help artists and content creators work more efficiently to meet content demand by automating repetitive tasks, optimizing campaigns, and providing a hyper-personalized experience for the end customer. Businesses can benefit from increased content output, cost savings, improved personalization, and enhanced customer experience. In this post, we demonstrate how you can build language-assisted editing features using MME TorchServe that allow you to erase any unwanted object from an image and modify or replace any object in an image by supplying a text instruction.

The user experience flow for each use case is as follows:

  • To remove an unwanted object, the select the object from the image to highlight it. This action sends the pixel coordinates and the original image to a generative AI model, which generates a segmentation mask for the object. After confirming the correct object selection, you can send the original and mask images to a second model for removal. The detailed illustration of this user flow is demonstrated below.

Dog on a bench with mouse pointer clicking the dog

Dog on a bench highlighted

A bench without the dog

Step 1: Select an object (“dog”) from the image Step 2: Confirm the correct object is highlighted Step 3: Erase the object from the image
  • To modify or replace an object, the select and highlight the desired object, following the same process as described above. Once you confirm the correct object selection, you can modify the object by supplying the original image, the mask, and a text prompt. The model will then change the highlighted object based on the provided instructions. A detailed illustration of this second user flow is as follows.

A vase with a cactus and mouse pointer

A vase highlighted

A rounded vase with a cactus

Step 1: Select an object (“vase”) from the image Step 2: Confirm the correct object is highlighted Step 3: Provide a text prompt (“futuristic vase”) to modify the object

To power this solution, we use three generative AI models: Segment Anything Model (SAM), Large Mask Inpainting Model (LaMa), and Stable Diffusion Inpaint (SD). Here are how these models been utilized in the user experience workflow:

To remove an unwanted object To modify or replace an object

flow diagram

flow diagram

  1. Segment Anything Model (SAM) is used to generate a segment mask of the object of interest. Developed by Meta Research, SAM is an open-source model that can segment any object in an image. This model has been trained on a massive dataset known as SA-1B, which comprises over 11 million images and 1.1 billion segmentation masks. For more information on SAM, refer to their website and research paper.
  2. LaMa is used to remove any undesired objects from an image. LaMa is a Generative Adversarial Network (GAN) model specializes in fill missing parts of images using irregular masks. The model architecture incorporates image-wide global context and a single-step architecture that uses Fourier convolutions, enabling it to achieve state-of-the-art results at a faster speed. For more details on LaMa, visit their website and research paper.
  3. SD 2 inpaint model from Stability AI is used to modify or replace objects in an image. This model allows us to edit the object in the mask area by providing a text prompt. The inpaint model is based on the text-to-image SD model, which can create high-quality images with a simple text prompt. It provides additional arguments such as original and mask images, allowing for quick modification and restoration of existing content. To learn more about Stable Diffusion models on AWS, refer to Create high-quality images with Stable Diffusion models and deploy them cost-efficiently with Amazon SageMaker.

All three models are hosted on SageMaker MMEs, which reduces the operational burden from managing multiple endpoints. In addition to that, using MME eliminates concerns about certain models being underutilized because resources are shared. You can observe the benefit from improved instance saturation, which ultimately leads to cost savings. The following architecture diagram illustrates how all three models are served using SageMaker MMEs with TorchServe.

flow diagram

We have published the code to implement this solution architecture in our GitHub repository. To follow along with the rest of the post, use the notebook file. It is recommended to run this example on a SageMaker notebook instance using the conda_python3 (Python 3.10.10) kernel.

Extend the TorchServe container

The first step is to prepare the model hosting container. SageMaker provides a managed PyTorch Deep Learning Container (DLC) that you can retrieve using the following code snippet:

# Use SageMaker PyTorch DLC as base image
baseimage = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=region,
    py_version="py310",
    image_scope="inference",
    version="2.0.0",
    instance_type="ml.g5.2xlarge",
)
print(baseimage)

Because the models require resources and additional packages that are not on the base PyTorch DLC, you need to build a Docker image. This image is then uploaded to Amazon Elastic Container Registry (Amazon ECR) so we can access directly from SageMaker. The custom installed libraries are listed in the Docker file:

ARG BASE_IMAGE

FROM $BASE_IMAGE

#Install any additional libraries
RUN pip install segment-anything-py==1.0
RUN pip install opencv-python-headless==4.7.0.68
RUN pip install matplotlib==3.6.3
RUN pip install diffusers
RUN pip install tqdm
RUN pip install easydict
RUN pip install scikit-image
RUN pip install xformers
RUN pip install tensorflow
RUN pip install joblib
RUN pip install matplotlib
RUN pip install albumentations==0.5.2
RUN pip install hydra-core==1.1.0
RUN pip install pytorch-lightning
RUN pip install tabulate
RUN pip install kornia==0.5.0
RUN pip install webdataset
RUN pip install omegaconf==2.1.2
RUN pip install transformers==4.28.1
RUN pip install accelerate
RUN pip install ftfy

Run the shell command file to build the custom image locally and push it to Amazon ECR:

%%capture build_output

reponame = "torchserve-mme-demo"
versiontag = "genai-0.1"

# Build our own docker image
!cd workspace/docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}

Prepare the model artifacts

The main difference for the new MMEs with TorchServe support is how you prepare your model artifacts. The code repo provides a skeleton folder for each model (models folder) to house the required files for TorchServe. We follow the same four-step process to prepare each model .tar file. The following code is an example of the skeleton folder for the SD model:

workspace
|--sd
   |-- custom_handler.py
   |-- model-config.yaml

The first step is to download the pre-trained model checkpoints in the models folder:

import diffusers
import torch
import transformers

pipeline = diffusers.StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16
)

sd_dir = "workspace/sd/model"
pipeline.save_pretrained(sd_dir)

The next step is to define a custom_handler.py file. This is required to define the behavior of the model when it receives a request, such as loading the model, preprocessing the input, and postprocessing the output. The handle method is the main entry point for requests, and it accepts a request object and returns a response object. It loads the pre-trained model checkpoints and applies the preprocess and postprocess methods to the input and output data. The following code snippet illustrates a simple structure of the custom_handler.py file. For more detail, refer to the TorchServe handler API.

def initialize(self, ctx: Context):

def preprocess(self, data):

def inference(self, data):

def handle(self, data, context):
    requests = self.preprocess(data)
    responses = self.inference(requests)

    return responses

The last required file for TorchServe is model-config.yaml. The file defines the configuration of the model server, such as number of workers and batch size. The configuration is at a per-model level, and an example config file is shown in the following code. For a complete list of parameters, refer to the GitHub repo.

minWorkers: 1
maxWorkers: 1
batchSize: 1
maxBatchDelay: 200
responseTimeout: 300

The final step is to package all the model artifacts into a single .tar.gz file using the torch-model-archiver module:

!torch-model-archiver --model-name sd --version 1.0 --handler workspace/sd/custom_handler.py --extra-files workspace/sd/model --config-file workspace/sam/model-config.yaml --archive-format no-archive!cd sd && tar cvzf sd.tar.gz .

Create the multi-model endpoint

The steps to create a SageMaker MME are the same as before. In this particular example, you spin up an endpoint using the SageMaker SDK. Start by defining an Amazon Simple Storage Service (Amazon S3) location and the hosting container. This S3 location is where SageMaker will dynamically load the models base on invocation patterns. The hosting container is the custom container you built and pushed to Amazon ECR in the earlier step. See the following code:

# This is where our MME will read models from on S3.
multi_model_s3uri = output_path

Then you want to define a MulitDataModel that captures all the attributes like model location, hosting container, and permission access:

print(multi_model_s3uri)
model = Model(
    model_data=f"{multi_model_s3uri}/sam.tar.gz",
    image_uri=container,
    role=role,
    sagemaker_session=smsess,
    env={"TF_ENABLE_ONEDNN_OPTS": "0"},
)

mme = MultiDataModel(
    name="torchserve-mme-genai-" + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
    model_data_prefix=multi_model_s3uri,
    model=model,
    sagemaker_session=smsess,
)
print(mme)

The deploy() function creates an endpoint configuration and hosts the endpoint:

mme.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

In the example we provided, we also show how you can list models and dynamically add new models using the SDK. The add_model() function copies your local model .tar files into the MME S3 location:

# Only sam.tar.gz visible!
list(mme.list_models())

models = ["sd/sd.tar.gz", "lama/lama.tar.gz"]
for model in models:
    mme.add_model(model_data_source=model)

Invoke the models

Now that we have all three models hosted on an MME, we can invoke each model in sequence to build our language-assisted editing features. To invoke each model, provide a target_model parameter in the predictor.predict() function. The model name is just the name of the model .tar file we uploaded. The following is an example code snippet for the SAM model that takes in a pixel coordinate, a point label, and dilate kernel size, and generates a segmentation mask of the object in the pixel location:

img_file = "workspace/test_data/sample1.png"
img_bytes = None

with Image.open(img_file) as f:
    img_bytes = encode_image(f)

gen_args = json.dumps(dict(point_coords=[750, 500], point_labels=1, dilate_kernel_size=15))

payload = json.dumps({"image": img_bytes, "gen_args": gen_args}).encode("utf-8")

response = predictor.predict(data=payload, target_model="/sam.tar.gz")
encoded_masks_string = json.loads(response.decode("utf-8"))["generated_image"]
base64_bytes_masks = base64.b64decode(encoded_masks_string)

with Image.open(io.BytesIO(base64_bytes_masks)) as f:
    generated_image_rgb = f.convert("RGB")
    generated_image_rgb.show()

To remove an unwanted object from an image, take the segmentation mask generated from SAM and feed that into the LaMa model with the original image. The following images show an example.

Dog on a bench

White mask of dog on black background

Just a bench

Sample image Segmentation mask from SAM Erase the dog using LaMa

To modify or replace any object in an image with a text prompt, take the segmentation mask from SAM and feed it into SD model with the original image and text prompt, as shown in the following example.

Dog on a bench White mask of dog on black background Hamster on a bench
Sample image Segmentation mask from SAM Replace using SD model with text prompt

“a hamster on a bench”

Cost savings

The benefits of SageMaker MMEs increase based on the scale of model consolidation. The following table shows the GPU memory usage of the three models in this post. They are deployed on one g5.2xlarge instance by using one SageMaker MME.

Model GPU Memory (MiB)
Segment Anything Model 3,362
Stable Diffusion In Paint 3,910
Lama 852

You can see cost savings when hosting the three models with one endpoint, and for use cases with hundreds or thousands of models, the savings are much greater.

For example, consider 100 Stable Diffusion models. Each of the models on its own could be served by an ml.g5.2xlarge endpoint (4 GiB memory), costing $1.52 per instance hour in the US East (N. Virginia) Region. To provide all 100 models using their own endpoint would cost $218,880 per month. With a SageMaker MME, a single endpoint using ml.g5.2xlarge instances can host four models simultaneously. This reduces production inference costs by 75% to only $54,720 per month. The following table summarizes the differences between single-model and multi-model endpoints for this example. Given an endpoint configuration with sufficient memory for your target models, steady state invocation latency after all models have been loaded will be similar to that of a single-model endpoint.

Single-model endpoint Multi-model endpoint
Total endpoint price per month $218,880 $54,720
Endpoint instance type ml.g5.2xlarge ml.g5.2xlarge
CPU Memory capacity (GiB) 32 32
GPU Memory capacity (GiB) 24 24
Endpoint price per hour $1.52 $1.52
Number of instances per endpoint 2 2
Endpoints needed for 100 models 100 25

Clean up

After you are done, please follow the instructions in the cleanup section of the notebook to delete the resources provisioned in this post to avoid unnecessary charges. Refer to Amazon SageMaker Pricing for details on the cost of the inference instances.

Conclusion

This post demonstrates the language-assisted editing capabilities made possible through the use of generative AI models hosted on SageMaker MMEs with TorchServe. The example we shared illustrates how we can use resource sharing and simplified model management with SageMaker MMEs while still utilizing TorchServe as our model serving stack. We utilized three deep learning foundation models: SAM, SD 2 Inpainting, and LaMa. These models enable us to build powerful capabilities, such as erasing any unwanted object from an image and modifying or replacing any object in an image by supplying a text instruction. These features can help artists and content creators work more efficiently and meet their content demands by automating repetitive tasks, optimizing campaigns, and providing a hyper-personalized experience. We invite you to explore the example provided in this post and build your own UI experience using TorchServe on a SageMaker MME.

To get started, see Supported algorithms, frameworks, and instances for multi-model endpoints using GPU backed instances.


About the authors

James Wu James Wu is a Senior AI/ML Specialist Solution Architect at AWS. helping customers design and build AI/ML solutions. James’s work covers a wide range of ML use cases, with a primary interest in computer vision, deep learning, and scaling ML across the enterprise. Prior to joining AWS, James was an architect, developer, and technology leader for over 10 years, including 6 years in engineering and 4 years in marketing & advertising industries.
Li Ning

Li Ning is a senior software engineer at AWS with a specialization in building large-scale AI solutions. As a tech lead for TorchServe, a project jointly developed by AWS and Meta, her passion lies in leveraging PyTorch and AWS SageMaker to help customers embrace AI for the greater good. Outside of her professional endeavors, Li enjoys swimming, traveling, following the latest advancements in technology, and spending quality time with her family.

Ankith Gunapal Ankith Gunapal is an AI Partner Engineer at Meta (PyTorch). He is passionate about model optimization and model serving, with experience ranging from RTL verification, embedded software, computer vision, to PyTorch. He holds a Master’s in Data Science and a Master’s in Telecommunications. Outside of work, Ankith is also an electronic dance music producer.

Saurabh Trikande Saurabh Trikande is a Senior Product Manager for Amazon SageMaker Inference. He is passionate about working with customers and is motivated by the goal of democratizing machine learning. He focuses on core challenges related to deploying complex ML applications, multi-tenant ML models, cost optimizations, and making deployment of deep learning models more accessible. In his spare time, Saurabh enjoys hiking, learning about innovative technologies, following TechCrunch and spending time with his family.

Subhash Talluri Subhash Talluri is a Lead AI/ML solutions architect of the Telecom Industry business unit at Amazon Web Services. He’s been leading development of innovative AI/ML solutions for Telecom customers and partners worldwide. He brings interdisciplinary expertise in engineering and computer science to help build scalable, secure, and compliant AI/ML solutions via cloud-optimized architectures on AWS.

Read More

Accelerating Generative AI Part III: Diffusion, Fast

Accelerating Generative AI Part III: Diffusion, Fast

This post is the third part of a multi-series blog focused on how to accelerate generative AI models with pure, native PyTorch. We are excited to share a breadth of newly released PyTorch performance features alongside practical examples to see how far we can push PyTorch native performance. In part one, we showed how to accelerate Segment Anything over 8x using only pure, native PyTorch. In part two, we showed how to accelerate Llama-7B by almost 10x using only native PyTorch optimizations. In this blog, we’ll focus on speeding up text-to-image diffusion models by upto 3x.

We will leverage an array of optimizations including:

  • Running with the bfloat16 precision
  • scaled_dot_product_attention (SPDA)
  • torch.compile
  • Combining q,k,v projections for attention computation
  • Dynamic int8 quantization

We will primarily focus on Stable Diffusion XL (SDXL), demonstrating a latency improvement of 3x. These techniques are PyTorch-native, which means you don’t have to rely on any third-party libraries or any C++ code to take advantage of them.

Enabling these optimizations with the 🤗Diffusers library takes just a few lines of code. If you’re already feeling excited and cannot wait to jump to the code, check out the accompanying repository here: https://github.com/huggingface/diffusion-fast.

SDXL Chart

(The discussed techniques are not SDXL-specific and can be used to speed up other text-to-image diffusion systems, as shown later.)

Below, you can find some blog posts on similar topics:

Setup

We will demonstrate the optimizations and their respective speed-up gains using the 🤗Diffusers library. Apart from that, we will make use of the following PyTorch-native libraries and environments:

  • Torch nightly (to benefit from the fastest kernels for efficient attention; 2.3.0.dev20231218+cu121)
  • 🤗 PEFT (version: 0.7.1)
  • torchao (commit SHA: 54bcd5a10d0abbe7b0c045052029257099f83fd9)
  • CUDA 12.1

For an easier reproduction environment, you can also refer to this Dockerfile. The benchmarking numbers presented in this post come from a 400W 80GB A100 GPU (with its clock rate set to its maximum capacity).

Since we use an A100 GPU (Ampere architecture) here, we can specify torch.set_float32_matmul_precision("high") to benefit from the TF32 precision format.

Run inference using a reduced precision

Running SDXL in Diffusers just takes a few lines of code:

from diffusers import StableDiffusionXLPipeline

## Load the pipeline in full-precision and place its model components on CUDA.
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0").to("cuda")

## Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

But this isn’t very practical as it takes 7.36 seconds to generate a single image with 30 steps. This is our baseline which we will try to optimize one step at a time.

SDXL Chart

Here, we’re running the pipeline with the full precision. We can immediately cut down the inference time by using a reduced precision such as bfloat16. Besides, modern GPUs come with dedicated cores for running accelerated computation benefiting from reduced precision. To run the computations of the pipeline in the bfloat16 precision, we just need to specify the data type while initializing the pipeline:

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

## Run the attention ops without efficiency.
pipe.unet.set_default_attn_processor()
pipe.vae.set_default_attn_processor()
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

SDXL Chart

By using a reduced precision, we’re able to cut down the inference latency from 7.36 seconds to 4.63 seconds.

Some notes on the use of bfloat16

  • Using a reduced numerical precision (such as float16, bfloat16) to run inference doesn’t affect the generation quality but significantly improves latency.
  • The benefits of using the bfloat16 numerical precision as compared to float16 are hardware-dependent. Modern generations of GPUs tend to favor bfloat16.
  • Furthermore, in our experiments, we bfloat16 to be much more resilient when used with quantization in comparison to float16.

(We later ran the experiments in float16 and found out that the recent versions of torchao do not incur numerical problems from float16.)

Use SDPA for performing attention computations

By default, Diffusers uses scaled_dot_product_attention (SDPA) for performing attention-related computations when using PyTorch 2. SDPA provides faster and more efficient kernels to run intensive attention-related operations. To run the pipeline SDPA, we simply don’t set any attention processor like so:

from diffusers import StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
	"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt, num_inference_steps=30).images[0]

SDPA gives a nice boost from 4.63 seconds to 3.31 seconds.

SDXL Chart

Compiling the UNet and VAE

We can ask PyTorch to perform some low-level optimizations (such as operator fusion and launching faster kernels with CUDA graphs) by using torch.compile. For the StableDiffusionXLPipeline, we compile the denoiser (UNet) and the VAE:

from diffusers import StableDiffusionXLPipeline
import torch

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cuda")

## Compile the UNet and VAE.
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

## First call to `pipe` will be slow, subsequent ones will be faster.
image = pipe(prompt, num_inference_steps=30).images[0]

Using SDPA attention and compiling both the UNet and VAE reduces the latency from 3.31 seconds to 2.54 seconds.

SDXL Chart

Notes on torch.compile

torch.compile offers different backends and modes. As we’re aiming for maximum inference speed, we opt for the inductor backend using the “max-autotune”. “max-autotune” uses CUDA graphs and optimizes the compilation graph specifically for latency. Using CUDA graphs greatly reduces the overhead of launching GPU operations. It saves time by using a mechanism to launch multiple GPU operations through a single CPU operation.

Specifying fullgraph to be True ensures that there are no graph breaks in the underlying model, ensuring the fullest potential of torch.compile. In our case, the following compiler flags were also important to be explicitly set:

torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True

For the full list of compiler flags, refer to this file.

We also change the memory layout of the UNet and the VAE to “channels_last” when compiling them to ensure maximum speed:

pipe.unet.to(memory_format=torch.channels_last)
pipe.vae.to(memory_format=torch.channels_last)

In the next section, we’ll show how to improve the latency even further.

Additional optimizations

No graph breaks during torch.compile

Ensuring that the underlying model/method can be fully compiled is crucial for performance (torch.compile with fullgraph=True). This means having no graph breaks. We did this for the UNet and VAE by changing how we access the returning variables. Consider the following example:

code example

Getting rid of GPU syncs after compilation

During the iterative reverse diffusion process, we call step() on the scheduler each time after the denoiser predicts the less noisy latent embeddings. Inside step(), the sigmas variable is indexed. If the sigmas array is placed on the GPU, indexing causes a communication sync between the CPU and GPU. This causes a latency, and it becomes more evident when the denoiser has already been compiled.

But if the sigmas array always stays on the CPU (refer to this line), this sync doesn’t take place, hence improved latency. In general, any CPU <-> GPU communication sync should be none or be kept to a bare minimum as it can impact inference latency.

Using combined projections for attention ops

Both the UNet and the VAE used in SDXL make use of Transformer-like blocks. A Transformer block consists of attention blocks and feed-forward blocks.

In an attention block, the input is projected into three sub-spaces using three different projection matrices – Q, K, and V. In the naive implementation, these projections are performed separately on the input. But we can horizontally combine the projection matrices into a single matrix and perform the projection in one shot. This increases the size of the matmuls of the input projections and improves the impact of quantization (to be discussed next).

Enabling this kind of computation in Diffusers just takes a single line of code:

pipe.fuse_qkv_projections()

This will make the attention operations for both the UNet and the VAE take advantage of the combined projections. For the cross-attention layers, we only combine the key and value matrices. To learn more, you can refer to the official documentation here. It’s worth noting that we leverage PyTorch’s scaled_dot_product_attention here internally.

These additional techniques improved the inference latency from 2.54 seconds to 2.52 seconds.

SDXL Chart

Dynamic int8 quantization

We selectively apply dynamic int8 quantization to both the UNet and the VAE. This is because quantization adds additional conversion overhead to the model that is hopefully made up for by faster matmuls (dynamic quantization). If the matmuls are too small, these techniques may degrade performance.

Through experimentation, we found that certain linear layers in the UNet and the VAE don’t benefit from dynamic int8 quantization. You can check out the full code for filtering those layers here (referred to as dynamic_quant_filter_fn below).

We leverage the ultra-lightweight pure PyTorch library torchao to use its user-friendly APIs for quantization:

from torchao.quantization import apply_dynamic_quant

apply_dynamic_quant(pipe.unet, dynamic_quant_filter_fn)
apply_dynamic_quant(pipe.vae, dynamic_quant_filter_fn)

Since this quantization support is limited to linear layers only, we also turn suitable pointwise convolution layers into linear layers to maximize the benefit. We also specify the following compiler flags when using this option:

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.use_mixed_mm = True

To prevent any numerical issues stemming from quantization, we run everything in the bfloat16 format.

Applying quantization this way improved the latency from 2.52 seconds to 2.43 seconds.

SDXL Chart

Resources

We welcome you to check out the following codebases to reproduce these numbers and extend the techniques to other text-to-image diffusion systems as well:

Other links

Improvements in other pipelines

We applied these techniques to other pipelines to test the generality of our approach. Below are our findings:

SSD-1B

SSD-1B Chart

Stable Diffusion v1-5

Stable Diffusion v1-5 chart

PixArt-alpha/PixArt-XL-2-1024-MS

It’s worth noting that PixArt-Alpha uses a Transformer-based architecture as its denoiser for the reverse diffusion process instead of a UNet.

PixArt-alpha/PixArt-XL-2-1024-MS chart

Note that for Stable Diffusion v1-5 and PixArt-Alpha, we didn’t explore the best shape combination criteria for applying dynamic int8 quantization. It might be possible to get better numbers with a better combination.

Collectively, the methods we presented offer substantial speedup over the baseline without degradation in the generation quality. Furthermore, we believe that these methods should complement other optimization methods popular in the community (such as DeepCache, Stable Fast, etc.).

Conclusion and next steps

In this post, we presented a basket of simple yet effective techniques that can help improve the inference latency of text-to-image Diffusion models in pure PyTorch. In summary:

  • Using a reduced precision to perform our computations
  • Scaled-dot product attention for running the attention blocks efficiently
  • torch.compile with “max-autotune” to improve for latency
  • Combining the different projections together for computing attention
  • Dynamic int8 quantization

We believe there’s a lot to be explored in terms of how we apply quantization to a text-to-image diffusion system. We didn’t exhaustively explore which layers in the UNet and the VAE tend to benefit from dynamic quantization. There might be opportunities to further speed things up with a better combination of the layers being targeted for quantization.

We kept the text encoders of SDXL untouched other than just running them in bfloat16. Optimizing them might also lead to improvements in latency.

Acknowledgements

Thanks to Ollin Boer Bohan whose VAE was used throughout the benchmarking process as it is numerically more stable under reduced numerical precisions.

Thanks to Hugo Larcher from Hugging Face for helping with infrastructure.

Read More

Understanding GPU Memory 2: Finding and Removing Reference Cycles

Understanding GPU Memory 2: Finding and Removing Reference Cycles

This is part 2 of the Understanding GPU Memory blog series. Our first post Understanding GPU Memory 1: Visualizing All Allocations over Time shows how to use the memory snapshot tool. In this part, we will use the Memory Snapshot to visualize a GPU memory leak caused by reference cycles, and then locate and remove them in our code using the Reference Cycle Detector.

Sometimes when we were using the Memory Snapshot, we saw plots of GPU memory that looked similar to this.

GPU memory

In this snapshot, each peak shows GPU tensors building up over time and then several tensors getting released at once. In addition, a CUDA OOM happens on the right side causing all the tensors to be released. Seeing the tensors accumulate like this is a clear indication of a problem, but it doesn’t immediately suggest why.

Tensors in Reference Cycles

During early debugging, we dug in further to find that this **pattern happens a lot when your Python code has objects with reference cycles. ** Python will clean up non-cyclic objects immediately using reference counting. However objects in reference cycles are only cleaned up later by a cycle collector. If these cycles refer to a GPU tensor, the GPU tensor will stay alive until that cycle collector runs and removes the reference cycle. Let’s take a look at a simplified example.

Simple reference cycle

Code Snippet behind the snapshot (full code in Appendix A):

    def leak(tensor_size, num_iter=100000, device="cuda:0"):
      class Node:
        def __init__(self, T):
          self.tensor = T
          self.link = None

      for _ in range(num_iter):
        A = torch.zeros(tensor_size, device=device)
        B = torch.zeros(tensor_size, device=device)
        a, b = Node(A), Node(B)

        # A reference cycle will force refcounts to be non-zero.
        a.link, b.link = b, a
        # Python will eventually garbage collect a & b, but will
        # OOM on the GPU before that happens (since python
        # runtime doesn't know about CUDA memory usage).

In this code example, the tensors A and B are created, where A has a link to B and vice versa. This forces a non-zero reference count when A and B go out of scope. When we run this for 100,000 iterations, we expect the automatic garbage collection to free the reference cycles when going out of scope. However, this will actually CUDA OOM.

Why doesn’t automatic garbage collection work?

The automatic garbage collection works well when there is a lot of extra memory as is common on CPUs because it amortizes the expensive garbage collection by using Generational Garbage Collection. But to amortize the collection work, it defers some memory cleanup making the maximum memory usage higher, which is less suited to memory constrained environments. The Python runtime also has no insights into CUDA memory usage, so it cannot be triggered on high memory pressure either. It’s even more challenging as GPU training is almost always memory constrained because we will often raise the batch size to use any additional free memory.

The CPython’s garbage collection frees unreachable objects held in reference cycles via the mark-and-sweep. The garbage collection is automatically run when the number of objects exceeds certain thresholds. There are 3 generations of thresholds to help amortize the expensive costs of running garbage collection on every object. The later generations are less frequently run. This would explain why automatic collections will only clear several tensors on each peak, however there are still tensors that leak resulting in the CUDA OOM. Those tensors were held by reference cycles in later generations.

Explicitly calling gc.collect()

One way to fix this is by explicitly calling the garbage collector frequently. Here we can see that the GPU memory for tensors out of scope gets cleaned up when we explicitly call the garbage collector every 100 iterations. This also controls the maximum GPU peak memory held by leaking tensors.

memory leak

Although this works and fixes the CUDA OOM issue, calling gc.collect() too frequently can cause other issues including QPS regressions. Therefore we cannot simply increase the frequency of garbage collection on every training job. It’s best to just avoid creating reference cycles in the first place. More on this in section, Reference Cycle Detector.

Sneaky Memory Leak in Callback

Real examples are more complicated, so let’s look at a more realistic example that has a similar behavior. In this snapshot, we can observe the same behavior of tensors being accumulated and freed during automatic garbage collection, until we hit a CUDA OOM.

memory leak

Code Snippet behind this snapshot (full code sample in Appendix A):

    class AwaitableTensor:
      def __init__(self, tensor_size):
        self._tensor_size = tensor_size
        self._tensor = None

      def wait(self):
        self._tensor = torch.zeros(self._tensor_size, device="cuda:0")
        return self._tensor

    class AwaitableTensorWithViewCallback:
      def __init__(self, tensor_awaitable, view_dim):
        self._tensor_awaitable = tensor_awaitable
        self._view_dim = view_dim
        # Add a view filter callback to the tensor.
        self._callback = lambda ret: ret.view(-1, self._view_dim)

      def wait(self):
        return self._callback(self._tensor_awaitable.wait())

    async def awaitable_leak(
      tensor_size=2**27, num_iter=100000,
    ):
      for _ in range(num_iter):
        A = AwaitableTensor(tensor_size)
        AwaitableTensorWithViewCallBack(A, 4).wait()

In this code, we define two classes. The class AwaitableTensor will create a tensor when waited upon. Another class AwaitableTensorWithViewCallback will apply a view filter on the AwaitableTensor via callback lambda.

When running awaitable_leak, which creates tensor A (512 MB) and applies a view filter for 100,000 iterations, we expect that A should be reclaimed each time it goes out of scope because the reference count should reach 0. However, this will actually OOM!

While we know there is a reference cycle here, it isn’t clear from the code where the cycle is created. To help with these situations, we have created a tool to locate and report these cycles.

Reference Cycle Detector

Introducing the Reference Cycle Detector, which helps us find reference cycles keeping GPU tensors alive. The API is fairly simple:

  • During model initialization:
    • Import: from torch.utils.viz._cycles import warn_tensor_cycles
    • Start: warn_tensor_cycles()

The Reference Cycle Detector will issue warnings every time that the cycle collector runs and finds a CUDA tensor that gets freed. The warning provides an object graph showing how the reference cycle refers to the GPU tensor.

object graph

For instance in this object graph, we can easily observe that there is a circular dependency on the outer circle of the graph, and highlighted in red is the GPU tensor kept alive.

Most cycles are pretty easy to fix once they are discovered. For instance here we can remove the reference to self created by self._view_dim in the callback.

code snippet

We’ve spent some time fixing cycles in existing models using these tools. For example in TorchRec, we’ve found and removed a reference cycle in PR#1226.

code snippet

Once we’ve removed the reference cycles, the code will no longer issue a CUDA OOM nor show any memory leaks in their snapshots.

What are the other benefits of using the Reference Cycle Detector?

Removing these cycles will also directly lower the maximum GPU memory usage as well as make it less likely for memory to fragment because the allocator returns to the same state after each iteration.

Where can I find these tools?

We hope that the Reference Cycle Detector will greatly improve your ability to find and remove memory leaks caused by reference cycles. The Reference Cycle Detector is available in the v2.1 release of PyTorch as experimental features and More information about the Reference Cycle Detector can be found in the PyTorch Memory docs here.

Feedback

We look forward to hearing from you about any enhancements, bugs or memory stories that our tools helped to solve! As always, please feel free to open new issues on PyTorch’s Github page.

We are also open to contributions from the OSS community, feel free to tag Aaron Shi and Zachary DeVito in any Github PRs for reviews.

Acknowledgements

Really appreciate the content reviewers, Mark Saroufim, Gregory Chanan, and Adnan Aziz for reviewing this post and improving its readability.

Appendix

Appendix A – Code Sample

This code snippet was used to generate the plots and examples shown. Here are the arguments to reproduce the sections:

  • Introduction: python sample.py
  • Explicitly calling gc.collect(): python sample.py --gc_collect_interval=100
  • Sneaky Memory Leak in Callback: python sample.py --workload=awaitable
  • Ref Cycle Detector: python sample.py --workload=awaitable --warn_tensor_cycles

sample.py:

# (c) Meta Platforms, Inc. and affiliates. 
import argparse
import asyncio
import gc
import logging
import socket
from datetime import datetime, timedelta

import torch

logging.basicConfig(
   format="%(levelname)s:%(asctime)s %(message)s",
   level=logging.INFO,
   datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)

TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

# Keep a max of 100,000 alloc/free events in the recorded history
# leading up to the snapshot.
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000

def start_record_memory_history() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not recording memory history")
       return

   logger.info("Starting snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(
       max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
   )

def stop_record_memory_history() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not recording memory history")
       return

   logger.info("Stopping snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(enabled=None)

def export_memory_snapshot() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not exporting memory snapshot")
       return

   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   try:
       logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")
       torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
   except Exception as e:
       logger.error(f"Failed to capture memory snapshot {e}")
       return

# This function will leak tensors due to the reference cycles.
def simple_leak(tensor_size, gc_interval=None, num_iter=30000, device="cuda:0"):
    class Node:
        def __init__(self, T):
            self.tensor = T
            self.link = None

    for i in range(num_iter):
        A = torch.zeros(tensor_size, device=device)
        B = torch.zeros(tensor_size, device=device)
        a, b = Node(A), Node(B)
        # A reference cycle will force refcounts to be non-zero, when
        # a and b go out of scope.
        a.link, b.link = b, a
        # Python will eventually gc a and b, but may OOM on the CUDA
        # device before that happens (since python runtime doesn't
        # know about CUDA memory usage).

        # Since implicit gc is not called frequently enough due to
        # generational gc, adding an explicit gc is necessary as Python
        # runtime does not know about CUDA memory pressure.
        # https://en.wikipedia.org/wiki/Tracing_garbage_collection#Generational_GC_(ephemeral_GC)
        if gc_interval and i % int(gc_interval) == 0:
            gc.collect()

async def awaitable_leak(
    tensor_size, gc_interval=None, num_iter=100000, device="cuda:0"
):
    class AwaitableTensor:
        def __init__(self, tensor_size, device) -> None:
            self._tensor_size = tensor_size
            self._device = device
            self._tensor = None

        def wait(self) -> torch.Tensor:
            self._tensor = torch.zeros(self._tensor_size, device=self._device)
            return self._tensor

    class AwaitableTensorWithViewCallBack:
        def __init__(
            self,
            tensor_awaitable: AwaitableTensor,
            view_dim: int,
        ) -> None:
            self._tensor_awaitable = tensor_awaitable
            self._view_dim = view_dim
            # Add a view filter callback to the tensor.
            self._callback = lambda ret: ret.view(-1, self._view_dim)

        def wait(self) -> torch.Tensor:
            return self._callback(self._tensor_awaitable.wait())

    for i in range(num_iter):
        # Create an awaitable tensor
        a_tensor = AwaitableTensor(tensor_size, device)

        # Apply a view filter callback on the awaitable tensor.
        AwaitableTensorWithViewCallBack(a_tensor, 4).wait()

        # a_tensor will go out of scope.

        if gc_interval and i % int(gc_interval) == 0:
            gc.collect()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="A memory_leak binary instance")
    parser.add_argument(
        "--gc_collect_interval",
        default=None,
        help="Explicitly call GC every given interval. Default is off.",
    )
    parser.add_argument(
        "--workload",
        default="simple",
        help="Toggle which memory leak workload to run. Options are simple, awaitable.",
    )
    parser.add_argument(
        "--warn_tensor_cycles",
        action="store_true",
        default=False,
        help="Toggle whether to enable reference cycle detector.",
    )
    args = parser.parse_args()

    if args.warn_tensor_cycles:
        from tempfile import NamedTemporaryFile

        from torch.utils.viz._cycles import observe_tensor_cycles

        logger.info("Enabling warning for Python reference cycles for CUDA Tensors.")

        def write_and_log(html):
            with NamedTemporaryFile("w", suffix=".html", delete=False) as f:
                f.write(html)
                logger.warning(
                    "Reference cycle includes a CUDA Tensor see visualization of cycle %s",
                    f.name,
                )

        observe_tensor_cycles(write_and_log)
    else:
        # Start recording memory snapshot history
        start_record_memory_history()

    # Run the workload with a larger tensor size.
    # For smaller sizes, we will not CUDA OOM as gc will kick in often enough
    # to reclaim reference cycles before an OOM occurs.
    size = 2**26  # 256 MB
    try:
        if args.workload == "awaitable":
            size *= 2
            logger.info(f"Running tensor_size: {size*4/1024/1024} MB")
            asyncio.run(
                awaitable_leak(tensor_size=size, gc_interval=args.gc_collect_interval)
            )
        elif args.workload == "simple":
            logger.info(f"Running tensor_size: {size*4/1024/1024} MB")
            simple_leak(tensor_size=size, gc_interval=args.gc_collect_interval)
        else:
            raise Exception("Unknown workload.")
    except Exception:
        logger.exception(f"Failed to allocate {size*4/1024/1024} MB")

    # Create the memory snapshot file
    export_memory_snapshot()

    # Stop recording memory snapshot history
    stop_record_memory_history()

Read More

Training Production AI Models with PyTorch 2.0

Training Production AI Models with PyTorch 2.0

1. Introduction

PyTorch 2.0 (abbreviated as PT2) can significantly improve the training and inference performance of an AI model using a compiler called_ torch.compile_ while being 100% backward compatible with PyTorch 1.x. There have been reports on how PT2 improves the performance of common benchmarks (e.g., huggingface’s diffusers). In this blog, we discuss our experiences in applying PT2 to production _AI models _at Meta.

2. Background

2.1 Why is automatic performance optimization important for production?

Performance is particularly important for production—e.g, even a 5% reduction in the training time of a heavily used model can translate to substantial savings in GPU cost and data-center power. Another important metric is development efficiency, which measures how many engineer-months are required to bring a model to production. Typically, a significant part of this bring-up effort is spent on manual _performance tuning such as rewriting GPU kernels to improve the training speed. By providing _automatic _performance optimization, PT2 can improve _both cost and development efficiency.

2.2 How PT2 improves performance

As a compiler, PT2 can view_ multiple_ operations in the training graph captured from a model (unlike in PT1.x, where only one operation is executed at a time). Consequently, PT2 can exploit a number of performance optimization opportunities, including:

  • Fusing multiple operations into a single GPU kernel:
    • A typical type of performance overhead in running a GPU program is the CPU overhead of launching small GPU kernels. By fusing multiple operations into a single GPU kernel, PT2 can significantly reduce the kernel-launching overhead on the CPU. For instance, consider the PyTorch program in Figure 1(a). When it is executed on GPU with PT1, it has three GPU kernels (two for the two sin() ops and one for the addition op). With PT2, there is only one kernel generated, which fuses all three ops.
    • After fusing some operations, certain operations in the graph may become dead and hence can be optimized away. This can save both compute and memory bandwidth on the GPU. For instance, in Figure 1(b), one of the duplicated sin() ops can be optimized away.
    • In addition, fusion can also reduce GPU device memory reads/writes (by composing pointwise kernels) and help improve hardware utilization.

Fig.1  How PT2 improves performance with fusion and dead-code elimination.

Fig. 1: How PT2 improves performance with fusion and dead-code elimination.

  • Reducing the type conversion overhead for using lower-precision data types:
    • PyTorch 1.x supports Automatic Mixed Precision (AMP). While AMP can reduce the compute time of an op, it introduces type conversion overhead before and after the op. PT2 can increase AMP performance by optimizing away unnecessary type conversion code, significantly reducing its overhead. As an example, Figure 2(a) converts three 32-bit input tensors (a32, b32, c32) to bf16 before doing the matrix multiplications. Nevertheless, in this example, a32 and c32 are actually the same tensor (a_float32). So, there is no need to convert a_float32 twice, as shown in the code generated by torch.compile in Figure 2(b). Note that while both this example and the previous one optimize away redundant computations, they are different in the sense that the type conversion code in this example is implicit via torch.autocast, unlike in the previous example where the torch.sin(x).cuda() is explicit in user code.

Fig.2  How PT2 reduces type conversion overhead when using AMP.

Fig. 2: How PT2 reduces type conversion overhead when using AMP.

  • Reusing buffers on the GPU:
    • With a global view, the scheduler in torch.compile can reuse buffers on the GPU, thereby reducing both memory allocation time and memory consumption. Figure 3 shows the driver program that calls the Triton kernels generated for the program in Figure 2(a). We can see that buf1 is reused as buf4.

Fig.3  Reuse of buffers.

Fig. 3: Reuse of buffers.

  • Autotuning:
    • PT2 has options to enable autotuning (via Triton) on matrix-multiply ops, pointwise ops, and reduction ops. Tunable parameters include block size, number of stages, and number of warps. With autotuning, the most performant implementation of an op can be found empirically.

3. Production environment considerations

In this section, we describe a number of important considerations in applying PT2 to production.

3.1 Ensuring no model quality degradation with torch.compile

Applying torch.compile to a model will cause numerical changes because of (1) reordering of floating-point ops during various optimizations such as fusion and (2) use of lower precision data types like bf16 if AMP is enabled. Therefore 100% bitwise compatibility with PT 1.x is not expected. Nevertheless, we still need to make sure that the model quality (measured in some form of numeric scores) is preserved after applying torch.compile. Typically, each production model will have its own range of acceptable scores (e.g., percentage change must be within 0.01%).

In case of a model-quality drop caused by torch.compile, we need to do a deep-dive debug.

One useful technique for debugging a torch.compile-related numeric issue is to apply torch.compile with different backends, in particular “eager” and “aot_eager”, in addition to “inductor”:

  • If the numeric issue happens with the “eager” backend, then the forward graph constructed by torch.compile is likely incorrect;
  • If the numeric issue doesn’t happen with “eager” but happens with “aot_eager”, then the backward graph constructed by torch.compile is likely incorrect;
  • If the numeric issue doesn’t happen with either “eager” or “aot_eager” but happens with “inductor”, then the code generation inside the inductor is likely incorrect.

3.2 Autotuning in production

By default, the autotuning in torch.inductor is done online _while the model is executed. For some production models, we find that the autotuning time can take several hours, which is not acceptable for production. Therefore, we add _offline autotuning which works as depicted in Figure 4. The very first time that a model is run, the details (e.g., input tensor shape, data type etc) on all ops that require tuning will be logged to a database. Then, a tuning process for these ops is run overnight to search for the most performant implementation of each op; the search result is updated to a persistent cache (implemented as a source file of torch.inductor). Next time when the model is run again, the tuned implementation of each op will be found in the cache and chosen for execution.

Fig.4  The offline autotuning used in production.

Fig. 4: The offline autotuning used in production.

As we previously discussed in this blog, a profiler is essential for debugging the performance of production models. We have enhanced the profiler to display torch.compile related events on the timeline. The most useful ones are marking which parts of the model are running compiled code so that we can quickly validate if the parts of the model that are supposed to be compiled are actually compiled by torch.compile. For example, the trace in Figure 5 has two compiled regions (with the label “CompiledFunction”). Other useful events are time spent on the compilation and that spent on accessing the compiler’s code-cache.

Fig.5  A trace with two compiled regions.

Fig. 5: A trace with two compiled regions.

3.4 Controlling just-in-time compilation time

torch.compile uses just-in-time compilation. The compilation happens when the first batch of data is trained. In our production setting, there is an upper limit on how much time is allowed for a training job to reach its first batch, aka Time-To-First-Batch (TTFB). We need to make sure that enabling torch.compile will not increase TTFB to over the limit. This could be challenging because production models are large and~~ ~~torch.compile can take substantial compilation time. We enable parallel compilation to keep the compile time under control (this is controlled by the global variable compile_threads inside torch/_inductor/config.py, which is already set to the CPU count on OSS Linux). A model is decomposed into one or more computational graphs; each graph is decomposed into multiple Triton kernels. If parallel compilation is enabled, all the Triton kernels in the same graph can be compiled simultaneously (nevertheless, kernels from different graphs are still compiled in serial). Figure 6 illustrates how parallel compilation helps.

Fig.6  Using parallel compilation in production.

Fig. 6: Using parallel compilation in production.

4. Results

In this section, we use three production models to evaluate PT2. First we show the training time speedups with PT2, using different optimization configs. Second, we show the importance of parallel compilation on the compilation time.

4.1 Training-time speedup with torch.compile

Figure 7 reports the training-time speedup with PT2. For each model, we show four cases: (i) no-compile with bf16, (ii) compile with fp32, (iii) compile with bf16, (iv) compile with bf16 and autotuning. The y-axis is the speedup over the baseline, which is no-compile with fp32. Note that no-compile with bf16 is actually slower than no-compile with fp32, due to the type conversion overhead. In contrast, compiling with bf16 achieves much larger speedups by reducing much of this overhead. Overall, given that these models are already heavily optimized by hand, we are excited to see that torch.compile can still provide 1.14-1.24x speedup.

Fig.7 Training-time speedup with torch.compile (note: the baseline, no-compile/fp32, is_ omitted _in this figure).

Fig. 7: Training-time speedup with torch.compile (note: the baseline, no-compile/fp32, is_ omitted _in this figure).

4.2 Compilation-time reduction with parallel compilation

Figure 8 shows the compilation time with and without parallel compilation. While there is still room for improvement on the serial compilation time, parallel compilation has reduced the compilation overhead on TTFB to an acceptable level. Models B and C benefit more from parallel compilation than Model A does because they have more distinct Triton kernels per graph.

Fig.8 PT2 compilation time.

Fig. 8: PT2 compilation time.

5. Concluding Remarks

In this blog, we demonstrate that PT2 can significantly accelerate the training of large and complex production AI models with reasonable compilation time. In our next blog, we will discuss how PT2 can do general graph transformations.

6. Acknowledgements

Many thanks to Mark Saroufim, Adnan Aziz, and Gregory Chanan for their detailed and insightful reviews.

Read More

Empowering Models with Performance: The Art of Generalized Model Transformation Approach

Empowering Models with Performance: The Art of Generalized Model Transformation Approach

Introduction

PyTorch 2.0 (PT2) offers a compiled execution mode which rewrites Python bytecode to extract sequences of PyTorch operations, translating them into a Graph IR. The IR is then just-in-time compiled through a customizable back end, improving training performance without user interference. Often, production models may go through multiple stages of optimization/lowering to hit performance targets. Therefore, having a compiled mode is desirable as it can separate the work of improving model performance from direct modification of the PyTorch model implementation. Thus, the compiled mode becomes more important, enabling Pytorch users to enhance model performance without modifying the PyTorch code implementation. This feature is particularly valuable for optimizing complex models, including large-scale and production-ready ones.

In our previous blog post , we outlined how heuristic model transformation rules are employed to optimize intricate production models. While these rules enabled substantial performance gains for some pilot models, they lacked universal adaptability; they don’t consistently perform well across different models or sometimes even within different sections of a single model.

Fig.1 PT1 Graph mode vs PT2 Compile mode.

Fig. 1: PT1 Graph mode vs PT2 Compile mode.

In this blog post, we propose a more generalized model transformation solution, serving as a plugin to the PT2 compiler as shown in Fig.1 which is more general, performant and user-friendly, bringing performance improvements to both model training and inference without manual efforts. As illustrated in Fig.2, by incorporating the previously user-defined transformations into the compiler, we have streamlined the production stack. These changes bring advantages to a broader range of PyTorch models, extending beyond just Meta models, which has already been incorporated in PT2 and is ready for use to benefit all Pytorch models.

Fig.2 Simplified stack with PT2 compile mode.

Fig. 2: Simplified stack with PT2 compile mode.

Guiding Principle: Atomic Rules

Traditionally, people might use predefined heuristic rules to replace a model subgraph with another more performant subgraph toreduce launch overhead, minimize memory bw, and fully occupy SMs. However, this approach doesn’t scale well as it is hard to craft a set of rules that fits all models perfectly.

Instead of grappling with bulky, complex rules, we can actually break them down into smaller, more digestible pieces – what we call ‘atomic rules’. These tiny powerhouses of efficiency target the transformation of individual operators, to conduct one step of the fusion/transformation. This makes them easy to handle and apply, offering a straightforward path to optimizing models. So, with these atomic rules in hand, optimizing any model for top-tier performance becomes a breeze!

We will walk through some simple examples to demonstrate how we use a chain of atomic rules to replace complicated heuristic rules.

Case 1: Horizontal fusion of computation chains started with accesses to embedding tables

Horizontal fusion means fusing parallel operators into one so as to reduce the number of kernels to be launched and improve performance. In our previous blog (Section 3.2), we described model transformations that fused layernorm and activation functions after embedding bags, as shown in the figure provided. However, this method, had limitations:

  1. It only worked with layernorm and activation functions after embedding.
  2. It was restricted to models with specific architecture rules, causing various issues in our production stack, including parameter changes and inference disruptions.

To improve, we can use three atomic rules as shown in Fig.3 to replace the complicated heuristic rule:

  • Fuse layernorms that follow the same split nodes horizontally.
  • Then, fuse tanh functions following the same split nodes horizontally.
  • Lastly, fuse vertical split-cat nodes.

These atomic rules offer a clean and streamlined way for model simplification and optimization.

Fig.3 Before, we optimized the model in one go by replacing subgraphs. Now, with atomic rules, we optimize step-by-step, covering more cases.

Fig. 3: Before, we optimized the model in one go by replacing subgraphs. Now, with atomic rules, we optimize step-by-step, covering more cases.

Case 2: Fuse horizontal MLP

MLPs (Multilayer Perceptrons) are fundamental components of deep neural networks, often consisting of linear, normalization, and activation functions. In complex models, there’s often a need to fuse many horizontal MLPs. Traditional methods find and replace parallel MLPs with a fused module as shown in Fig.4, but this isn’t always straightforward. Some models might not have normalization, or they might use different activation functions, making it hard to apply a one-size-fits-all rule.

This is where our atomic rules come in handy. These simplified rules target individual operators one at a time, making the process easier and more manageable. We use the following atomic rules for horizontal MLP fusion:

  • Fusing horizontal linear operators
  • Fusing horizontal layernorms.
  • Fusing horizontal activation functions.

Fig.4 Pseudocode for fusing MLP. Traditional optimizations need manual Python code changes.

Fig. 4: Pseudocode for fusing MLP. Traditional optimizations need manual Python code changes.

The beauty of these rules is that they’re not limited to one case. They can be applied broadly. Since PyTorch models are built with torch operators, focusing on a smaller set of operators simplifies the process. This approach is not only more manageable but also more general compared to writing a specific large pattern replacement rule, making it easier to optimize various models efficiently.

Compile-time Graph Search

Our principle is to use chained atomic rules to replace heuristic rules. While this approach covers a wider range of cases, it does entail a longer time for graph search and pattern matching. The next question is: how can we minimize compilation time while performing compile-time graph searches efficiently?

We design a two-step greedy algorithm as illustrated in Fig. 5. The first step in this process is to identify the target nodes, which we follow certain rules, e.g., identifying all linear operations with the same input shapes. Once identified, we use a Breadth-First Search (BFS) strategy to separate these nodes into different sets, so that nodes within a set don’t have data dependency. The nodes within each of these sets are independent and can be fused horizontally.

Fig.5 Process of model transformation with graph IR.

Fig. 5: Process of model transformation with graph IR.

With our approach, the search time is roughly 60 seconds for one of our largest internal models. It is manageable for on-the-fly tasks and

In the End

In our tests with internal ranking models, we observed approximately 5% to 15% training performance improvement across five models on top of the performance gain brought by torch.compile. We have enabled the optimization in PT2 compiler stack and landed it as default when users choose Inductor as the backend (config). We expect our generalized transformation approach could benefit models beyond Meta, and look forward to more discussion and improvement through this compiler level transformation framework.

Acknowledgements

Many thanks to Mark Saroufim, Gregory Chanan, Adnan Aziz, and Rocky Liu for their detailed and insightful reviews.

Read More

Understanding GPU Memory 1: Visualizing All Allocations over Time

Understanding GPU Memory 1: Visualizing All Allocations over Time

During your time with PyTorch on GPUs, you may be familiar with this common error message:

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 0 has a total capacity of 79.32 GiB of which 401.56 MiB is free.

In this series, we show how to use memory tooling, including the Memory Snapshot, the Memory Profiler, and the Reference Cycle Detector to debug out of memory errors and improve memory usage.

Memory Timeline

The Memory Snapshot tool provides a fine-grained GPU memory visualization for debugging GPU OOMs. Captured memory snapshots will show memory events including allocations, frees and OOMs, along with their stack traces.

In a snapshot, each tensor’s memory allocation is color coded separately. The x axis is over time, and the y axis is the amount of GPU memory in MB. The snapshot is interactive, so we can observe the stack trace for any allocation by mousing over.

In this snapshot, there are 3 peaks showing the memory allocations over 3 training iterations. When looking at the peaks, it is easy to see the rise of memory in the forward pass and the fall during the backward pass as the gradients are computed. It is also possible to see that the program has the same pattern of memory use iteration to iteration. One thing that stands out is the many tiny spikes in memory, by mousing over them, we see that they are buffers used temporarily by convolution operators.

Capturing Memory Snapshots

The API to capture memory snapshots is fairly simple and available in torch.cuda.memory:

  • Start: torch.cuda.memory._record_memory_history(max_entries=100000)
  • Save: torch.cuda.memory._dump_snapshot(file_name)
  • Stop: torch.cuda.memory._record_memory_history(enabled=None)

Code Snippet (for full code sample, see Appendix A):

   # Start recording memory snapshot history, initialized with a buffer
   # capacity of 100,000 memory events, via the `max_entries` field.
   torch.cuda.memory._record_memory_history(
       max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
   )

   # Run your PyTorch Model.
   # At any point in time, save a snapshot to file for later.
   for _ in range(5):
       pred = model(inputs)
       loss_fn(pred, labels).backward()
       optimizer.step()
       optimizer.zero_grad(set_to_none=True)

   # In this sample, we save the snapshot after running 5 iterations.
   #   - Save as many snapshots as you'd like.
   #   - Snapshots will save last `max_entries` number of memory events
   #     (100,000 in this example).
   try:
       torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
   except Exception as e:
       logger.error(f"Failed to capture memory snapshot {e}")

   # Stop recording memory snapshot history.
   torch.cuda.memory._record_memory_history(enabled=None)

To visualize the snapshot file, we have a tool hosted at https://pytorch.org/memory_viz. There, you can drag and drop your saved snapshot file and it will plot each allocation over time.

Memory Timeline

Alternatively, you can generate an HTML from a .pickle by using the script at pytorch/torch/cuda/_memory_viz.py, here is an example:

python torch/cuda/_memory_viz.py trace_plot snapshot.pickle -o snapshot.html

Debugging CUDA OOMs

Let’s look at how we can use the memory snapshot tool to answer:

  1. Why did a CUDA OOM happen?
  2. Where is the GPU Memory being used?

ResNet50 with a bug

We’ve taken a look at a properly working model in the first snapshot. Now, let’s take a look at a training example with a bug, see snapshot:

Memory Timeline

Notice how the second iteration uses far more memory than the first iteration. If this model were much larger, it could have CUDA OOM’d in the second iteration without much more insight into why.

Memory Timeline

When examining this snapshot further, we can clearly see that several tensors are staying alive from the first iteration to the second and later iterations. If we mouse over one of these tensors, it would show a stack trace suggesting that these were gradient tensors.

And indeed if we go to the code, we can see that it doesn’t clear the gradient tensors, when it could have cleared them before the forward.

Before:

        for _ in range(num_iters):
          pred = model(inputs)
          loss_fn(pred, labels).backward()
          optimizer.step()

After:

        for _ in range(num_iters):
          pred = model(inputs)
          loss_fn(pred, labels).backward()
          optimizer.step()
          # Add this line to clear grad tensors
          optimizer.zero_grad(set_to_none=True)

We can simply add an optimizer.zero_grad(set_to_none=True) instruction to clear the gradient tensors from iteration to iteration (more details about why we need to zero the gradients here: https://pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html).

This is a simplification of a bug we’ve found in more complicated programs using this tool. We encourage you to try out the Memory Snapshot on your GPU memory problems and let us know how it goes.

ResNet50 after bug fix

After applying the fix, the snapshot seems to be clearing the gradients now.

Memory Timeline

We now have the snapshot of a properly working ResNet50 model. Try out the code yourself (see code sample in Appendix A).

But you may be wondering, why is there still an increase in memory after the first iteration? To answer this, let’s visit the Memory Profiler in the next section.

Categorized Memory Usage

The Memory Profiler is an added feature of the PyTorch Profiler that categorizes memory usage over time. We still rely on the Memory Snapshot for stack traces for deep dives into memory allocations.

To generate a memory timeline, here is a code snippet (full code sample in Appendix B):

   # Initialize the profiler context with record_shapes, profile_memory,
   # and with_stack set to True.
   with torch.profiler.profile(
       activities=[
           torch.profiler.ProfilerActivity.CPU,
           torch.profiler.ProfilerActivity.CUDA,
       ],
       schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
       record_shapes=True,
       profile_memory=True,
       with_stack=True,
       on_trace_ready=trace_handler,
   ) as prof:
       # Run the PyTorch Model inside the profile context.
       for _ in range(5):
           prof.step()
           with record_function("## forward ##"):
               pred = model(inputs)

           with record_function("## backward ##"):
               loss_fn(pred, labels).backward()

           with record_function("## optimizer ##"):
               optimizer.step()
               optimizer.zero_grad(set_to_none=True)

   # Construct the memory timeline HTML plot.
   prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")

For further reference, see https://pytorch.org/docs/main/profiler.html.

The Memory Profiler automatically generates categories based on the graph of tensor operations recorded during profiling.

Memory Timeline

In this Memory Timeline collected using the Memory Profiler, we have the same training example as before. We can observe the gradients in blue are now being cleared from iteration to iteration. We can also notice that the optimizer state in yellow is allocated after the first iteration, and is kept constant for the rest of the job.

This optimizer state is the reason behind the increase of GPU memory from the first iteration to the second. Try out the code yourself (see code sample in Appendix B). The Memory Profiler helps to improve training memory understanding so that model authors can figure out which categories are using the most GPU memory.

Where can I find these tools?

We hope that these tools will greatly improve your ability to debug CUDA OOMs and to understand your memory usage by category.

The Memory Snapshot and the Memory Profiler are available in the v2.1 release of PyTorch as experimental features.

Feedback

We look forward to hearing from you about any enhancements, bugs or memory stories that our tools helped to solve! As always, please feel free to open new issues on PyTorch’s Github page.

We are also open to contributions from the OSS community, feel free to tag Aaron Shi and Zachary DeVito in any Github PRs for reviews.

Acknowledgements

Really appreciate the content reviewers, Mark Saroufim, Gregory Chanan, and Adnan Aziz for reviewing this post and improving its readability.

Appendix

Appendix A – ResNet50 Memory Snapshot Code Example

# (c) Meta Platforms, Inc. and affiliates. 
import logging
import socket
from datetime import datetime, timedelta

import torch

from torchvision import models

logging.basicConfig(
   format="%(levelname)s:%(asctime)s %(message)s",
   level=logging.INFO,
   datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)

TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

# Keep a max of 100,000 alloc/free events in the recorded history
# leading up to the snapshot.
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000

def start_record_memory_history() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not recording memory history")
       return

   logger.info("Starting snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(
       max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
   )

def stop_record_memory_history() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not recording memory history")
       return

   logger.info("Stopping snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(enabled=None)

def export_memory_snapshot() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not exporting memory snapshot")
       return

   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   try:
       logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")
       torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
   except Exception as e:
       logger.error(f"Failed to capture memory snapshot {e}")
       return

# Simple Resnet50 example to demonstrate how to capture memory visuals.
def run_resnet50(num_iters=5, device="cuda:0"):
   model = models.resnet50().to(device=device)
   inputs = torch.randn(1, 3, 224, 224, device=device)
   labels = torch.rand_like(model(inputs))
   optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
   loss_fn = torch.nn.CrossEntropyLoss()

   # Start recording memory snapshot history
   start_record_memory_history()

   for _ in range(num_iters):
       pred = model(inputs)
       loss_fn(pred, labels).backward()
       optimizer.step()
       optimizer.zero_grad(set_to_none=True)

   # Create the memory snapshot file
   export_memory_snapshot()

   # Stop recording memory snapshot history
   stop_record_memory_history()

if __name__ == "__main__":
    # Run the resnet50 model
    run_resnet50()

Appendix B – ResNet50 Memory Profiler Code Example

# (c) Meta Platforms, Inc. and affiliates. 
import logging
import socket
from datetime import datetime, timedelta

import torch

from torch.autograd.profiler import record_function
from torchvision import models

logging.basicConfig(
   format="%(levelname)s:%(asctime)s %(message)s",
   level=logging.INFO,
   datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)

TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

def trace_handler(prof: torch.profiler.profile):
   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   # Construct the trace file.
   prof.export_chrome_trace(f"{file_prefix}.json.gz")

   # Construct the memory timeline file.
   prof.export_memory_timeline(f"{file_prefix}.html", device="cuda:0")

def run_resnet50(num_iters=5, device="cuda:0"):
   model = models.resnet50().to(device=device)
   inputs = torch.randn(1, 3, 224, 224, device=device)
   labels = torch.rand_like(model(inputs))
   optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
   loss_fn = torch.nn.CrossEntropyLoss()

   with torch.profiler.profile(
       activities=[
           torch.profiler.ProfilerActivity.CPU,
           torch.profiler.ProfilerActivity.CUDA,
       ],
       schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
       record_shapes=True,
       profile_memory=True,
       with_stack=True,
       on_trace_ready=trace_handler,
   ) as prof:
       for _ in range(num_iters):
           prof.step()
           with record_function("## forward ##"):
               pred = model(inputs)

           with record_function("## backward ##"):
               loss_fn(pred, labels).backward()

           with record_function("## optimizer ##"):
               optimizer.step()
               optimizer.zero_grad(set_to_none=True)

if __name__ == "__main__":
    # Warm up
    run_resnet50()
    # Run the resnet50 model
    run_resnet50()

Read More