Powering the AI Revolution: The PyTorch Documentary

Powering the AI Revolution: The PyTorch Documentary

Now live: The official PyTorch Documentary! This film unveils the authentic narrative of PyTorch’s inception, attributing its existence to a dedicated group of unsung heroes driving technological innovation.

The documentary shares the strength of the PyTorch community, resonating with our communities across the globe. We hope this story of PyTorch inspires greater contributions, attracts more contributors to the project, and fosters widespread recognition of PyTorch’s significance in the open source community.

We couldn’t have produced this without the support of our PyTorch Foundation members and sponsors:

company logos

AMD

“PyTorch’s growth and adoption in the AI community is a testament to open collaboration. The collective efforts of all the contributors have helped propel PyTorch as one of the most widely adopted AI frameworks in the industry. AMD is proud to be a part of this movement – making sure that the future of AI is open – and we are excited to continue contributing to this vibrant ecosystem.”

– Niles Burbank, AMD

AWS

“The release of the PyTorch Documentary showcases the innovation and real-world impact of one of the most widely adopted open source machine learning frameworks. By supporting and contributing to the PyTorch community, AWS helps enable cutting-edge machine learning research that drives advancements in AI capabilities. We are excited about the documentary as it highlights the power of collaboration in propelling PyTorch to the forefront of machine learning and empowering developers and data scientists to create groundbreaking models. At AWS, we celebrate frameworks like PyTorch that foster environments where open source machine learning technologies can grow and benefit the community at-large, as well as our customers.”

– Brian Granger, AWS

Google Cloud

“Google recognizes the impact of PyTorch on the AI community, providing researchers and developers with powerful, flexible tools for innovation. This documentary not only celebrates the remarkable achievements of the PyTorch community but also highlights the collaborative spirit driving advancements in AI. We look forward to continuing our support for PyTorch and fostering an open ecosystem that accelerates machine learning research and application.”

– Dwarak Rajagopal, Google

Meta

“We have been so impressed with the growth and collaboration that PyTorch has created over the years. From very humble beginnings at Meta to a cornerstone in AI research and development, the documentary showcases the dedication of our contributors since the start. It’s an honor to be a part of something so impactful, and now it’s been documented for our community to take part in.”

– Soumith Chintala, Meta

Microsoft Azure

“We’re truly excited about the premiere of the PyTorch Documentary. At Microsoft, PyTorch has been our default deep learning framework for building AI solutions including Microsoft Copilot. Additionally, we have made significant investments to create an optimized environment for our customers to develop, train, fine-tune and deploy their PyTorch workloads on Azure and Windows, furthering our commitment to democratize AI.”

– Eric Boyd, Microsoft

PyTorch Foundation

“The release of the PyTorch documentary marks a significant milestone for our community, showcasing the incredible journey and rapid evolution of PyTorch. We are excited to share these stories and achievements with the world, and we look forward to continuing to foster innovation and growth of the PyTorch community and PyTorch’s evolving ecosystem.”

– Matt White, PyTorch Foundation

Read More

Training MoEs at Scale with PyTorch

Training MoEs at Scale with PyTorch

Over the past year, Mixture of Experts (MoE) models have surged in popularity, fueled by powerful open-source models like DBRX, Mixtral, DeepSeek, and many more. In this blog post, we’ll talk about how we scale to over three thousand GPUs using PyTorch Distributed and MegaBlocks, an efficient open-source MoE implementation in PyTorch.

What is a MoE?

A MoE model is a model architecture that uses multiple expert networks to make predictions. A gating network is used to route and combine the outputs of experts, ensuring each expert is trained on a different, specialized distribution of tokens. The architecture of a transformer-based large language model typically consists of an embedding layer that leads into multiple transformer blocks (Figure 1, Subfigure A). Each transformer block contains an attention block and a dense feed forward network (Figure 1, Subfigure B). These transformer blocks are stacked such that the output of one transformer block leads to the input of the next block. The final output goes through a fully connected layer and softmax to obtain probabilities for the next token to output.

When using a MoE in LLMs, the dense feed forward layer is replaced by a MoE layer which consists of a gating network and a number of experts (Figure 1, Subfigure D). The gating network, typically a linear feed forward network, takes in each token and produces a set of weights that determine which tokens are routed to which experts. The experts themselves are typically implemented as a feed forward network as well. During training, the gating network adapts to assign inputs to the experts, enabling the model to specialize and improve its performance. The router outputs are then used to weigh expert outputs to give the final output of the MoE layer.

Figure 1: Using Mixture of Experts in a transformer block

Figure 1: Using Mixture of Experts in a transformer block

Compared to dense models, MoEs provide more efficient training for a given compute budget. This is because the gating network only sends tokens to a subset of experts, reducing the computational load. As a result, the capacity of a model (its total number of parameters) can be increased without proportionally increasing the computational requirements. During inference, only some of the experts are used, so a MoE is able to perform faster inference than a dense model. However, the entire model needs to be loaded in memory, not just the experts being used.

The sparsity in MoEs that allows for greater computational efficiency comes from the fact that a particular token will only be routed to a subset of experts. The number of experts and how experts are chosen depends on the implementation of the gating network, but a common method is top k. The gating network first predicts a probability value for each expert, then routes the token to the top k experts to obtain the output. However, if all tokens always go to the same subset of experts, training becomes inefficient and the other experts end up undertrained. To alleviate this problem, a load balancing loss is introduced that encourages even routing to all experts.

The number of experts and choosing the top k experts is an important factor in designing MoEs. A higher number of experts allows scaling up to larger models without increasing computational cost. This means that the model has a higher capacity for learning, however, past a certain point the performance gains tend to diminish. The number of experts chosen needs to be balanced with the inference costs of serving the model since the entire model needs to be loaded in memory. Similarly, when choosing top k, a lower top k during training results in smaller matrix multiplications, leaving free computation on the table if communication costs are large enough. During inference, however, a higher top k generally leads to slower inference speed.

MegaBlocks

MegaBlocks is an efficient MoE implementation that uses sparse matrix multiplication to compute expert outputs in parallel despite uneven token assignment. MegaBlocks implements a dropless MoE that avoids dropping tokens while using GPU kernels that maintain efficient training. Prior to MegaBlocks, dynamic routing formulations forced a tradeoff between model quality and hardware efficiency. Previously, users had to either drop tokens from computation or waste computation and memory on padding. Experts can receive a variable number of tokens and the expert computation can be performed efficiently using block sparse matrix multiplication. We’ve integrated MegaBlocks into LLM Foundry to enable scaling MoE training to thousands of GPUs.

Figure 2: Matrix multiplication for expert computations

Figure 2: Matrix multiplication for expert computations

Expert Parallelism

As models scale to larger sizes and fail to fit on a single GPU, we require more advanced forms of parallelism. Expert parallelism is a form of model parallelism where we place different experts on different GPUs for better performance. Instead of expert weights being communicated across all GPUs, tokens are sent to the device that contains the expert. By moving data instead of weights, we can aggregate data across multiple machines for a single expert. The router determines which tokens from the input sequence should be sent to which experts. This is typically done by computing a gating score for each token-expert pair, and then routing each token to the top-scoring experts. Once the token-to-expert assignments are determined, an all-to-all communication step is performed to dispatch the tokens to the devices hosting the relevant experts. This involves each device sending the tokens assigned to experts on other devices, while receiving tokens assigned to its local experts.

The key advantage of expert parallelism is processing a few, larger matrix multiplications instead of several small matrix multiplications. As each GPU only has a subset of experts, it only has to do computation for those experts. Correspondly, as we aggregate tokens across multiple GPUs, the size of each matrix is proportionally larger. As GPUs are optimized for large-scale parallel computations, larger operations can better exploit their capabilities, leading to higher utilization and efficiency. A more in depth explanation of the benefits of larger matrix multiplications can be found here. Once the computation is complete, another all-to-all communication step is performed to send the expert outputs back to their original devices.

Figure 3: Token routing in expert parallelism

Figure 3: Token routing in expert parallelism

We leverage PyTorch’s DTensor, a low-level abstraction for describing how tensors are sharded and replicated, to effectively implement expert parallelism. We first manually place experts on different GPUs, typically sharding across a node to ensure we can leverage NVLink for fast GPU communication when we route tokens. We can then build a device mesh on top of this layout, which lets us succinctly describe the parallelism across the entire cluster. We can use this device mesh to easily checkpoint or rearrange experts when we need alternate forms of parallelism.

Scaling ZeRO-3 with PyTorch FSDP

In conjunction with expert parallelism, we use data parallelism for all other layers, where each GPU stores a copy of the model and optimizer and processes a different chunk of data. After each GPU has completed a forward and backward pass, gradients are accumulated across GPUs for a global model update.

ZeRO-3 is a form of data parallelism where weights and optimizers are sharded across each GPU instead of being replicated. Each GPU now only stores a subset of the full model, dramatically reducing memory pressure. When a part of the model is needed for computation, it is gathered across all the GPUs, and after the computation is complete, the gathered weights are discarded. We use PyTorch’s implementation of ZeRO-3, called Fully Sharded Data Parallel (FSDP).

As we scale to thousands of GPUs, the cost of communication across devices increases, slowing down training. Communication increases due to the need to synchronize and share model parameters, gradients, and optimizer states across all GPUs which involves all-gather and reduce-scatter operations. To mitigate this issue while keeping the benefits of FSDP, we utilize Hybrid Sharded Data Parallel (HSDP) to shard the model and optimizer across a set number of GPUs and replicate this multiple times to fully utilize the cluster. With HSDP, an additional all reduce operation is needed in the backward pass to sync gradients across replicas. This approach allows us to balance memory efficiency and communication cost during large scale distributed training. To use HSDP we can extend our previous device mesh from expert parallelism and let PyTorch do the heavy lifting of actually sharding and gathering when needed.

Figure 4: FSDP and HSDP

Figure 4: FSDP and HSDP

With PyTorch, we can effectively combine these two types of parallelism, leveraging FSDP’s higher level API while using the lower-level DTensor abstraction when we want to implement something custom like expert parallelism. We now have a 3D device mesh with expert parallel shard dimension, ZeRO-3 shard dimension, and a replicate dimension for pure data parallelism. Together, these techniques deliver near linear scaling across very large clusters, allowing us to achieve MFU numbers over 40%.

Elastic Checkpointing with Torch Distributed

Fault tolerance is crucial for ensuring that LLMs can be trained reliably over extended periods, especially in distributed environments where node failures are common. To avoid losing progress when jobs inevitably encounter failures, we checkpoint the state of the model, which includes parameters, optimizer states, and other necessary metadata. When a failure occurs, the system can resume from the last saved state rather than starting over. To ensure robustness to failures, we need to checkpoint often and save and load checkpoints in the most performant way possible to minimize downtime. Additionally, if too many GPUs fail, our cluster size may change. Accordingly, we need the ability to elastically resume on a different number of GPUs.

PyTorch supports elastic checkpointing through its distributed training framework, which includes utilities for both saving and loading checkpoints across different cluster configurations. PyTorch Distributed Checkpoint ensures the model’s state can be saved and restored accurately across all nodes in the training cluster in parallel, regardless of any changes in the cluster’s composition due to node failures or additions.

Additionally, when training very large models, the size of checkpoints may be very large, leading to very slow checkpoint upload and download times. PyTorch Distributed Checkpoint supports sharded checkpoints, which enables each GPU to save and load only its portion of the model. When combining sharded checkpointing with elastic training, each GPU reads the metadata file to determine which shards to download on resumption. The metadata file contains information on what parts of each tensor are stored in each shard. The GPU can then download the shards for its part of the model and load that part of the checkpoint.

Figure 5: Checkpointing saving and resumption resharded on additional GPUs

Figure 5: Checkpointing saving and resumption resharded on additional GPUs

By parallelizing checkpointing across GPUs, we can spread out network load, improving robustness and speed. When training a model with 3000+ GPUs, network bandwidth quickly becomes a bottleneck. We take advantage of the replication in HSDP to first download checkpoints on one replica and then send the necessary shards to other replicas. With our integration in Composer, we can reliably upload checkpoints to cloud storage as frequently as every 30 minutes and automatically resume from the latest checkpoint in the event of a node failure in less than 5 minutes.

Conclusion

We’re very excited to see how PyTorch is enabling training state-of-the-art LLMs with great performance. In our post, we’ve shown how we implemented efficient MoE training through Pytorch Distributed and MegaBlocks on Foundry. Furthermore, Pytorch elastic checkpointing allowed us to quickly resume training on a different number of GPUs when node failures occurred. Using Pytorch HSDP has allowed us to scale training efficiently as well as improve checkpointing resumption times. We look forward to continuing building on a strong and vibrant open-source community to help bring great AI models to everyone. Come join us in building great models at LLM Foundry and PyTorch.

Read More

🎉 PyTorch Docathon H2 2024 Wrap-up 🎉

We are thrilled to announce the successful completion of the H1 2024 PyTorch Docathon! The event was a resounding success, and we want to extend our heartfelt gratitude to all the participants who made it possible. Dedication, expertise, and tireless efforts of our open-source contributors have once again helped us to improve PyTorch documentation.

This Docathon ran from June 4 through June 20 with more than 176 registrants. The energy and enthusiasm were palpable, and entrants were judged on the difficulty of submissions that resulted in over 50 merged pull requests.

We want to give a special shout-out to our top contributors, who went above and beyond during this event. Your dedication and expertise have been invaluable in enhancing the PyTorch documentation and empowering developers worldwide.

Meet the top contributors

For the full list of participants, see here.

As we bring this Docathon to a close, we encourage each and every one of you to stay inspired and keep contributing to PyTorch documentation and code, and pushing the boundaries of what’s possible with PyTorch. Your collective efforts are shaping the landscape of deep learning and fostering innovation in the PyTorch community.

Thank you again for your participation and support. We look forward to seeing what you will achieve next!

Team PyTorch

Read More

Accelerating Neural Network Training with Semi-Structured (2:4) Sparsity

Accelerating Neural Network Training with Semi-Structured (2:4) Sparsity

Over the past year, we’ve added support for semi-structured (2:4) sparsity into PyTorch. With just a few lines of code, we were able to show a 10% end-to-end inference speedup on segment-anything by replacing dense matrix multiplications with sparse matrix multiplications.

However, matrix multiplications are not unique to neural network inference – they happen during training as well. By expanding on the core primitives we used earlier to accelerate inference, we were also able to accelerate model training. We wrote a replacement nn.Linear layer, SemiSparseLinear, that is able to achieve a 1.3x speedup across the forwards + backwards pass of the linear layers in the MLP block of ViT-L on a NVIDIA A100.

End-to-end, we see a wall time reduction of 6% for a DINOv2 ViT-L training, with virtually no accuracy degradation out of the box (82.8 vs 82.7 on ImageNet top-1 accuracy).

2 strategies for training a ViT model

We compare 2 strategies for training a ViT model for 125k iterations on 4x NVIDIA A100s: either fully dense (blue), or sparse for 70% of the training, then dense (orange). Both achieve similar results on the benchmarks, but the sparse variant trains 6% faster. For both experiments, we evaluate the intermediate checkpoints with and without sparsity.

As far as we are aware, this is the first OSS implementation of accelerated sparse training and we’re excited to provide a user API in torchao. You can try accelerating your own training runs with just a few lines of code:

# Requires torchao and pytorch nightlies and CUDA compute capability 8.0+
import torch
from torchao.sparsity.training import (
    SemiSparseLinear,
    swap_linear_with_semi_sparse_linear,
)

model = torch.nn.Sequential(torch.nn.Linear(1024, 4096)).cuda().half()

# Specify the fully-qualified-name of the nn.Linear modules you want to swap
sparse_config = {
    "seq.0": SemiSparseLinear
}

# Swap nn.Linear with SemiSparseLinear, you can run your normal training loop after this step
swap_linear_with_semi_sparse_linear(model, sparse_config)

How does this work?

The general idea behind sparsity is simple: skip calculations involving zero-valued tensor elements to speed up matrix multiplication. However, simply setting weights to zero isn’t enough, as the dense tensor still contains these pruned elements and dense matrix multiplication kernels will continue to process them, incurring the same latency and memory overhead. To achieve actual performance gains, we need to replace dense kernels with sparse kernels that intelligently bypass calculations involving pruned elements.

These kernels work on sparse matrices, which remove the pruned elements and store the specified elements in a compressed format. There are many different sparse formats, but we’re particularly interested in semi-structured sparsity, also known as 2:4 structured sparsity or fine-grained structured sparsity or more generally N:M structured sparsity.

2:4 sparse compressed representation

2:4 sparse compressed representation. Original Source

A 2:4-sparse matrix is a matrix where at most 2 elements are non-zero for every 4 elements, as illustrated in the image above. Semi-structured sparsity is attractive because it exists in a goldilocks spot of performance and accuracy:

  1. NVIDIA GPUs since Ampere offer hardware acceleration and library support (cuSPARSELt) for this format, with matrix multiplication being up to 1.6x faster
  2. Pruning models to fit this sparsity pattern does not degrade accuracy as much as other patterns. NVIDIA’s whitepaper shows pruning then retraining is able to recover accuracy for most vision models.

Illustration of 2:4 (sparse) matrix multiplication on NVIDIA GPUs

Illustration of 2:4 (sparse) matrix multiplication on NVIDIA GPUs. Original source

Accelerating inference with semi-structured sparsity is straightforward. Since our weights are fixed during inference, we can prune and compress the weight ahead of time (offline) and store the compressed sparse representation instead of our dense tensor.

flow chart

Then, instead of dispatching to dense matrix multiplication we dispatch to sparse matrix multiplication, passing in the compressed sparse weight instead of the normal dense one. For more information about accelerating models for inference using 2:4 sparsity, please refer to our tutorial.

Extending sparse inference acceleration to training

In order to use sparsity to reduce the training time of our models, we need to consider when the mask is calculated, as once we store the compressed representation the mask is fixed.

Training with a fixed mask applied to an existing trained dense model (also known as pruning) does not degrade accuracy, but this requires two training runs – one to obtain the dense model and another to make it sparse, offering no speedups.

Instead we’d like to train a sparse model from scratch (dynamic sparse training), but training from scratch with a fixed mask will lead to a significant drop in evaluations, as the sparsity mask would be selected at initialization, when the model weights are essentially random.

To maintain the accuracy of the model when training from scratch, we prune and compress the weights at runtime, so that we can calculate the optimal mask at each step of the training process.

Conceptually you can think of our approach as an approximate matrix multiplication technique, where we `prune_and_compress` and dispatch to `sparse_GEMM` in less time than a `dense_GEMM` call would take. This is difficult because the native pruning and compression functions are too slow to show speedups.

Given the shapes of our ViT-L training matrix multiplications (13008x4096x1024), we measured the runtime of a dense and sparse GEMM respectively at 538us and 387us. In other words, the pruning and compression step of the weight matrix must run in less than 538-387=151us to have any efficiency gain. Unfortunately, the compression kernel provided in cuSPARSELt already takes 380us (without even considering the pruning step!).

Given the max NVIDIA A100 memory IO (2TB/s), and considering that a prune and compress kernel would be memory bound, we could theoretically prune and compress our weight (4096x1024x2 bytes=8MB) in 4us (8MB / 2TB/s)! And in fact, we were able to write a kernel that prunes and compresses a matrix into 2:4-sparse format, and runs in 36 us (10x faster than the compression kernel in cuSPARSELt), making the entire GEMM (including the sparsification) faster. Our kernel is available for use in PyTorch.

Our custom sparsification kernel

Our custom sparsification kernel, which includes pruning + compression, is ~30% faster across a linear layer forward+backward. Benchmarks run on a NVIDIA A100-80GB GPU.

Writing a performant runtime sparsification kernel

There were multiple challenges we faced in order to implement a performant runtime sparsification kernel, which we will explore below.

1) Handling the backwards pass

For the backwards pass, we need to calculate dL/dX and dL/dW for the gradient update and the subsequent layer, which means we need to calculate xWT and xTW respectively.

Overview of runtime sparsification for training acceleration (FW + BW pass)

Overview of runtime sparsification for training acceleration (FW + BW pass)

However this is problematic, because the compressed representation cannot be transposed, since there’s no guarantee that the tensor is 2:4 sparse in both directions.

Both matrices are valid 2:4 matrices. However, the right one is no longer a valid 2:4 matrix once transposed because one column contains more than 2 elements

Both matrices are valid 2:4 matrices. However, the right one is no longer a valid 2:4 matrix once transposed because one column contains more than 2 elements

Therefore, we prune a 4×4 tile, instead of a 1×4 strip. We greedily preserve the largest values, ensuring that we take at most 2 values for each row / column. While this approach is not guaranteed to be optimal, as we sometimes only preserve 7 values instead of 8, it efficiently calculates a tensor that is 2:4 sparse both row-wise and column-wise.

We then compress both the packed tensor and the packed transpose tensor, storing the transpose tensor for the backwards pass. By calculating both the packed and packed transpose tensor at the same time, we avoid a secondary kernel call in the backwards pass.

Our kernel prunes the weight matrix in registers

Our kernel prunes the weight matrix in registers, and writes the compressed values in global memory. It also prunes at the same time W.t, which is needed for the backward pass, minimizing the memory IO

There’s some additional transpose trickery needed to handle the backwards pass – the underlying hardware only supports operations where the first matrix is sparse. For weight sparsification during inference, when we need to calculate xWT we rely on transpose properties to swap the order of the operands.

Math formula

During inference, we use torch.compile to fuse the outer transpose into subsequent pointwise ops in order to avoid paying a performance penalty.

However in the case of the backwards pass of training, we have no subsequent pointwise op to fuse with. Instead, we fuse the transposition into our matrix multiplication by taking advantage of cuSPARSELt’s ability to specify the row / column layout of the result matrix.

2) Kernel tiling for efficient memory-IO

In order for our kernel to be as efficient as possible, we want to coalesce our reads / writes, as we found that memory IO to be the main bottleneck. This means that within a CUDA thread, we want to read/write chunks of 128 bytes at a time, so that multiple parallel reads/writes can be coalesced into a single request by the GPU memory controller.

Therefore, instead of a thread handling a single 4×4 tile, which is only 4x4x2 = 32 bytes, we decided that each thread will handle 4 4×4 tiles (aka an 8×8 tile), which allows us to operate 8x8x2 =128 byte chunks.

Kernel tiling for efficient memory-IO

3) Sorting elements in a 4×4 tile without warp-divergence

For each individual 4×4 tile within our thread we calculate a bitmask that specifies which elements to prune and which elements to keep. To do this we sort all 16 elements and greedily preserve elements, so long as they do not break our 2:4 row / col constraint. This preserves only the weights with the largest values.

Crucially we observe that we are only ever sorting a fixed number of elements, so by using a branchless sorting network, we can avoid warp divergence.

Sorting network diagram

For clarity, the transposed packed tensor and metadata are omitted. Sorting network diagram taken from Wikipedia.

Warp divergence occurs when we have conditional execution inside across a thread block. In CUDA, work items in the same work group (thread block) are dispatched at the hardware level in batches (warps). If we have conditional execution, such that some work-items in the same batch run different instructions, then they are masked when the warp is dispatched, or dispatched sequentially.

For example, if we have some code like if (condition) do(A) else do(B), where condition is satisfied by all the odd-numbered work items, then the total runtime of this conditional statement is do(A) + do(B), since we would dispatch do(A) for all odd-numbered work-items, masking out even-numbered work-items, and do(B) for all even numbered work-items, masking out odd-numbered work-items. This answer provides more information about warp divergence.

4) Writing the compressed matrices and metadata

Once the bitmask has been computed, the weight data has to be written back in a compressed format in global memory. This is not trivial, because the data needs to stay in registers, and it’s not possible to index registers (eg C[i++] = a prevents us from storing C in registers). Furthermore, we found that nvcc was using many more registers than we expected, which caused register spilling and impacted global performance. We write this compressed matrix to global memory in Column-Major format to make the writes more efficient.

compressed matrix to global memory in Column-Major format

We also need to write the cuSPARSELt metadata as well. This metadata layout is quite similar to the one from the open-source CUTLASS library and is optimized for being loaded efficiently through shared-memory in the GEMM kernel with the PTX ldmatrix instruction.

However, this layout is not optimized to be written efficiently: the first 128 bits of the metadata tensor contains metadata about the first 32 columns of the rows 0, 8, 16 and 24. Recall that each thread handles an 8×8 tile, which means that this information is scattered across 16 threads.

We rely on a series of warp-shuffle operations, once for the original and transposed representation respectively to write the metadata. Fortunately, this data represents less than 10% of the total IO, so we can afford to not fully coalesce the writes.

DINOv2 Sparse Training: Experimental Setup and Results

For our experiments, the ViT-L model is trained on ImageNet for 125k steps using the DINOv2 method. All our experiments were run on 4x AMD EPYC 7742 64-core CPUs and 4x NVIDIA A100-80GB GPUs. During sparse training, the model is trained with 2:4 sparsity enabled for the first part of the training, where only half of the weights are enabled. This sparsity mask on the weights is dynamically recomputed at every step, as weights are continuously updated during the optimization. For the remaining steps, the model is trained densely, producing a final model without 2:4 sparsity (except the 100% sparse training setup), which is then evaluated.

Training setup ImageNet 1k log-regression
0% sparse (125k dense steps, baseline) 82.8
40% sparse (40k sparse -> 85k dense steps) 82.9
60% sparse (75k sparse -> 50k dense steps) 82.8
70% sparse (87.5k sparse -> 37.5k dense steps) 82.7
80% sparse (100k sparse -> 25k dense steps) 82.7
90% sparse (112.5k sparse -> 12.5k dense steps) 82.0
100% sparse (125k sparse steps) 82.3 (2:4-sparse model)

sparsity training diagrams

During the sparse training steps, in the backward pass we obtain a dense gradient for the sparse weights. For the gradient descent to be sound, we should also sparsify this gradient before using it in the optimizer to update the weights. Instead of doing that, we use the full dense gradient to update the weights – we found this to work better in practice: this is the STE (Straight Through Estimator) strategy. In other words, we update all the parameters at every step, even the ones we don’t use.

Conclusion and Future Work

In this blog post, we’ve shown how to accelerate neural network training with semi-structured sparsity and explained some of the challenges we faced. We were able to achieve a 6% end to end speedup on DINOv2 training with a small 0.1 pp accuracy drop.

There are several areas of expansion for this work:

  • Expansion to new sparsity patterns: Researchers have created new sparsity patterns like V:N:M sparsity that use the underlying semi-structured sparse kernels to allow for more flexibility. This is especially interesting for applying sparsity to LLMs, as 2:4 sparsity degrades accuracy too much, but we have seen some positive results for more general N:M pattern.
  • Performance optimizations for sparse fine-tuning: This post covers sparse training from scratch, but oftentimes we want to fine-tune a foundational model. In this case, a static mask may be sufficient to preserve accuracy which would enable us to make additional performance optimizations.
  • More experiments on pruning strategy: We calculate the mask at each step of the network, but calculating the mask every n steps may yield better training accuracy. Overall, figuring out the best strategy to use semi-structured sparsity during training is an open area of research.
  • Compatibility with fp8: The hardware also supports fp8 semi-structured sparsity (in the 4:8 format instead of 2:4), and this approach should work similarly with fp8 in principle. In practice, we would need to write similar sparsification kernels, and could possibly fuse them with the scaling of the tensors.
  • Activation Sparsity: Efficient sparsification kernels also enable to sparsify the activations during training. Because the sparsification overhead grows linearly with the sparsified matrix size, setups with large activation tensors compared to the weight tensors could benefit more from activation sparsity than weight sparsity. Furthermore, activations are naturally sparse because of the usage of ReLU or GELU activation functions, reducing accuracy degradation.

If you are interested in these problems, please feel free to open an issue / PR in torchao, a community we’re building for architecture optimization techniques like quantization and sparsity. Additionally, if you have general interest in sparsity please reach out in CUDA-MODE (#sparsity)

Read More

Reducing Model Checkpointing Times by Over 10x with PyTorch Distributed Asynchronous Checkpointing

Reducing Model Checkpointing Times by Over 10x with PyTorch Distributed Asynchronous Checkpointing

Summary: With PyTorch distributed’s new asynchronous checkpointing feature, developed with feedback from IBM, we show how IBM Research Team is able to implement and reduce effective checkpointing time by a factor of 10-20x. Example: 7B model ‘down time’ for a checkpoint goes from an average of 148.8 seconds to 6.3 seconds, or 23.62x faster.

This directly translates into either more net training progress for every given 24 hour period while continuing to robustly checkpoint or more frequent checkpoints to shorten recovery window/time.

In this note, we showcase the usage code and architecture that makes asynchronous checkpointing possible, along with timing results verified by IBM’s Research team.

Async Checkpointing vs Standard Checkpointing

Model checkpointing is a vital part of large model training, but checkpointing is an expensive process as each checkpoint process involves blocking training progress in order to save out the latest model weights. However, not checkpointing or reducing checkpointing frequency can result in a significant loss in training progress. For example, failures such as a deadlock, straggler, and gpu errors require the training process to be restarted. In order to restart from a failure, all (training) workers must stop their training process and be restarted from the last saved checkpoint.

Thus, the inherent tension between robustness to failures vs training progress plays out as a tradeoff, but now with asynchronous checkpointing, PyTorch Distributed is able to significantly reduce this tension and enable frequent checkpoint with minimal impact to the overall training time.

For background, it was almost exactly a year ago that we showcased how distributed checkpointing had massively sped up checkpointing times from the original torch.save() functionality. As IBM Research had noted, torch.save could take up to 30 minutes to checkpoint a single 11B model (PyTorch 1.13).

With advancements in distributed checkpointing, checkpoints could be done in under 4 minutes for up to 30B model sizes.

With asynchronous checkpointing, the training time lost due to checkpointing now moves to under 30 seconds, and often as short as 6 seconds.

To be clear, asynchronous checkpointing does not compress the actual serialization checkpointing time as the previous update showcased. Rather it moves the final checkpointing process off the critical path (to cpu threads) to allow GPU training to continue while finalizing the checkpoint under separate threads.

However, to the user, the effect is nearly the same in that down time for training due to checkpointing is substantially reduced, in many cases by 10x or even 20x.

Async Dist Checkpointing

As the above speedup chart shows, asynchronous checkpointing produces a 10x to 23x further improvement over the previous large improvements from a year ago.

How does Asynchronous Checkpointing work?

Asynchronous checkpointing modularizes the checkpointing process into two parts rather than one monolithic process. The first phase copies the data from each gpu/rank from GPU to CPU. This is the visible downtime to the user and can take from 6 – 14 seconds for 7B-13B model sizes. The second phase asynchronously copies the data from CPU memory to disk to persist the checkpoint.

Once data is copied to CPU in the first phase, the GPU is free to immediately resume training. Hence with asynchronous checkpointing the downtime for checkpointing is simply the time needed to copy over the latest model states to CPU.

At the same time that training resumes, non-blocking CPU threads work with the freshly arrived data in memory to complete the full checkpointing/serialization process to disk (i.e. persistent save).

flow diagram

Note that PyTorch’s Distributed Checkpointer relies on collective communication calls to per-rank metadata necessary to optimize saves, as well as a final synchronization which marks checkpointing as complete and makes the action atomic. This can interfere with distributed training (as distributed training also relies upon similar calls to synchronize training across multiple GPUs) if the Checkpointing thread utilizes the same process group used for training.

Specifically, a race condition between the calls could potentially cause training and asynch checkpointing save threads to wait on collective calls at the same time, resulting in a true collective hang.

We avoided this scenario by initializing a separate process group for async checkpointing. This separates the checkpointing collectives into their own logical process group, which thus ensures it will not interfere with collective calls in the main training threads.

How do I use Asynchronous Checkpointing in my training?

Usage of Asynchronous checkpointing is relatively straightforward. Using the latest nightly version of PyTorch, you will want to initialize your process group with both nccl and gloo. Gloo is required for the cpu threads portion.

From there, create a duplicate process group which the asynchronous checkpointing will utilize.
Then train as usual but at the point when you want to checkpoint, use the asynchronous save api, passing in the states to save, the checkpoint id and the checkpoint process group.

Code snippet

Asynchronous checkpointing is also fully implemented in torchtitan. Here, it is implemented for use with pre-training your own Llama2 or Lllama3 model. Using it is as simple as updating the toml config file:

Code snippet

Future work

Checkpointing has made huge strides over the past year. Moving from almost half an hour checkpoints to under 5 minutes with distributed checkpointing and now to under 30 seconds with asynchronous checkpointing.

The last frontier – zero overhead checkpointing where even the < 30 seconds is eliminated by streaming the updated weights during the backward pass such that checkpoint data is already on cpu at the point asynchronous checkpointing would kick in.

This would effectively move large model training to where checkpointing has no disruption or downtime enabling both more robustness (as checkpoints could be taken more frequently) and faster training progress due to no downtime for checkpointing.

Source code link: https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py

Read More

PyTorch Foundation Welcomes New Executive Director

PyTorch Foundation Welcomes New Executive Director

Matt White
The PyTorch Foundation is excited to welcome Matt White, our new executive director. The PyTorch Foundation formed in 2022 with the goal to drive adoption of AI tooling by fostering and sustaining an ecosystem of open source, vendor-neutral projects with PyTorch. Over the past 2 years, we’ve seen excellent growth across the project – with both contributor and member growth.

“I am honored to be a part of the PyTorch Foundation, working with such a passionate and skilled community,” said Matt White. “I am looking forward to working with our contributors and members to advance the PyTorch ecosystem through research, cutting edge technologies and open source best practices.”

Matt is a career technologist, researcher and innovator and has over 25 years of experience in AI, data, autonomous systems and simulations. He is the Co-founder and Chair of the Open Metaverse Foundation, a part of the Linux Foundation. Previously, Matt was the Director of the Generative AI Commons at the Linux Foundation, leading the advancement of open science and open-source artificial intelligence projects. He is also the GM of AI at the Linux Foundation.

Learn more about the PyTorch Foundation:

Read More

INT4 Decoding GQA CUDA Optimizations for LLM Inference

INT4 Decoding GQA CUDA Optimizations for LLM Inference

An efficient decoding Grouped-Query Attention with low-precision KV cache

Introduction

Generative AI has taken the world by storm with its ability to generate content like humans. Many of these generative AI tools are powered by large language models (LLMs), like Meta Llama models and OpenAI’s ChatGPT. One of the main challenges of LLMs is supporting large “context lengths” (also known as “sequence lengths”). The context length refers to the number of tokens that the model uses to understand the input context and generate responses. Longer context lengths generally translate into higher precision and quality in the responses. However, long context lengths are compute and memory intensive. This is mainly due to the following reasons:

  • The computational complexity of attention layers increases proportionally with the context length (the growth rate depends on the attention algorithm). As a result, when using long context lengths, the attention layers can become a bottleneck, particularly during the prefill phase where attentions are compute bound.
  • The KV cache size grows linearly with the context length, thus, putting higher pressure on the memory requirement and consequently slowing down the already memory-bound attention decoding. Moreover, since the memory capacity is limited, the batch size reduces when the KV cache gets bigger, which generally results in a drop in throughput.

The computational complexity growth is difficult to solve compared to the other problem mentioned above. One way to address the KV cache size growth problem is to use low precision KV cache. From our experiments, group-wise INT4 quantization provides comparable results in terms of accuracy compared to BF16 KV cache during the decode phase in Meta Llama 2 inference. However, we did not observe any latency improvement, despite reading 4x lesser data in attention decoding layers. This means that the INT4 attention is 4x less efficient at utilizing precious HBM bandwidth than BF16 attention.

In this note, we discuss the CUDA optimizations that we applied to INT4 GQA (grouped-query attention – the attention layer that we use in the LLM inference phase) to improve its performance by up to 1.8x on the NVIDIA A100 GPU and 1.9x on the NVIDIA H100 GPU.

  • The optimized CUDA INT4 GQA outperformed INT4 Flash-Decoding GQA (the best performing INT4 GQA that we used in the experiment mentioned above) by 1.4x-1.7x on A100 and 1.09x-1.3x on H100.
  • The optimized CUDA INT4 GQA performs better than BF16 Flash-Decoding GQA by 1.5x-1.7x on A100 and 1.4x-1.7x on H100.

Background

GQA for LLM Inference

Grouped-Query Attention (GQA) is a variant of multi-head attention (MHA) where each KV cache head is shared across a group of query heads. Our LLM inference adopts GQA as an attention layer in both the prefill and decode phases in order to reduce the capacity requirement for the KV cache. We use multiple GPUs in inference where the KV cache and query heads are distributed across GPUs. Each GPU runs an attention layer with a single KV head and a group of Q heads. Therefore, when viewed from a single GPU perspective, the GQA component can also be described as MQA (Multi-Query Attention).

The simplified workflow of decoding GQA is illustrated in Figure 1. GQA takes three main inputs: input query (denoted Q), K cache (denoted K), and V cache (denoted V). Our current GQA inference uses BF16 for Q, K, and V.

  • Q is a 4D BF16 tensor of shape (B, 1, HQ, D)
  • K is a 4D BF16 tensor of shape (B, Tmax, HKV, D)
  • V is a 4D BF16 tensor of shape (B, Tmax, HKV, D)

where

  • B is the batch size (the number of input prompts)
  • HQ is the number of query heads
  • HKV is the number of KV heads (HQ must be divisible by HKV)
  • Tmax is the maximum context length
  • D is the head dimension (fixed to 128)

GQA is simply bmm(softmax(bmm(Q, KT) / sqrt(D)), V). This yields a single output tensor (denoted as O) which is a 4D BF16 tensor that has the same shape as Q. Note that matrix multiplications are performed using BF16, however, accumulation and softmax are carried out in FP32. We call this “BF16 GQA” as the KV cache is BF16.

Figure 1: The simplified workflow of BF16 GQA for LLM inference

Figure 1 The simplified workflow of BF16 GQA for LLM inference

INT4 GQA

To further reduce the size of the KV cache, we explore the possibility of using INT4 for KV cache instead of BF16. We estimate the potential performance improvement by calculating the computational intensity (CI) of INT4 GQA and comparing it to that of BF16 GQA, as CI represents FLOPS per byte. We compute the CI for QKT and PV (as shown in Equation 1) as they take KV cache as an operand. Note that we disregard the Q load as it is negligible compared to the KV cache. We also ignore any intermediate data loads/stores that are not on global memory. Thus, the CI only takes into account the computation FLOPS and KV cache loads.

Equation 1

Equation (1)

Assuming that HQ = 8 and HKV = 1, CI for BF16 KV cache is 8 while CI for INT4 KV cache is 32. The CIs indicate that both BF16 and INT4 GQAs are memory bound (the peak CIs for BF16 tensor cores for A100 and H100 are 312 TF / 2 TB/s = 141 and 990 TF / 3.35 TB/s = 269; note that these TF numbers are without sparsity). Moreover, with INT4 KV cache, we should expect up to 4x performance improvement compared to BF16 GQA.

To enable INT4 KV cache support in GQA, we can dequantize the KV cache from INT4 to BF16 before passing it to the BF16 GQA operator. However, since KV cache is typically large, copying it from/to global memory can be costly. Moreover, decoding GQA is a memory bound operation (the memory unit is utilized much more heavily than the compute unit). Figure 2 shows the NCU profile of the FMHA CUTLASS BF16 GQA kernel in xFormers, which is one of the state of the art implementations of GQA. From the figure, it is obvious that memory is a bottleneck.

Figure 2: The NCU profile of the FMHA CUTLASS BF16 kernel in xFormers

Figure 2 The NCU profile of the FMHA CUTLASS BF16 kernel in xFormers

A more efficient alternative is to fuse INT4 dequantization with the GQA operation (shown in Figure 3). In other words, having GQA read INT4 KV cache directly and perform the INT4 to BF16 conversion within the kernel. This change can potentially reduce the amount of global memory reads required for the KV cache, which could lead to a decrease in latency. We call this “INT4 GQA.”

Figure 3: The workflow of fused INT4 GQA

Figure 3 The workflow of fused INT4 GQA

We list the state of the art implementations of GQA in the table below along with their features in Table 1.

Table 1 State of the art GQA implementations

Implementation Denote BF16 GQA Fused INT4 GQA
Flash-Decoding (Triton implementation) FD Yes Yes
Flash Attention (v2.3.3) FA Yes No
CUDA baseline CU Yes Yes

All implementations, except for CU, support both split-K and non split-K. CU only has the split-K implementation. Only FA has a heuristic in the backend to determine whether to run the split-K or non split-K kernel. For other implementations, users must explicitly choose which version to run. In this note, we focus on long context lengths (in our experiments, we use a context length of 8192) and therefore opt for the split-K version wherever possible.

As the baseline, we measured the performance of the state of the art GQA implementations on NVIDIA A100 and H100 GPUs. The latency (time in microseconds) and achieved bandwidth (GB/s) are reported in Table 2. Note that we ran a range of split-Ks (from 2 to 128 splits) and reported the best performance for each implementation. For all experiments, we use a context length of 8192. For INT4 GQA, we used row-wise quantization (i.e., num quantized groups = 1).

Table 2 Baseline GQA performance

On A100

Time (us) BF16 GQA INT4 GQA
Batch size FD FA CU FD FA CU
32 139 133 183 137 143
64 245 229 335 234 257
128 433 555 596 432 455
256 826 977 1127 815 866
512 1607 1670 2194 1581 1659
Effective Bandwidth (GB/s) BF16 GQA INT4 GQA
Batch size FD FA CU FD FA CU
32 965 1012 736 262 250
64 1097 1175 802 305 278
128 1240 968 901 331 314
256 1301 1100 954 351 331
512 1338 1287 980 362 345

On H100

Time (us) BF16 GQA INT4 GQA
Batch size FD FA CU FD FA CU
32 91 90 114 70 96
64 148 146 200 113 162
128 271 298 361 205 294
256 515 499 658 389 558
512 1000 1011 1260 756 1066
Effective Bandwidth (GB/s) BF16 GQA INT4 GQA
Batch size FD FA CU FD FA CU
32 1481 1496 1178 511 371
64 1815 1840 1345 631 443
128 1982 1802 1487 699 487
256 2087 2156 1634 736 513
512 2150 2127 1706 757 537

First, let’s discuss the BF16 GQA performance: CU ranks last in terms of performance among all implementations. FD and FA have comparable performance. When the batch size is less than or equal to 64, FA utilizes the split-K kernel and performs slightly better than FD. However, when the batch size is greater than 64, FD performs better.

The same trend holds true for INT4 GQAs. However, we did not measure the performance of FA as it does not support INT4 KV cache. FD outperforms CU for all cases.

When comparing the latencies of FD between BF16 and INT4 GQAs, we find that they are almost identical. This suggests that INT4 GQA is highly inefficient, which can be further confirmed by the significantly lower achievable bandwidth for INT4 GQA compared to BF16 GQA. The same trend is also true when looking at the performance of CU.

CUDA with Tensor Cores INT4 GQA Implementation

In this section, we briefly describe our baseline implementation which is CUDA with tensor cores INT4 GQA (CU). Each thread block processes only one KV head and a group of query heads from one input prompt. Therefore, each thread block performs mm(softmax(mm(Q, KT) / sqrt(D)), V); notice that mm is being performed not bmm. Moreover, since this is a split-K implementation, tokens in the KV cache are split among different thread blocks. Note that each thread block contains 4 warps (each warp contains 32 threads for NVIDIA A100 and H100 GPUs). Work in each thread block is split among warps. Within each warp, we use the WMMA API to compute matrix multiplication on tensor cores. Figure 4 demonstrates the work partitioning in CU.

Figure 4: CU work partitioning

Figure 4 CU work partitioning

Optimizing CUDA with Tensor Cores Kernel of INT4 GQA

In this note, we discuss the optimizations that we have applied to the CUDA with tensor cores implementation of INT4 GQA (CU). The ideal goal is to improve the INT4 GQA performance by 4 times based on the CI analysis in the previous section. Note that the query size is negligible compared to the KV cache size when the context length is long.

In our analysis, we used the NVIDIA Nsight Compute (NCU) as the main profiler. Our general bottleneck elimination approach is to minimize the stall cycles. We applied 10 optimizations to INT4 GQA, three of which are specific for NVIDIA A100/H100 GPUs. These optimizations are well known CUDA optimization techniques which can be generalized to many applications.

It is worth noting that the reason that we choose to optimize the CUDA implementation rather than the Flash-Decoding implementation (FD) (which is Triton based) is because with CUDA, we have a better control of how the low-level instructions are being generated. Many optimization techniques that we apply such as, operating on tensor core fragments directly (Optimizations 7-9), cannot be done through Triton since it does not expose low-level details to developers. However, these optimizations can be integrated into the compiler-based solution to make the optimizations available to broader operators, which is indeed a part of our future plan.

Optimization 1: Unroll K Loads

Problem Analysis:

The NCU profile shows that during K loading, there are only 2 global loads followed by memory stalls at dequantize_permuted_int4. The memory stalls are the long scoreboard stalls which indicates the waits for global memory access. This suggests that the kernel does not issue sufficient memory loads

to hide the global load latency. The kernel issues data loading, and then waits to consume the data immediately causing the global load latency to be exposed. The stalls are shown in Figure 5.

Figure 5: K loading before unrolling

Figure 5 K loading before unrolling (the numbers that the arrows point to are stall cycles caused by global memory wait)

Solution:

In the baseline implementation, we use uint32_t to load 8 INT4 K values in a single load and we perform 2 uint32_t loads in each iteration, which is 16 INT4 K values. To allow for a better global load latency hiding, we issue 8 uint32_t loads instead of two before consuming the K values in dequantize_permuted_int4. This allows the compiler to unroll the loads as well as reorder the instructions to hide the global load latency better. Figure 6 shows the NCU profile of K loading after unrolling. Comparing Figure 5 and Figure 6, we effectively reduce the stall cycles by unrolling the K loads.

Figure 6: K loading after unrolling

Figure 6 K loading after unrolling (the numbers that the arrows point to are stall cycles caused by global memory wait)

Results:

Table 3 Performance of Optimization 1 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 1 Baseline Opt 1
32 137 143 134 262 250 267 1.02 1.07
64 234 257 237 305 278 302 0.99 1.09
128 432 455 422 331 314 339 1.02 1.08
256 815 866 806 351 331 355 1.01 1.07
512 1581 1659 1550 362 345 369 1.02 1.07

Optimization 2: Improve P Type Casting (FP32->BF16)

Problem Analysis:

Since the product of softmax(bmm(Q, KT) / sqrt(D)) is FP32 (denoted as P in Figure 3), the kernel has to convert P from FP32 to BF16 before feeding it to the next bmm computation. The kernel performs the FP32 to BF16 conversion of P by copying the FP32 data from one location in shared memory to another location in shared memory. This causes stalls during the shared memory access (shown in Figure 7) which might be caused by (1) the shared memory indirection; and (2) the shared memory bank conflict since each thread accesses an 16-bit element (because of this, two threads can access the same memory bank simultaneously).

Figure 7: P type casting before Optimization 2

Figure 7 P type casting before Optimization 2 (the number that the arrow points to is stall cycles caused by shared memory wait)

Solution:

We use all threads in the thread block to do in-place type conversion. Each thread operates on two consecutive elements in order to avoid the shared memory bank conflict when storing BF16. All threads work on the same head (h) at the same time to guarantee correctness of the conversion. The in-place conversion steps are as follows:

  1. Each thread loads 2 FP32 token elements from the same head from the shared memory into registers
  2. Call __syncthreads() to make sure that every thread finishes reading the data
  3. Each thread converts its data to 2 BF16 token elements and then stores the results to the same shared memory

Some optimizations that we apply to the implementation:

  • Use vector types (especially nv_bfloat2)
  • Unroll data loading/storing, i.e., performing multiple loads before calling __syncthreads() and performing multiple stores after __syncthreads()

After this optimization, long stalls are not observed during P type casting as shown in Figure 8.

Figure 8: P type casting after Optimization 2

Figure 8 P type casting after Optimization 2 (the numbers that the arrow points to are stall cycles caused by shared memory wait)

Culprits:

Since we unroll data loading/storing by using registers as an intermediate storage, the number of registers per thread increases resulting in reduced occupancy.

Results:

Table 4 Performance of Optimization 2 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 2 Baseline Opt 2
32 137 143 126 262 250 285 1.09 1.14
64 234 257 221 305 278 324 1.06 1.16
128 432 455 395 331 314 362 1.09 1.15
256 815 866 749 351 331 382 1.09 1.16
512 1581 1659 1435 362 345 399 1.10 1.16

Optimization 3: Remove Local Memory Usage for max QKT computation

Problem Analysis:

During the softmax computation, the kernel has to compute max QKT for each head. It uses a temporary “thread-local” storage for storing per-thread max QKT results (one float value for each head). Depending on the compiler, the thread-local storage can be allocated on registers (on chip) or the local memory (off chip == global memory). Unfortunately, in the baseline, the thread-local storage resides in the local memory which is much slower than the registers (shown in Figure 9). We suspect that this is because the compiler cannot determine the indices of thread-local storage at compile time (since the number of heads (H) in the kernel is a runtime variable). Accessing local memory as if accessing registers can hurt the performance of the kernel.

Figure 9: Local memory access during max QKT computation

Figure 9 Local memory access during max QKT computation

Solution:

We realize that we do not need H (number of heads) floats as temporary storage per thread since each thread can compute max QKT for only one head instead of all the heads. Thus, we only need one float per thread, which can be easily stored in a register. To accumulate the max results among warps, we use shared memory. This optimization eliminates the local memory usage during max QKT computation.

Results:

Table 5 Performance of Optimization 3 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 3 Baseline Opt 3
32 137 143 119 262 250 300 1.14 1.20
64 234 257 206 305 278 348 1.14 1.25
128 432 455 368 331 314 389 1.17 1.24
256 815 866 696 351 331 411 1.17 1.24
512 1581 1659 1338 362 345 428 1.18 1.24

Optimization 4: Remove local memory usage for row sum

Problem Analysis:

Similar to Optimization 3, the local memory usage problem is also observed during the row sum computation in the softmax computation. Since local memory is off chip, accessing it as if accessing registers can hurt the performance of the kernel.

Solution:

We apply the same solution as the max QKT computation for the row sum computation. That is to have each thread compute a row sum of only one head, which requires only one float per thread. This eliminates the need for local memory.

Results:

Table 6 Performance of Optimization 4 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 4 Baseline Opt 4
32 137 143 118 262 250 302 1.15 1.21
64 234 257 204 305 278 351 1.15 1.26
128 432 455 364 331 314 393 1.19 1.25
256 815 866 688 351 331 416 1.18 1.26
512 1581 1659 1328 362 345 431 1.19 1.25

Optimization 5: Add prefetch for V load

Problem Analysis:

The same issue as K loading is observed when loading V. That is, the kernel issues data loading, and then waits to consume the data immediately causing the global load latency to be exposed. However, when using the unrolling technique mentioned above, the compiler allocates the temporary buffer on local memory instead of registers causing a large slow down.

Solution:

We adopt the data prefetching technique for V loading. We load the next iteration V values immediately after the current iteration values are consumed. This allows the data loading to be overlapped with the PK computation resulting in better kernel performance.

Results:

Table 7 Performance of Optimization 5 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 5 Baseline Opt 5
32 137 143 109 262 250 327 1.25 1.31
64 234 257 194 305 278 370 1.21 1.33
128 432 455 345 331 314 414 1.25 1.32
256 815 866 649 351 331 441 1.26 1.33
512 1581 1659 1244 362 345 460 1.27 1.33

Optimization 6: Add Group-Wise INT4 (Groups = 4) with Vector Load

Problem Analysis:

Prior to this optimization, CU only supported row-wise INT4 quantization. That is, every column in each row shares the same scales. The scales of each row are stored in the first 4 bytes of each row as shown in Figure 10. In the kernel, each thread loads only one row at a time. Since each row contains 68 bytes (4 bytes for scales and 64 bytes for data), it cannot guarantee that every row aligns with a size of any vector type. Thus, vector loads cannot be used for loading the KV cache.

Figure 10: The layout of each row of INT4 KV cache with row-wise quantization

Figure 10 The layout of each row of INT4 KV cache with row-wise quantization

Solution:

We have implemented support for group-wise INT4 quantization with num groups = 4. In this case, columns in each row in the KV cache tensor are divided into 4 equal groups. Columns within the same group share the same scales for quantization/dequantization. The data layout for INT4 KV cache is shown in Figure 11. The scales for all groups are serialized and stored at the beginning of each row. The INT4 data is also serialized and laid out next to the scales.

Because the number of bytes in each row now becomes 80 bytes, we can use a vector type, i.e., uint2 in our case, to load data. (We do not use uint4 since each thread loads only 16 INT4s at a time due to the tensor core fragment size.) Vector load is generally better than scalar load since it does not cause extra byte loads.

Figure 11: The layout of each row of INT4 KV cache with row-wise quantization

Figure 11 The layout of each row of INT4 KV cache with row-wise quantization

Results:

Table 8 Performance of Optimization 6 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 6 Baseline Opt 6
32 137 143 111 262 250 322 1.23 1.29
64 234 257 192 305 278 372 1.22 1.34
128 432 455 346 331 314 414 1.25 1.32
256 815 866 642 351 331 446 1.27 1.35
512 1581 1659 1244 362 345 460 1.27 1.33

Table 9 Performance of Optimization 6 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD
Opt 6 Opt 6
32 129 116 325 364 1.31
64 219 195 385 431 1.36
128 392 347 429 484 1.39
256 719 638 468 527 1.41
512 1375 1225 489 550 1.43

Optimization 7: Compute max QKT From WMMA Fragment Directly (A100/H100 specific)

Problem Analysis:

We observe large stalls due to shared memory accessing during the max QKT computation (showing as large short scoreboard stalls) as shown in Figure 12.

Figure 12: Stalls due to shared memory access during max QKT computation

Figure 12 Stalls due to shared memory access during max QKT computation (the number that the arrow points to is stall cycles caused by shared memory wait)

Solution:

We bypass shared memory when computing max QKT by computing it from the WMMA fragment (i.e., the tensor core fragment) directly. The layout of the WMMA fragment is specific to the GPU architecture. In this optimization, we only enabled this optimization for the NVIDIA A100/H100 GPUs. Other GPUs will still use shared memory for the max QKT computation. By bypassing shared memory, we effectively eliminate the stalls caused by shared memory access. The tensor core layout of the C fragment which is used for storing the QKT results is shown in Figure 13.

Figure 13: C fragment (QKT storage) tensor core layout on A100/H100

Figure 13 C fragment (QKT storage) tensor core layout on A100/H100

Table 10 Performance of Optimization 7 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 7 Baseline Opt 7
32 137 143 107 262 250 333 1.27 1.33
64 234 257 183 305 278 391 1.28 1.40
128 432 455 333 331 314 430 1.30 1.37
256 815 866 620 351 331 461 1.31 1.40
512 1581 1659 1206 362 345 475 1.31 1.38

Table 11 Performance of Optimization 7 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD vs CUDA_WMMA Opt 6
Opt 6 Opt 7 Opt 6 Opt 7
32 129 116 111 325 364 380 1.17 1.04
64 219 195 187 385 431 449 1.17 1.04
128 392 347 333 429 484 506 1.18 1.04
256 719 638 615 468 527 547 1.17 1.04
512 1375 1225 1184 489 550 569 1.16 1.03

Optimization 8: Write FP32->BF16 Results to P Fragment Directly (A100/H100 specific)

Problem Analysis:

During the FP32-BF16 conversion for the P fragment, the kernel loads the FP32 data from shared memory, does the conversion and then stores the BF16 data back to shared memory. Moreover, the conversion requires many thread block synchronizations (__syncthreads()).

Solution:

Due to the data partitioning design of the kernel, each warp performs only one pass through the P fragment. Thus, we do not have to write the conversion results back to the shared memory for future usage. To avoid writing the BF16 data to the shared memory and thread block synchronizations, we have each warp load the FP32 data of the P WMMA fragment from the shared memory, do the conversion and then write the BF16 data directly to the P fragment.

Note that this optimization is applied to only the NVIDIA A100 and H100 GPUs because the WMMA fragment layout is architecture dependent. For non-A100/H100 GPUs, the kernel will fallback to the original path.

The P fragment tensor core layout is shown in Figure 14. Note that this layout is specific to the NVIDIA A100/H100 GPU.

Figure 14: P fragment tensor core layout on A100/H100

Figure 14 P fragment tensor core layout on A100/H100

Table 12 Performance of Optimization 8 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 8 Baseline Opt 8
32 137 143 101 262 250 353 1.35 1.41
64 234 257 174 305 278 410 1.34 1.47
128 432 455 317 331 314 451 1.36 1.43
256 815 866 590 351 331 485 1.38 1.47
512 1581 1659 1143 362 345 501 1.38 1.45

Table 13 Performance of Optimization 8 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD vs CUDA_WMMA Opt 6
Opt 6 Opt 8 Opt 6 Opt 8
32 129 116 106 325 364 396 1.22 1.09
64 219 195 180 385 431 467 1.21 1.08
128 392 347 319 429 484 528 1.23 1.09
256 719 638 596 468 527 565 1.21 1.07
512 1375 1225 1138 489 550 591 1.21 1.08

Optimization 9: Swizzle P Shared Memory Layouts (A100/H100 specific)

Problem Analysis:

We observe large shared memory bank conflicts during P loading. The amount of bank conflict depends on the memory access stride. For instance, for split-Ks = 32 and max seq length = 8192, we observed that only 4 out of 32 banks are being accessed in parallel (memory access stride = 256). From Figure 14, when all threads access element 0, threads that have the same threadIdx.x % 4 access the same bank.

Figure 15: P fragment in shared memory before swizzling

Figure 15 P fragment in shared memory before swizzling

Solution:

We shuffle the layout of P load/store in the shared memory in such a way that avoids bank conflicts. In other words, we store the QKT results (C fragment) and load them (P fragment) using the swizzled layout. Moreover, instead of using the original memory access stride which is dependent on the number of tokens per thread block, we use the fragment’s column size as the stride which is constant. Thus, the load and store of the P fragment is always contiguous.

The new layouts for the C and P fragments are shown in Figure 16. With the new layout, it is guaranteed that 16 banks are being accessed in parallel as shown in Figure 17.

Figure 16: The swizzled layouts of C and P fragments

Figure 16 The swizzled layouts of C and P fragments

Figure 17: P fragment in shared memory after swizzling

Figure 17 P fragment in shared memory after swizzling

Table 14 Performance of Optimization 9 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 9 Baseline Opt 9
32 137 143 98 262 250 365 1.39 1.46
64 234 257 167 305 278 429 1.41 1.54
128 432 455 299 331 314 479 1.45 1.52
256 815 866 549 351 331 521 1.48 1.58
512 1581 1659 1060 362 345 540 1.49 1.56

Table 15 Performance of Optimization 9 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD vs CUDA_WMMA Opt 6
Opt 6 Opt 9 Opt 6 Opt 9
32 129 116 105 325 364 400 1.23 1.10
64 219 195 174 385 431 484 1.26 1.12
128 392 347 302 429 484 558 1.30 1.15
256 719 638 560 468 527 601 1.28 1.14
512 1375 1225 1065 489 550 632 1.29 1.15

Optimization 10: Pad Shared Memory for INT4 Dequantization

Problem Analysis:

Once the kernel reads the INT4 K or V cache from global memory, it performs dequantization and stores the results (BF16) in the shared memory. Then, the BF16 data is loaded to the WMMA fragment from shared memory (via the WMMA interface). We observed a large number of bank conflicts for both K and V accesses. For instance, for K stores, only 4 out of 32 banks are being accessed in parallel. For K loads, 16 banks are being accessed in parallel. The same also occurs for V stores and loads. See the figures in the solution section.

Solution:

We pad the shared memory to reduce the bank conflict. Specifically, we pad each row by 2. That is, the row stride of K becomes F_K + 2 and the row stride of V becomes F_N + 2 (F_K and F_N are the fixed widths of the K and V WMMA fragments, respectively). With this optimization, we are able to reduce the bank conflict by 1.8x as shown in Figure 18.

Figure 18: Bank conflicts before and after Optimization 10

Figure 18 Bank conflicts before and after Optimization 10

After Optimization 10, for K stores, 32 banks are being accessed in parallel (shown in Figure 19), while for K loads, 29 banks are accessed in parallel (shown in Figure 20).

Figure 19: K fragment store shared memory layout without and with padding

Figure 19 K fragment store shared memory layout without and with padding

Figure 20: K fragment load shared memory layout without and with padding

Figure 20 K fragment load shared memory layout without and with padding

Table 16 Performance of Optimization 10 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 10 Baseline Opt 10
32 137 143 94 262 250 380 1.45 1.52
64 234 257 151 305 278 475 1.55 1.71
128 432 455 266 331 314 538 1.63 1.71
256 815 866 489 351 331 586 1.67 1.77
512 1581 1659 930 362 345 616 1.70 1.79

Table 17 Performance of Optimization 10 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD vs CUDA_WMMA Opt 6
Opt 6 Opt 10 Opt 6 Opt 10
32 129 116 99 325 364 425 1.31 1.17
64 219 195 161 385 431 523 1.36 1.21
128 392 347 282 429 484 598 1.39 1.23
256 719 638 509 468 527 662 1.41 1.25
512 1375 1225 965 489 550 698 1.43 1.27

Performance Evaluation

Microbenchmark results

We also evaluated BF16 GQA performance using our optimized kernel (as shown in Table 19). CU still performs generally worse than FD and FA for BF16. This is expected since our optimizations are INT4 focused.

While INT4 GQA is still not as efficient as BF16 GQA (see the achieved bandwidths), it is important to note that when comparing FD BF16 GQA performance against CU INT4 GQA performance, we can see that the latency of INT4 is smaller than that of BF16.

Table 19 Performance of BF16 GQA and INT GQA after CU optimizations

On A100

Time (us) BF16 GQA INT4 GQA
Batch size FD FA CU before CU after FD FA CU before CU after
32 139 133 183 163 137 143 94
64 245 229 335 276 234 257 151
128 433 555 596 517 432 455 266
256 826 977 1127 999 815 866 489
512 1607 1670 2194 1879 1581 1659 930
Effective Bandwidth (GB/s) BF16 GQA INT4 GQA
Batch size FD FA CU before CU after FD FA CU before CU after
32 965 1012 736 824 262 250 380
64 1097 1175 802 972 305 278 475
128 1240 968 901 1039 331 314 538
256 1301 1100 954 1075 351 331 586
512 1338 1287 980 1144 362 345 616

On H100

Time (us) BF16 GQA INT4 GQA
Batch size FD FA CU before CU after FD FA CU before CU after
32 91 90 114 100 70 96 64
64 148 146 200 183 113 162 101
128 271 298 361 308 205 294 170
256 515 499 658 556 389 558 306
512 1000 1011 1260 1066 756 1066 575
Effective Bandwidth (GB/s) BF16 GQA INT4 GQA
Batch size FD FA CU before CU after FD FA CU before CU after
32 1481 1496 1178 1341 511 371 560
64 1815 1840 1345 1470 631 443 710
128 1982 1802 1487 1743 699 487 844
256 2087 2156 1634 1934 736 513 935
512 2150 2127 1706 2015 757 537 996

E2E results

We evaluated our optimized INT4 GQA kernel in Llama 2 70B on 8 H100 GPUs. We ran the model end-to-end, but only reported the decode latency. We use FP8 FFN (feed forward network) to emphasize the attention performance in the decoding phase. We vary the batch size from 1 to 256 and the context length from 2,048 (2K) to 16,384 (16K). The E2E performance results are shown in the figure below.

Figure 21: Meta Llama 2 decode latency (ms) comparison

Figure 21 Meta Llama 2 decode latency (ms) comparison (BF16 GQA runs out of memory in large batch size configurations)

Code

If you are interested, please checkout our code here. If you have any questions, please feel free to open an issue on GitHub, and we will be happy to help. Your contributions are welcome!

Read More

Ready, Set, Contribute: PyTorch Docathon Kickoff H1 2024

The PyTorch Docathon is now live! This event is dedicated to enhancing the quality of the PyTorch documentation with the invaluable assistance of our community. Our hope with this Docathon is to simplify the process for new users to get started with PyTorch, guide them in effectively utilizing its features, and ultimately expedite the transition from research to production in machine learning.

JOIN THE KICK-OFF EVENT
on June 4th at 10 AM PT

Event Details

  • June 4: Kick-off – join a 30-minutes livestream kick off event on Discord on June 4th at 10 AM PT here. If you can’t join the kick-off event, watch our welcome video on YouTube
  • June 4-June 16: Submissions and Feedback
  • June 17-18: Final Reviews
  • June 20: Winner Announcements

How to Contribute

Review the Docathon H1 2024 issue in the pytorch/pytorch or pytorch/tutorials repo that contain all the necessary information on participating in the Docathon and highlights the specific issues to work on. Remember to sign the CLA in your first PR and adhere to the Code of Conduct guidelines.

Read the Code of Conduct

Take a moment to review the PyTorch code of conduct found here. This document outlines the expectations for behavior and communication within our team, and it is important that everyone is aware of and adheres to these guidelines.

Join our Discord

This channel serves as the main communication hub during the Docathon. You can join it using by using this link:

JOIN DISCORD SERVER

When you first join the server, you will have limited access. To gain full access to our Discord PyTorch Docathon Channel:

  1. Enter the server and navigate to the #self-roles channel.
  2. In the #self-roles channel, click on the ‘Join Docathon’ button in the relevant post to assign yourself the docathon role.
  3. After assigning the role, you will see the ‘PyTorch Docathon H1 2024 Section’ in the left-hand menu for discussions.
  4. To help prevent spam we are asking that you change your server username to your GitHub username or the email username you registered with.

Explore the GitHub Issues

All the Docathon issues are posted on GitHub. You can find them by the docathon-h1-2024 label in the following participating repositories:

The issues are categorized into three levels of difficulty: easy, medium, and advanced. If this is your first time contributing to PyTorch, we recommend starting with an issue at the easy level.

Prizes for Winners

We will have a leaderboard throughout the duration of the Docathon. The more you contribute, the higher you’ll get on the board! Our top three winners will get free admission to PyTorch Conference 2024.

Thank you to our Partners

This year, we’re thrilled to work with the PyTorch Teams at Meta, Google and Snowflake to help us put on a successful event. We’ll also be at Snowflake Dev Day on June 6 where you can hear from Meta’s Matthias Reso, and check out our PyTorch booth.

Happy contributing!

Read More

Maximizing Training Throughput Using PyTorch FSDP and Torch.compile

Maximizing Training Throughput Using PyTorch FSDP and Torch.compile

Recently, we demonstrated how FSDP and selective activation checkpointing can be used to achieve 57% MFU (Model Flops Utilization) for training a 7B model on A100 GPUs. We also demonstrated how it can train a high quality model, which we open sourced as Granite 7B base model on Hugging Face Hub under the Apache v2.0 license.

We continued our quest to improve the utilization of GPUs by leveraging torch.compile. Using torch.compile and the selective activation checkpointing from our previous work, we achieve a MFU of 68% for the 7B model on A100 GPUs! torch.compile improves training MFU between 10% and 23% for various model sizes.

This blog is organized into three parts: (1) Challenges addressed in order to train using torch.compile, (2) Numerical parity of compile with no-compile, and (3) MFU report.

We open sourced all the code and updated it in the fms-fsdp repository. We are also working with Team PyTorch at Meta to contribute these to the newly released torch titan repository for pre-training.

Challenges of using torch.compile

torch.compile is a graph compilation technique that improves GPU utilization. For details on how torch compile works, we refer the readers to the recent PyTorch paper and associated tutorials. A key challenge in getting torch.compile to perform well is to minimize (or eliminate) graph breaks. We initially started with the Llama implementation provided by Meta, but compiling it caused too many graph breaks resulting in reduced training throughput.

Several portions of the model architecture had to be fixed, with the most important one being the positional embedding layer (RoPE). The typical RoPE implementation uses complex numbers, which was not supported in torch.compile at the time of testing. We implemented RoPE using einops while maintaining parity with the original model architecture implementation. We had to properly cache the frequencies so that we did not run into graph breaks within the RoPE implementation.

Compiling an FSDP model does result in graph breaks, which the PyTorch team at Meta is working to remove. However, these graph breaks as of PyTorch 2.3 are at FSDP unit boundaries and do not affect throughput significantly.

When using custom kernels, we need to wrap each kernel by exposing its API to torch.compile. This involves indicating what parameters are modified in-place, how they are modified, and what shapes and strides will their return values have based on the inputs. In our case, SDPA Flash attention is already integrated appropriately and we were able to get that kernel to work with torch.compile with no graph breaks.

We also noticed that when increasing the amount of data from 2T to 6T tokens, the data loader became a bottleneck. A key reason for this is the fact that previously, we implemented document shuffling in our dataloader naively, by having each worker maintain a list of shuffled document pointers.

With the larger dataset, these pointer lists were growing to hundreds of thousands of entries per worker. Maintaining pointer lists at this scale became expensive enough that cpu contention throttled our training throughput. We re-implemented document shuffling without any pointer lists using a Linear Congruential Generator. LCG is a pseudorandom number generator algorithm that implements a random walk over a population, providing sampling without replacement.

We leveraged the same idea to produce implicit bijective mappings from ordered to shuffled document indices. This enables us to shrink those annoying lists of hundreds of thousands of pointers down to a single integer state for the LCG. This eliminated 80% of the bottleneck and provided a significant boost to our performance. We will devote a separate blog to go into all the details of our performant pre-training data loader.

Numerical Parity of torch.compile and torch.no-compile

We had previously observed parity issues when training with compile and no-compile options, with one of these being related to the use of SDPA. After a few days of intense debugging sessions between the PyTorch teams at Meta and IBM, we were able to achieve parity between PyTorch compile and no-compile modes. To document and verify this parity, we take a mini-Llama model architecture of 1.4B size and train it to 100B tokens in four variations – no-compile, compile with no activation checkpointing, compile with selective activation checkpointing, and compile with full activation checkpointing.

We plot the loss curves and gradient norm for these options below:

Figure 1: Loss curve and gradient norm for various compile options

Figure 1: Loss curve and gradient norm for various compile options

Further, we run the lm-evaluation-harness and compare the various model scores on different benchmarks and observe no major differences between compile and no-compile, which is shown below.

Figure 2: lm-evaluation-harness comparison of various benchmarks between compile and no-compile

Figure 2: lm-evaluation-harness comparison of various benchmarks between compile and no-compile

We observe from all these results that compile with all its variants is equal to no-compile option, thus demonstrating parity between compile and no-compile.

MFU report

Finally, like our previous blog, we compute the MFU for four different model sizes on two clusters. One cluster is 128 A100 GPUs with 400 Gbps inter-node connectivity, and the other is 464 H100 GPUs with 3.2 Tbps inter-node connectivity. We use the selective activation checkpointing that we covered in the prior blog in addition to compile. We capture the results in the table below.

Model size Batch size MFU no-compile MFU compile Percentage gain (%)
7B 2 0.57 0.68 20
13B 2 0.51 0.60 17
34B 2 0.47 0.54 15
70B 2 0.50 0.55 10

Table 1: MFU results with compile and no compile for Llama2 model architectures on 128 A100 80GB GPUs with 400Gbps internode interconnect

Model size Batch size MFU no-compile MFU compile Percentage gain
7B 2 0.37 0.45 21
13B 2 0.35 0.43 23
34B 2 0.32 0.38 19
70B 2 0.32 0.38 19

Table 2: MFU results with compile and no compile for Llama2 model architectures on 464 H100 80GB GPUs with 3.2Tbps internode interconnect

We also had an internal production run on 448 GPUs using a Llama2 7B architecture. Using compile and selective activation checkpointing, with a global batch size of 3.7M, we trained for 4T tokens in 13 days 10 hours!

During training, the data center cooling had to kick in with extra air conditioning and our training team was alerted to this, since we were using the GPUs quite effectively ☺

One key observation from the tables 1 and 2 is that the MFU numbers do not linearly scale with model size. There are two possible explanations that we are actively investigating, one is the scalability of FSDP as model size increases and when tensor parallel needs to be enabled to more effectively use the GPU and the other is batch size, which can be increased further to get better MFU. We plan to explore FSDP v2 and selective operator checkpointing along with the tensor parallel feature to study the scaling laws of FSDP with model size.

Future Work

We plan to start testing FSDP v2 which will be released as part of PyTorch 2.4. FSDP2 provides per parameter sharding and selective operator checkpointing feature that can potentially provide even better memory-compute tradeoffs.

We have also been engaged with the PyTorch team at Meta to evaluate the new asynchronous checkpointing feature that can further improve the GPU utilization by reducing the time to write checkpoints.

We are exploring extending various Triton kernels currently used in inference to perform backward operations to gain speedups beyond inference only.

Finally, as recent work on use of fp8 is emerging, we plan to explore how we can even further accelerate model training using the new data type that promises a 2x acceleration.

Acknowledgements

There are several teams that have been involved in reaching this proof point and we would like to thank the teams across Meta and IBM. Specifically, we extend our gratitude to the Meta PyTorch distributed and compiler teams and IBM Research.

Multiple people were extensively involved in the effort of achieving torch.compile numerical parity with our models, and we wish to acknowledge the key folks involved in this effort; Animesh Jain and Less Wright at Meta, and Linsong Chu, Davis Wertheimer, Brian Vaughan, Antoni i Viros Martin, Mudhakar Srivatsa, and Raghu Ganti at IBM Research.

Special thanks to Stas Bekman, who provided extensive feedback and helped improve this blog. Their insights have been invaluable in highlighting key aspects of optimizing the training and exploring further enhancements.

Read More

Achieving Sustainability Goals with PyTorch and Intel AI

Achieving Sustainability Goals with PyTorch and Intel AI

This post was contributed by Intel AI in partnership with the PyTorch Foundation.

In 2017, the UN Global Compact emphasized digital technology, particularly open source, as crucial for achieving Sustainable Development Goals (SDGs), projecting a potential $2.1 trillion boost to the tech sector by 2030. The SDGs, part of the “2030 Agenda for Sustainable Development,” address global prosperity across various sectors.

The Linux Foundation’s Sustainability Initiative aligns projects with sustainable development goals. By assessing project impact, resources can be better allocated for enhancement. Intel is also a contributor to this initiative, and recently presented three use cases with PyTorch and Intel AI to address UN SDG-aligned issues.

Sustainability Goals

SDG 15: Life on Land

  • Using a bone likelihood map to pinpoint dinosaur bones, which paves the way for transfer learning to tackle contemporary challenges like wildfire prediction.
  • Employing transfer learning for wildfire prediction and generating data with Stable Diffusion.

SDG 9: Industry, Innovation, Infrastructure

  • Identifying crucial minerals, oil, and gas through subsurface models.

Here are the key highlights from the workshops. Read below for a summary, and be sure to watch the full workshop videos and visit the GitHub repositories.

Session 1: Introduction to Dinosaur Bone Bed Maps

Bob Chesebrough recently led a PyTorch workshop demonstrating how to create a dinosaur bone bed map for Dinosaur National Monument. He shared footage of his discoveries and explained his AI-driven approach, utilizing geological data to pinpoint possible bone-rich areas.

Attendees learned to set up JupyterLab, access the training section, and launch a BASH shell. Bob’s classification model, applied to aerial images, facilitated heatmap generation to identify potential bone locations, refined through field data. The GitHub repo “Jurassic” guided participants through directory setup and model optimization steps.

Rahul Unnikrishnan Nair demonstrated the use of PyTorch, focusing on performance enhancements. The workshop covered modeling best practices, such as data transformations, class distribution, dropout layers, and efficient training methods. Training and scoring procedures were examined, with a focus on model accuracy and transportability to other regions. Heatmap creation involved cutting images into tiles, considering context for accurate environmental identification.

Watch the full workshop video here and visit the GitHub repository to access the code sample and experiment with the code using Intel ® Extension for PyTorch. Try it out with PyTorch and explore what works best for you. Happy dinosaur bone hunting!

Session 2: Seismic Data to Subsurface Models with OpenFWI: Training an AI Model with PyTorch

Seismic exploration is crucial for subsurface imaging in mineral and oil/gas exploration. Full waveform inversion (FWI) recreates subsurface sound wave velocities, akin to ultrasound for the Earth.

Ben Consolvo, an AI Software Engineering Manager at Intel, presented training AI models directly from seismic data using PyTorch on Intel high-performance processors. FWI, though accurate, is computationally intensive and relies on precise initial models. AI models offer an alternative approach, learning directly from data without the need for precise initializations. Ben explained the challenges of AI models, highlighting the need for diverse datasets and the potential use of CPUs for fine-tuning. He also discussed FWI’s surprising medical applications.

Watch the full video here and go to the paper for more details. The GitHub repo is OpenFWI.

Session 3: Using PyTorch to Aid Wildfire Prediction

Forest fires pose significant threats to ecosystems, wildlife, and communities. Machine learning presents a promising approach to enhance prediction accuracy. In this Earth Day webinar, Bob Chesebrough and Rahul Unnikrishnan Nair demonstrated image analysis techniques using the MODIS dataset which was used to predict early forest fire probabilities. Through fine-tuning a ResNet18 model with the Intel® Extension for PyTorch, pre-trained models were adjusted with aerial photos, utilizing geo-spatial and color data for fire risk assessment.

Emphasizing the temporal and geographical filtering requirements for dataset analysis, showcasing images from fire-affected areas like Paradise, CA, the model’s adaptability to different hardware configurations was highlighted, along with the utilization of Stable Diffusion for data synthesis when real datasets were unavailable. The presenters encouraged audience engagement in PyTorch experimentation for early fire detection by extending a challenge to leverage these tools for critical predictive tasks. Join them in this endeavor to enhance wildfire prevention and protection efforts.

Watch the full video here and go to the paper for more details. The GitHub repo is ForestFirePrediction.

About the Intel Speakers

Bob Chesebrough, Sr Solutions Architect

Bob Chesebrough’s industry experience is software development/AI solution engineering for Fortune 100 companies and national laboratories for over three decades. He is also a hobbyist who has logged over 800 miles and 1000 hours in the field finding dinosaur bones. He and his sons discovered an important fossil of the only known crocodilian from the Jurassic in New Mexico, they have also discovered and logged into the museum 2000+ bones localities and described a new mass bone bed in New Mexico.

Rahul Unnikrishnan Nair, Architect in Applied AI and the Engineering Lead at Intel® Liftoff

In his current role at Intel® Liftoff for Startups program, Rahul Nair brings his extensive experience in applied AI and engineering to mentor early-stage AI startups. His dedication lies in helping these startups transform their innovative ideas into fully-fledged, market-ready products with a strong emphasis on use-case-driven, practical engineering and optimization.

Ben Consolvo, AI Software Engineering Manager

Ben Consolvo is an AI Solutions Engineering Manager at Intel. He has been building a team and a program around Intel’s AI technology paired with Intel’s hardware offerings. He brings a background and passion in data science, particularly in deep learning (DL) and computer vision. He has applied his skills in DL in the cybersecurity industry to automatically identify phishing websites, as well as to the oil and gas industry to identify subsurface features for geophysical imaging.

Kelli Belcher, AI Solutions Engineer

Kelli Belcher is an AI Solutions Engineer at Intel with over 5 years of experience across the financial services, healthcare, and tech industries. In her current role, Kelli helps build Machine Learning solutions using Intel’s portfolio of open AI software tools. Kelli has experience with Python, R, SQL, and Tableau, and holds a Master of Science in Data Analytics from the University of Texas.

Read More