New performance improvements in Amazon SageMaker model parallel library

Foundation models are large deep learning models trained on a vast quantity of data at scale. They can be further fine-tuned to perform a variety of downstream tasks and form the core backbone of enabling several AI applications. The most prominent category is large-language models (LLM), including auto-regressive models such as GPT variants trained to complete natural text. LLMs typically contain billions of parameters, making them rarely fit on one single accelerator, and require model parallelism techniques. Another category is diffusion models, notably Stable Diffusion, that has pushed AI image generation to an unprecedented milestone where remarkable visuals can be generated from a simple text description. Diffusion models are typically much smaller than LLMs and distributed training remains to play a critical role in facilitating development.

SageMaker model parallel (SMP) library is a large-model training solution available on Amazon SageMaker platform. It can be integrated with PyTorch models to easily apply a range of state-of-the-art large-model distributed training techniques to train at scale. Earlier this year, SMP launched sharded data parallelism, a distributed training technique powered by Amazon in-house MiCS technology under the hood. Sharded data parallel shards model parameters, gradients, and optimizer states across data-parallel workers. MiCS performs a number of optimizations including scale-aware partitioning to provide near-linear scalability.  In Train gigantic models with near-linear scaling using sharded data parallelism, we shared that sharded data parallel in SMP achieved  39.7% speed up compared to DeepSpeed ZeRO-3 on a 30B parameter GPT-2 model with sequence length 2048.

To help our customers further minimize training costs and accelerate time-to-market, we are thrilled to introduce two new performance improvements in SageMaker model parallel — SMDDP Collectives and FlashAttention. SMDDP Collectives is the most performant collective library on AWS infrastructure for large model training offered by SageMaker distributed data parallel library. FlashAttention is introduced in Dao et al., which re-implements the attention mechanism in an IO-aware manner, reducing the memory bandwidth requirement and saving on attention speed and memory footprint. These two components collectively push our sharded data parallel technique to be 30.58% faster when training a 100B parameter GPT-NeoX model on 32 p4d.24xlarge instances. For customers who are already using sharded data parallel on supported models, no code changes are necessary to benefit from the performance boost offered by these latest features. Stability AI, the inventor of the Stable Diffusion family of models that showed unparalleled image generation abilities, chose to use SMP to build foundation models. With SMP,  Stability AI achieved 163 TFLOPs per GPU for a 13B-parameter GPT-NeoX on 32 p4d.24xlarge instances, a 58% speed up compared to DeepSpeed. You can learn more about Stability AI’s mission and partnership with AWS in the talk of Stability AI CEO at AWS re:Invent 2022 or in this blog post.

“Our mission at Stability AI is to build the foundation to activate humanity’s potential through AI. To achieve this mission, we need to efficiently train open-source foundation models on hundreds of accelerated compute instances. We rely on SageMaker and its distributed training libraries to optimize performance and implement state-of-the-art strategies to shard models and data across our training cluster. These optimizations reduce our training costs, help us meet customer needs faster, and speed up the development of new models.”

— Emad Mostaque, Founder and CEO of Stability AI.

In this blog post, we’ll first present our latest performance improvements in the SageMaker model parallel library. Then, we’ll revisit how to train foundational models using sharded data parallel.  Finally, we’ll benchmark performance of 13B, 50B, and 100B parameter auto-regressive models and wrap up with future work.

New performance improvements in  SageMaker model parallel library

Starting from AWS Deep Learning Containers (DLC) PyTorch 1.12.1, SageMaker model parallel library v1.13 comes with the following two new components that are critical in improving training performance. They are currently available on ml.p4d.24xlarge instance with Elastic Fabric Adapter (EFA) enabled:

1. AWS-optimized AllGather from SMDDP Collectives

In sharded data parallel, since only a shard of the model state is present on a GPU, an AllGather collective is needed to gather the full set of parameters from across all GPUs in the sharding group during forward or backward pass computations. In the previous versions of SageMaker model parallel, we used NVIDIA Collective Communications Library (NCCL) for these collectives. However, NCCL is a general purpose collective communications library not designed for AWS infrastructure, which leads to sub-optimal performance even with EFA enabled.

Previously, we had developed the SMDDP Collectives library that provided an AWS-optimized implementation of the All-Reduce collective to speedup performance of pure data parallel training. To improve the performance of large model training with sharded data parallelism, we expanded the SMDDP Collectives library to include an optimized implementation of the AllGather collective. The key advantage of SMDDP Collectives AllGather is that it adopts an all-to-all-type communication pattern for inter-node communication, enabling our collective to have high-throughput and be less latency-sensitive. In addition, our AllGather collective offloads the communication-related processing to the CPU, thereby freeing up valuable GPU cycles for gradient computation, leading to significant performance improvement especially on large models.

2. FlashAttention

In modern transformer architecture, one of the largest sources of memory consumption is the activation footprint in the self-attention layer. This is because each attention head computes an SxS attention matrix for each input, where S is the sequence length, and this matrix goes through several operations, such as dropout, softmax, and matrix multiplication, with each intermediate output requiring memory space for use in back-propagation.

FlashAttention (Dao et al.) is a recent innovation from HazyResearch in Stanford that re-implements the self-attention mechanism in an I/O-aware manner. The main insight behind FlashAttention is that the self-attention mechanism is bottlenecked by memory bandwidth to and from GPU high bandwidth memory (HBM). This means that the self-attention layer can be computed in chunks across the sequence dimension, with each chunk going through the entire self-attention pipeline at a time. The intermediate results for a chunk are stored at the high-bandwidth SRAM, avoiding the expensive round-trip to the HBM for every iteration. Although a naive implementation would run into the issue of the cross-chunk dependency at the softmax layer, FlashAttention introduces a clever implementation that side-steps this dependency. Combined with re-computation in backward pass, FlashAttention results in substantial memory savings and performance improvement (25% faster training for GPT-NeoX 13B over 16 p4d nodes), due to avoidance of the HBM round-trip and storage of SxS matrices. You can find visuals and more explanations in HazyResearch’s FlashAttention repository.

Train foundation models at scale with SageMaker model parallel

To train foundation models with SMP powered by SMDDP Collectives, there’s no additional changes required in your sharded data parallel training jobs. If you’re new to using sharded data parallel, follow this complete tutorial notebook and blog post that will walk you through the entire process, from data processing, defining and submitting training jobs, to monitoring training logs. A ready-to-use training script for GPT-2 model can be found at train_gpt_simple.py. For training a different model type, you can follow the API document to learn about how to apply SMP APIs.

We highlight the key hyperparameters in the PyTorch Estimator of a sharded data parallel training job as below. The hyperparameter ddp_dist_backend in smp_options now has a new option, "auto" , as its default value. With "auto", SMP will use AWS-optimized AllGather for sharded data parallelism jobs and fall back to NCCL otherwise. You can refer to this document for supported configurations. If you want to run sharded data parallel in SMP specifically with NCCL as the communication backend of choice, you can set “ddp_dist_backend" to "nccl" in smp_options.

import sagemaker
from sagemaker.pytorch import PyTorch

smp_options = {
    "enabled": True,
    "parameters": {
        "ddp": True,
        "ddp_dist_backend": "auto", #OR "nccl" to disable SMDDP Collectives
        # To enable sharded data parallelism.
        # Here we shard model states across 128 GPUs.
        "sharded_data_parallel_degree": 128,  
    }
}

smp_estimator = PyTorch(
    entry_point="train_gpt_simple.py",
    role=sagemaker.get_execution_role(),
    instance_type='ml.p4d.24xlarge',
    instance_count=32,
    distribution={
        "smdistributed": {"modelparallel": smp_options},
        ...
    },
    ...
)

smp_estimator.fit(inputs=data_channels)

With the latest SMPv1.13 release, the sharded data parallel training technique supports FlashAttention for popular models including BERT, RoBERTa, GPT-2, GPT-J, GPT-Neo and GPT-NeoX out-of-the-box. This is enabled by passing tensor_parallelism=True during model creation without setting tensor_parallel_degree. You can find an example in the same training script train_gpt_simple.py .

Benchmarking performance

We benchmarked sharded data parallelism in the SageMaker model parallel library on three different scales of models to understand how the two new features, FlashAttention and AWS-optimized AllGather, contribute to performance improvement. Placement group is not required to reproduce these benchmarks on SageMaker.

13B parameter GPT-NeoX

In this setting, we focus on understanding the performance gain contributed by FlashAttention and we leave AWS-optimized AllGather out of the picture. Using flash attention saves substantial GPU memory, which helps us increase batch size or reduce sharding degree, thereby improving performance. As the below results show, we observed an average of about 20.4% speedup in SMP with flash attention for 13B parameter GPT-NeoX model on various configurations across 16-64 p4d nodes. Memory usage during standard attention computation scales in a quadratic manner with an increase in sequence length, but FlashAttention has memory usage linear in sequence length. Hence FlashAttention is even more helpful as sequence length increases and makes it possible to use larger sequence lengths. Being memory-efficient without trading off model quality, FlashAttention has gained traction quickly in the large model training community in the past months including integration with Hugging Face Diffusers and Mosaic ML.

Configuration Performance
Model/Training Cluster SMP Without FlashAttention
(TFLOPs/GPU)
 With FlashAttention
(TFLOPs/GPU)
% Speedup
13B GPT-NeoX
Seq length: 2048
Global batch size: 1024
FP16
16 p4d.24xlarge nodes Activation checkpointing
sharded_data_parallel_degree:64
gradient_accumulation: 1
130 159 22.31
13B GPT-NeoX
Seq length: 2048
Global batch size: 2048
FP16
32 p4d.24xlarge nodes Activation checkpointing
sharded_data_parallel_degree:64
gradient_accumulation: 1
131 157 19.85
13B GPT-NeoX
Seq length: 2048
Global batch size: 4096
FP16
64 p4d.24xlarge nodes Activation checkpointing
sharded_data_parallel_degree:64
gradient_accumulation: 1
131 156 19.08

50B parameter Bloom

Now, we look at how AWS-optimized AllGather from SMDDP Collectives speedup large model training with SMP. We benchmark a 50B-parameter Bloom model and compare the performance with and without AWS-optimized AllGather collective. We observe that SMDDP collectives speeds up model training by upto 40% across 32 nodes to 64 nodes training jobs. SMDDP collectives help achieve better performance due to better utilization of the 400 Gbps network bandwidth available with p4d.24xlarge instances. This coupled with the design choice to offload communication-related processing to the CPU, helps achieve good compute-to-network overlap leading to optimized performance. Compute-to-network overlap especially becomes important in large models since the size of data communicated across nodes scales linearly with an increase in the model size.

Configuration Performance
Model/Training Cluster SMP Without AWS-optimized AllGather
(TFLOPs/GPU)
 With AWS-optimized AllGather
(TFLOPs/GPU)
% Speedup
50B Bloom
Seq length: 2048
Global batch size: 2048
BF16
32 p4d.24xlarge nodes Activation checkpointing
sharded_data_parallel_degree:128
gradient_accumulation: 1
102 143 40.20
50B Bloom
Seq length: 2048
Global batch size: 4096
BF16
64 p4d.24xlarge nodes Activation checkpointing
sharded_data_parallel_degree:128
gradient_accumulation: 1
101 140 38.61

100B parameter GPT-NeoX

Finally, we benchmark SMP with both of the latest features enabled. It shows that this new release of SMP v1.13 is 30% faster than the previous version on a 100B-parameter GPT-NeoX model.

Configuration Performance
Model/Training Cluster SMP Without FlashAttention and without AWS-optimized AllGather
(TFLOPs/GPU)
With FlashAttention + AWS-optimized AllGather
(TFLOPs/GPU)
% Speedup
100B GPT-NeoX
Seq length: 2048
Global batch size: 2048
FP16
32 p4d.24xlarge nodes Activation checkpointing
sharded_data_parallel_degree:256
offload_activations

  • Without FlashAttention: batch size is 4 with gradient accumulation of 2 steps.
  • With FlashAttention: batch size is 8 with no gradient accumulation
121 158 30.58
100B GPT-NeoX
Seq length: 2048
Global batch size: 4096
FP16
64 p4d.24xlarge nodes Activation checkpointing
sharded_data_parallel_degree:256
offload_activations

  • Without FlashAttention: batch size is 4 with gradient accumulation of 2 steps.
  • With FlashAttention: batch size is 8 with no gradient accumulation
122 158 29.51

For future work, we’ll be working on supporting an AWS-optimized Reduce-Scatter in SMDDP Collectives. The Reduce-Scatter collective is critical in averaging and sharding gradients computed in the backward pass. We expect this to further speed up SMP library in the future releases.

Conclusion

In this post, we discuss the two latest performance improvements for sharded data parallel technique in SageMaker model parallel library. LLMs show great promise in improving the quality and re-usability of ML models. AWS teams are working closely with customers to keep reducing their training costs and time-to-market. You can find more SageMaker model parallel examples in Amazon SageMaker Examples GitHub repo or attend our next distributed training workshops. If you are interested in speeding up large model training, check out these features and let us know what you build!


About the authors

Arjun Balasubramanian is a Senior Software Engineer at AWS focused on building high-performance, hardware accelerated collective communication algorithms for distributed deep learning. He is broadly interested in systems for large-scale machine learning and networking. Outside of work, he enjoys traveling and playing various sports.

Zhaoqi Zhu is a Software Development Engineer at AWS, specializing in distributed deep learning systems and working on the SageMaker Distributed Data Parallel library. Outside of work, Zhaoqi is passionate about soccer and hopes to not receive any red card in the upcoming season.

Can Karakus is a Senior Applied Scientist at AWS, optimizing large-scale distributed deep learning on AWS. His research interests cover deep learning, distributed optimization, distributed systems, and information theory. Outside of work, he enjoys cycling, traveling, reading and learning.

Rahul Huilgol is a Senior Software Engineer at AWS. He works on distributed deep learning systems, towards making it easy and performant to train large deep learning models in the cloud. In his spare time, he enjoys photography, biking and gardening.

Suhit Kodgule is a Software Development Engineer with AWS Artificial Intelligence group working on deep learning frameworks. In his spare time, he enjoys hiking, traveling and cooking.

Fei Wu is a Software Engineer at AWS. He works on distributed training for large-scale deep learning models on cloud. Outside of work, he enjoys basketball, gaming and cooking.

Read More