Deep Learning Energy Measurement and Optimization

Deep Learning Energy Measurement and Optimization

Zeus logo

This post is authored by Jae-Won Chung, a PhD student at the University of Michigan and the lead of the ML.ENERGY Initiative.

Deep learning consumes quite a bit of energy. For instance, training a single 200B LLM on AWS p4d instances consumed around 11.9 GWh (source: CIDR 2024 keynote), which is an amount that can single-handedly power more than a thousand average US households for a year.

Zeus is an open-source toolbox for measuring and optimizing the energy consumption of deep learning workloads. Our goal is to make energy optimization based on accurate measurements as easy as possible for diverse deep learning workloads and setups by offering composable tools with minimal assumptions.

Zeus largely provides two types of tools:

  1. Programmatic and command line GPU energy measurement tools
  2. Several energy optimization tools that find the best ML and/or GPU configurations

Zeus can benefit those who would like to

  • measure and optimize their electricity cost
  • reduce heat dissipation from their GPUs (by lowering power draw)
  • report energy usage from research and development
  • reduce carbon footprint from electricity usage

Part 1: Measuring Energy

Just like performance optimization, accurate measurement is the basis of effective energy optimization. Popular proxies for estimating power consumption like the maximum power draw of the hardware can sometimes be vastly off compared to actual measurement.

To make energy measurement as easy and transparent as possible, the core utility Zeus offers is the ZeusMonitor class. Let’s take a look at the actual snippet:

from zeus.monitor import ZeusMonitor

# All four GPUs are measured simultaneously.
monitor = ZeusMonitor(gpu_indices=[0,1,2,3])

# Measure total time and energy within the window.
monitor.begin_window("training")
for e in range(100):

    # Measurement windows can arbitrarily be overlapped.
    monitor.begin_window("epoch")
    for x, y in train_dataloader:
        y_hat = model(x)
        loss = criterion(y, y_hat)
        loss.backward()
        optim.step()
    measurement = monitor.end_window("epoch")
    print(f"Epoch {e}: {measurement.time} s, {measurement.total_energy} J")

measurement = monitor.end_window("training")
print(f"Entire training: {measurement.time} s, {measurement.total_energy} J")

<script src=”https://gist.github.com/jaywonchung/f580b782ff0513374c6fa507d5e072a8.js”></script>

What you see above is a typical PyTorch training loop which uses four GPUs for data parallel training. Inside, we created an instance of ZeusMonitor and passed in a list of GPU indices to monitor. Then, using the monitor, we can measure the time and energy consumption of arbitrary execution windows within the training script by pairing calls to begin_window and end_window. Multiple windows can overlap and nest in arbitrary ways without affecting the measurement of each, as long as their names are different.

ZeusMonitor adds very little overhead – typically single digit milliseconds – around the window. This allows ZeusMonitor to be used in various applications. For instance:

  • The ML.ENERGY Leaderboard: The first open-source benchmark on how much energy LLM text generation consumes.
  • The ML.ENERGY Colosseum: An online service that lets users compare LLM responses side-by-side based on response quality and energy consumption.

See our blog post for a deeper technical dive into accurate GPU energy measurement.

Part 2: Optimizing Energy

Let me introduce you to two of the energy optimizers provided by Zeus.

GlobalPowerLimitOptimizer

GPUs allow users to configure its maximum power draw, called power limit. Typically, as you lower the GPU’s power limit from the default maximum, computation may get slightly slower, but you’ll save disproportionately more energy. The GlobalPowerLimitOptimizer in Zeus automatically finds the optimal GPU power limit globally across all GPUs.

from zeus.monitor import ZeusMonitor
from zeus.optimizer.power_limit import GlobalPowerLimitOptimizer

# The optimizer measures time and energy through the ZeusMonitor.
monitor = ZeusMonitor(gpu_indices=[0,1,2,3])
plo = GlobalPowerLimitOptimizer(monitor)

for e in range(100):
    plo.on_epoch_begin()
    for x, y in train_dataloader:
        plo.on_step_begin()

        y_hat = model(x)
        loss = criterion(y, y_hat)
        loss.backward()
        optim.step()

        plo.on_step_end()
    plo.on_epoch_end()

<script src=”https://gist.github.com/jaywonchung/1922ddd56b15f8764f2bdacc4a441109.js”></script>

In our familiar PyTorch training loop, we have instantiated GlobalPowerLimitOptimizer and passed it an instance of the ZeusMonitor, through which the optimizer sees the GPUs. Then, we just need to let the optimizer know about training progress (step and epoch boundaries), and the optimizer will transparently do all the necessary profiling and converge to the optimal power limit.

If you’re using the HuggingFace Trainer or SFTTrainer, integration is even easier:

from zeus.monitor import ZeusMonitor
from zeus.optimizer.power_limit import HFGlobalPowerLimitOptimizer

# ZeusMonitor actually auto-detects CUDA_VISIBLE_DEVICES.
monitor = ZeusMonitor()
pl_optimizer = HFGlobalPowerLimitOptimizer(monitor)

# Pass in the optimizer as a Trainer callback. Also works for SFTTrainer.
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    ...,
    callbacks=[pl_optimizer],
)

<script src=”https://gist.github.com/jaywonchung/69aa379dd9633a6a486cede1887cec2c.js”></script>

The HFGlobalPowerLimitOptimizer wraps GlobalPowerLimitOptimizer so that it automatically detects step and epoch boundaries. We have example integrations here, including running Gemma 7B supervised fine-tuning with QLoRA.

Now, we know how to integrate the optimizer, but what is the optimal power limit? We know different users can have different preferences regarding trading off time and energy, so we allow users to specify an OptimumSelector (basically the Strategy Pattern) to express their needs.

# Built-in strategies for selecting the optimal power limit.
from zeus.optimizer.power_limit import (
    GlobalPowerLimitOptimizer,
    Time,
    Energy,
    MaxSlowdownConstraint,
)

# Minimize energy while tolerating at most 10% slowdown.
plo = GlobalPowerLimitOptimizer(
    monitor,
    MaxSlowdownConstraint(factor=1.1),
)

<script src=”https://gist.github.com/jaywonchung/1077b14bc7440b849be1f8320d4bf791.js”></script>

Some of the built-in strategies include “Minimize time” (Time, this might still reduce the power limit from the default since some workloads exhibit almost no slowdown even on lower power limits), “Minimize energy” (Energy), “Somewhere in between” (ZeusCost), and “Minimize energy given maximum slowdown” (MaxSlowdownConstraint). Users can also create their own optimum selectors as needed.

PipelineFrequencyOptimizer

The pipeline frequency optimizer, based on our research paper Perseus, is our latest work on energy optimization for large model training, like GPT-3. Perseus can reduce the energy consumption of large model training with no or negligible training throughput degradation. We’ll briefly talk about how.

one iteration of training with four stage pipeline parallelism

The above is a visualization of one iteration of training with four stage pipeline parallelism running with the 1F1B schedule. Each box is either a forward or a backward computation, and is colored with its power consumption.

The key observation here is that when models are partitioned into pipeline stages, it’s very difficult to slice them in perfectly equal sizes. This leads to forward/backward boxes of varying widths and therefore computation idle time between boxes. You would notice that those smaller boxes can run slightly slower than wider boxes and the overall critical path (blue line) will not change at all.

one iteration of training with four stage pipeline parallelism

That’s what Perseus automatically does. Based on profiling, it identifies computation boxes that are not on the critical path and figures out the precise amount of slowdown for each box that minimizes energy consumption. When done correctly, computations we slowed down will consume less power & energy, but the overall iteration time of the pipeline does not change.

See our guide to get started with Perseus!

Final Words

For users who run their own on-premise compute, energy consumption and the resulting electricity bill is not something that can be easily overlooked. On a larger scale, energy consumption is not just about electricity bills, but also about data center power delivery. With thousands of GPUs running in clusters, finding stable, affordable, and sustainable electricity sources to power data centers is becoming increasingly challenging. Finding ways to reduce energy disproportionately more than slowdown leads to lower average power consumption, which can help with the power delivery challenge.

With Zeus, we hope to take the first step towards deep learning energy measurement and optimization.

Wondering where to go from here? Here are a couple helpful links:

Read More

Introducing depyf: mastering torch.compile with ease

Introducing depyf: mastering torch.compile with ease

depyf logo

We are thrilled to introduce depyf, a new project to the PyTorch ecosystem designed to help users understand, learn, and adapt to torch.compile!

Motivation

torch.compile is a cornerstone of PyTorch 2.x, offering a straightforward path to accelerate machine learning workflows with just a single line of code for both training and inference. The mere inclusion of @torch.compile can dramatically enhance the performance of your code. However, identifying the optimal insertion point for torch.compile is not easy, not to mention the complexity of adjusting various knobs for maximum efficiency.

The intricacies of the torch.compile stack, encompassing Dynamo, AOTAutograd, Inductor, and more, present a steep learning curve. These components, essential for deep learning performance optimization, can be daunting without a solid foundation in the subject.

Note: For an introductory example of how torch.compile works, please refer to this walk-through explanation.

A common tool: TORCH_COMPILE_DEBUG

To demystify torch.compile, the common approach involves leveraging the TORCH_COMPILE_DEBUG environment variable. While it provides more information, deciphering the output remains a formidable task.

For example, when we have the following code:

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   main()

And run it with TORCH_COMPILE_DEBUG=1 python test.py , we will get a directory named torch_compile_debug/run_2024_02_05_23_02_45_552124-pid_9520 , under which there are these files:

.
├── torchdynamo
│   └── debug.log
└── torchinductor
   ├── aot_model___0_debug.log
   ├── aot_model___10_debug.log
   ├── aot_model___11_debug.log
   ├── model__4_inference_10.1
   │   ├── fx_graph_readable.py
   │   ├── fx_graph_runnable.py
   │   ├── fx_graph_transformed.py
   │   ├── ir_post_fusion.txt
   │   ├── ir_pre_fusion.txt
   │   └── output_code.py
   ├── model__5_inference_11.2
   │   ├── fx_graph_readable.py
   │   ├── fx_graph_runnable.py
   │   ├── fx_graph_transformed.py
   │   ├── ir_post_fusion.txt
   │   ├── ir_pre_fusion.txt
   │   └── output_code.py
   └── model___9.0
       ├── fx_graph_readable.py
       ├── fx_graph_runnable.py
       ├── fx_graph_transformed.py
       ├── ir_post_fusion.txt
       ├── ir_pre_fusion.txt
       └── output_code.py

The generated files and logs often raise more questions than they answer, leaving developers puzzled over the meaning and relationships within the data. Common puzzles for TORCH_COMPILE_DEBUG include:

  • What does model__4_inference_10.1 mean?
  • I have one function but three model__xxx.py in the directory, what is their correspondence?
  • What are those LOAD_GLOBAL stuff in debug.log ?

A better tool: depyf comes to rescue

Let’s see how depyf can help developers to resolve the above challenges. To use depyf , simply execute pip install depyf or follow the project page https://github.com/thuml/depyf to install the latest version, and then surround the main code within with depyf.prepare_debug .

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   import depyf
   with depyf.prepare_debug("depyf_debug_dir"):
       main()

After executing python test.py , depyf will produce a directory named depyf_debug_dir (the argument of the prepare_debug function). Under the directory, there would be these files:

.
├── __compiled_fn_0 AFTER POST GRAD 0.py
├── __compiled_fn_0 Captured Graph 0.py
├── __compiled_fn_0 Forward graph 0.py
├── __compiled_fn_0 kernel 0.py
├── __compiled_fn_3 AFTER POST GRAD 0.py
├── __compiled_fn_3 Captured Graph 0.py
├── __compiled_fn_3 Forward graph 0.py
├── __compiled_fn_3 kernel 0.py
├── __compiled_fn_4 AFTER POST GRAD 0.py
├── __compiled_fn_4 Captured Graph 0.py
├── __compiled_fn_4 Forward graph 0.py
├── __compiled_fn_4 kernel 0.py
├── __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_8.py
├── __transformed_code_0_for_toy_example.py
├── __transformed_code_1_for_torch_dynamo_resume_in_toy_example_at_8.py
└── full_code_for_toy_example_0.py

And there are two obvious benefits:

  1. The long and difficult-to-understand torchdynamo/debug.log is gone. Its content is cleaned up and shown as human-readable source code, in full_code_for_xxx.py and __transformed_code_{n}_for_xxx.py . It is worth to note, that the most tedious and difficult job of depyf is to decompile the bytecode inside torchdynamo/debug.log into Python source code, freeing developers from intimidating internals of Python.
  2. The correspondence between function names and computation graphs are respected. For example, in __transformed_code_0_for_toy_example.py , we can see a function named __compiled_fn_0 , and we will immediately know its corresponding computation graphs are in __compiled_fn_0_xxx.py , because they share the same __compiled_fn_0 prefix name.

Starting with full_code_for_xxx.py , and following the functions involved, users will have a clear view of what torch.compile does to their code.

One more thing: step-through debuggability

Stepping through code line by line using debuggers is a great way to understand how code works. However, under TORCH_COMPILE_DEBUG , those files are only for users’ information, and cannot be executed with the data users concern.

Note: By “debug”, we mean the process of inspecting and improving a program, rather than correcting buggy code.

A standout feature of depyf is its capability to facilitate step-through debugging for torch.compile: all of the files it generates are linked with runtime code objects inside Python interpreter, and we can set breakpoints in these files. The usage is simple, just add one context manager with depyf.debug() , and it should do the trick:

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   import depyf
   with depyf.prepare_debug("depyf_debug_dir"):
       main()
   with depyf.debug():
       main()

Just one caveat: the workflow of debugging torch.compile deviates from standard debugging workflow. With torch.compile, many codes are dynamically generated. Therefore, we need to:

  1. launch the program
  2. when the program exits with depyf.prepare_debug("depyf_debug_dir") , code will be available in depyf_debug_dir.
  3. when the program enters with depyf.debug() , it will automatically set a breakpoint internally, so that the program is paused.
  4. navigate to depyf_debug_dir to set breakpoints.
  5. continue to run the code, and debuggers will hit these breakpoints!

depyf screenshot

Here is a screenshot of what it looks like. All code and tensor variables are live, and we can inspect any variable, and step through the code, as in our daily debugging workflow now! The only difference is that we are debugging torch.compile generated code rather than human-written code.

Conclusion

torch.compile serves as an invaluable tool for accelerating PyTorch code effortlessly. For those looking to delve deeper into torch.compile, whether to leverage its full potential or to integrate custom operations, the learning curve can be very steep though. depyf is designed to lower this barrier, offering a user-friendly experience to understand, learn, and adapt to torch.compile.

Do explore depyf and experience its benefits firsthand! The project is open-source and readily available at https://github.com/thuml/depyf. Installation is straightforward via pip install depyf. We hope depyf can enhance everyone’s development workflow with torch.compile.

Read More

A Hitchhiker’s Guide to Speculative Decoding

A Hitchhiker’s Guide to Speculative Decoding

Speculative decoding is an optimization technique for inference that makes educated guesses about future tokens while generating the current token, all within a single forward pass. It incorporates a verification mechanism to ensure the correctness of these speculated tokens, thereby guaranteeing that the overall output of speculative decoding is identical to that of vanilla decoding. Optimizing the cost of inference of large language models (LLMs) is arguably one of the most critical factors in reducing the cost of generative AI and increasing its adoption. Towards this goal, various inference optimization techniques are available, including custom kernels, dynamic batching of input requests, and quantization of large models.

In this blog post, we provide a guide to speculative decoding and demonstrate how it can coexist with other optimizations. We are proud to open source the following, which includes the first speculator for Llama3 models:

  1. Speculator models for Meta Llama3 8B, IBM Granite 7B lab, Meta Llama2 13B, and Meta Code Llama2 13B.
  2. The code for inference via IBM’s fork of HF TGI.
  3. The code for training your own speculators and corresponding recipes.

We have deployed these speculators in an internal production-grade environment with thousands of daily users and observed 2x speedup on language models – Llama3 8B, Llama2 13B, and IBM Granite 7B and 3x speedup on IBM’s Granite 20B code models. We provide a detailed explanation of our approach in this technical report and are planning in-depth analysis in an upcoming ArXiv paper.

Speculative decoding: Inference

We run IBM TGIS in our internal production environment that has optimizations such as continuous batching, fused kernels, and quantization kernels. To enable speculative decoding in TGIS, we modified the paged attention kernel from vLLM. In what follows, we will describe the key changes to the inference engine to enable speculative decoding.

Speculative decoding is based on the premise that the model is powerful enough to predict multiple tokens in a single forward pass. However, the current inference servers are optimized to predict only a single token at a time. In our approach, we attach multiple speculative heads (in addition to the usual one) to the LLM to predict N+1-, N+2-, N+3-th … token. For example, 3 heads will predict 3 additional tokens. Details of the speculator architecture are explained in a later part of this blog. There are two challenges to achieve efficiency and correctness during inference – one is to predict without replicating KV-cache and the other is to verify that the predictions match the original model’s outcomes.

In a typical generation loop, after the prompt is processed in a single forward step, a sequence length of 1 (next token predicted) is fed into the forward pass of the model along with the kv-cache. In a naive speculative decoding implementation, each speculative head would have its own kv-cache, but instead we modify the paged attention kernel developed in the vLLM project to enable efficient kv-cache maintenance. This ensures that throughput does not reduce at larger batch sizes. Further, we modify the attention masks to enable verification of the N+1’th token and thus enable speculative decoding without deviating from the original model’s output. The details of this implementation are captured here.

Results

We illustrate the speedup obtained with the Meta’s chat versions of Llama2 13B using a simple prompt.

Visual illustration of the non-speculative generation (left) compared to speculative generation (right)

Figure 2: Visual illustration of the non-speculative generation (left) compared to speculative generation (right)

We deployed the above solution in an internal production environment. The figure below reports two metrics – time to first token (TTFT) and inter-token latency (ITL) with different numbers of concurrent users (which is captured in the numbers on the graph lines). We observe that the speculative decoding version is nearly twice as fast for the Llama2 13B chat model and nearly thrice as fast for the Granite 20B code model compared to the non-speculative version for all batch sizes. We observe similar behavior for the smaller models – IBM’s Granite 7B and Meta Llama3 8B models.

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Llama 13B with number of concurrent users indicated on the graph

Figure 3: Time to first token (TTFT – left) and Inter-token latency (ITL – right) for Llama 13B with number of concurrent users indicated on the graph

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Granite 20B Code with number of concurrent users indicated on the graph

Figure 4: Time to first token (TTFT – left) and Inter-token latency (ITL – right) for Granite 20B Code with number of concurrent users indicated on the graph

Note on efficiency

We performed numerous experiments to determine the right configuration for speculator training. These are:

  1. Speculator architecture: The current approach allows for the number of heads to be modified, which maps to the number of tokens that we can look ahead. Increasing the number of heads also increases the amount of extra compute needed and complexity of training. In practice, for language models, we find 3-4 heads works well in practice, whereas we found that code models can reap benefits from 6-8 heads.
  2. Compute: Increasing the number of heads results in increased compute in two dimensions, one is that of increased latency for a single forward pass as well as the compute needed for multiple tokens. If the speculator is not accurate with more heads, it will result in wasted compute increasing the latency and reducing the throughput.
  3. Memory: The increased compute is offset by the roundtrips to HBM that need to be done for each forward pass. Note that if we get 3 tokens lookahead correct, we have saved three round trip times on HBM.

We settled on 3-4 heads for the language models and 6-8 heads for the code models and across different model sizes ranging from 7B to 20B, we observed significant latency improvements without throughput loss compared to non-speculative decoding. We begin to observe throughput reduction beyond a batch size of 64, which happens rarely in practice.

Speculative decoding: Training

There are two broad approaches for speculative decoding, one is to leverage a smaller model (e.g., Llama 7B as a speculator for Llama 70B) and the other is to attach speculator heads (and train them). In our experiments, we find the approach of attaching speculator heads to be more effective both in model quality and latency gains.

Speculator architecture

Medusa made speculative decoding popular; their approach is to add a head to the existing model which is then trained to do speculation. We modify the Medusa architecture by making the “heads” hierarchical, where each head stage predicts a single token and then feeds it to the next head stage. These multi-stage heads are depicted in the below figure. We are exploring ways of minimizing the embeddings table by sharing these across the multiple stages and base model.

A simple architecture diagram for a 3-headed multi-stage  speculator. Z is the state from the base model.

Figure 4: A simple architecture diagram for a 3-headed multi-stage speculator. Z is the state from the base model.

Speculator training

We have a two-phase approach to training a speculator for efficiency reasons. In the first phase, we train on small batches with long sequence lengths (4k tokens) and use the standard causal LM approach for training. In phase 2, we use large batches with short sequence lengths (256 tokens) generated from the base model. In this training phase, we tune the heads to match the output of the base model. Through numerous experiments, we find that a 5:2 ratio of steps for phase 1 vs phase 2 works well. We depict the progress of these phases in the below figure. We use PyTorch FSDP and IBM FMS for the training of speculators.

Per-head training loss curves for Llama2-13B speculator training, phase 1 and 2

Figure 5: Per-head training loss curves for Llama2-13B speculator training, phase 1 and 2

Conclusion and Future Work

Through this blog, we are releasing a new approach for speculative decoding and the following assets:

  1. Models for improving the inter-token latencies for a range of models – Llama3 8B, Llama2 13B, Granite 7B, and CodeLlama 13B
  2. Production quality code for inference
  3. Recipes for training speculators

We are working on training speculators for Llama3 70B and Mistral models and invite the community to contribute as well as help improve on our framework. We would also love to work with major open source serving frameworks such as vLLM and TGI to contribute back our speculative decoding approach to benefit the community.

Acknowledgements

There are several teams that helped us get to these latency improvements for inference. We would like to thank the vLLM team for creating the paged attention kernel in a clean and reusable manner. We extend our gratitude to the Team PyTorch at Meta that helped provide feedback on this blog as well as continued efforts on optimal usage of PyTorch. Special thanks to our internal production teams at IBM Research who took this prototype to production and hardened it. A shout out to Stas Bekman for providing insightful comments on the blog resulting in an improved explanation of the tradeoffs between compute, memory, and speculator effectiveness.

The paged attention kernel was integrated into IBM FMS by Josh Rosenkranz and Antoni Viros i Martin. The speculator architecture and training was done by Davis Wertheimer, Pavithra Ranganathan, and Sahil Suneja. The integration of the modeling code with the inference server was done by Thomas Parnell, Nick Hill, and Prashant Gupta.

Read More

Announcing PyTorch Docathon June, 2024

We are thrilled to announce the upcoming PyTorch Docathon in June! The Docathon, akin to a hackathon, is an event dedicated to enhancing the quality of the PyTorch documentation with the invaluable assistance of our community. Documentation is a vital component of any technology. By refining it, we can simplify the process for new users to get started with PyTorch, guide them in effectively utilizing its features, and ultimately expedite the transition from research to production in machine learning. See our previous events here and here.

Why Participate

The Docathon is an inclusive event designed to be accessible to newcomers, requiring only a basic understanding of Python, PyTorch, and Machine Learning, with some tasks not even requiring these skills. It offers a rewarding experience as participants can see the direct impact of their contributions on the project’s usability and accessibility. The Docathon promotes a collaborative environment, allowing participants to work with other contributors and PyTorch maintainers, fostering the exchange of ideas and networking. It also provides a rich learning experience, offering the opportunity to explore PyTorch modules, update docstrings, and test tutorials.

Event Details

June 4: Kick-off
June 4-June 16: Submissions and Feedback
June 17-18: Final Reviews
June 20: Winner Announcements

Further details for the Docathon will be announced at the Kick-off call on June 4.

Please register to join this year’s event.

Read More

Accelerating Llama3 FP8 Inference with Triton Kernels

Accelerating Llama3 FP8 Inference with Triton Kernels

1.0 Summary

We present an optimized Triton FP8 GEMM (General Matrix-Matrix Multiply) kernel TK-GEMM, which leverages SplitK parallelization. For small batch size inference, TK-GEMM delivers up to 1.94x over the base Triton matmul implementation, 1.87x speedup over cuBLAS FP8 and 1.71x over cuBLAS FP16 for Llama3-70B inference problem sizes on NVIDIA H100 GPUs.

TK-GEMM Speedup over PyTorch (calling cuBLAS) for Llama3-70B Attention Layer Matrix Shapes (N=K=8192)

Figure 1. TK-GEMM Speedup over PyTorch (calling cuBLAS) for Llama3-70B Attention Layer Matrix Shapes (N=K=8192)

In this blog, we will cover how we designed an optimized kernel using Triton for FP8 inference and tuned it for Lama3-70B inference. We will cover FP8 (8-bit floating point), a new datatype supported by Hopper generation GPUs (SM90), the key SM90 features that Triton supports, and how we modified the parallelization to be able to maximize memory throughput for memory-bound (inference) problem sizes.

We also dedicate a section on CUDA graphs, an important technology that will help materialize kernel level speedups and enable developers who want to use Triton kernels in production settings to get additional performance gain.

Repo and code available at: https://github.com/pytorch-labs/applied-ai

2.0 FP8 Datatype

The FP8 datatype was introduced jointly by Nvidia, Arm and Intel and serves as a successor to 16-bit floating point types. With half the bit count, it has the potential to provide significant throughput improvements over its predecessors for Transformer networks. The FP8 datatype consists of 2 formats:

E4M3 (4-bit exponent and 3-bit mantissa). Able to store +/ 448 and nan.
E5M2 (5-bit exponent and 2-bit mantissa). Able to store +/- 57,334, nan and inf.

BF16, FP16, FP8 E4M3 and FP8 E5M2

Above: BF16, FP16, FP8 E4M3 and FP8 E5M2.
To show precision differences, the closest representation to 0.3952 is shown in each format.
Image Credit: Nvidia

We use E4M3 in inference and forward pass training due its higher precision and E5M2 in training backward pass due to its higher dynamic range. Nvidia has designed their H100 FP8 Tensor Core to provide a peak of 3958 TFLOPS, 2x the FLOPS of the FP16 Tensor Core.

We designed our Triton kernel with these hardware innovations in mind and in the rest of the blog we will discuss methods to leverage and verify that these features are indeed being utilized by the Triton compiler.

3.0 Triton Hopper Support and FP8 Tensor Core Instruction

The Hopper GPU architecture has added the following new features that we can expect will accelerate FP8 GEMM.

  • TMA (Tensor Memory Accelerator) Hardware Unit
  • WGMMA (Warp Group Matrix Multiply-Accumulate Instruction)
  • Threadblock Clusters

Triton currently takes advantage of one of these features, the wgmma instruction, whereas PyTorch (calling cuBLAS) leverages all 3 which makes these speedups even more impressive. To fully take advantage of the Hopper FP8 Tensor Core, the wgmma is necessary even though the older mma.sync instruction is still supported.

The key difference between the mma and wgmma instructions is that instead of 1 CUDA warp being responsible for an output shard, an entire warp group, 4 CUDA warps, asynchronously contributes to an output shard.

To see what this instruction looks like in practice, and to verify that our Triton Kernel is indeed utilizing this feature we analyzed the PTX and SASS assembly using nsight compute.

PTX Assembly

Figure 2. PTX Assembly

This instruction is further lowered into a QGMMA instruction in SASS.

SASS Assembly

Figure 3. SASS Assembly

Both instructions tell us that we are multiplying two FP8 E4M3 input tensors and accumulating in F32, which confirms that the TK-GEMM Kernel is utilizing the FP8 Tensor Core and the lowering is being done correctly.

4.0 SplitK Work Decomposition

TK-GEMM vs Base Triton GEMM TFLOPS for M = 1-64

Figure 4. TK-GEMM vs Base Triton GEMM TFLOPS for M = 1-64

The base Triton FP8 GEMM implementation does not perform well for the small M regime, where for a matrix multiplication of A (MxN) x B (NxK), M < N, K. To optimize for this type matrix profile we applied a SplitK work decomposition instead of the Data Parallel decomposition found in the base Triton kernel. This greatly improved latencies for the small M regime.

For background, SplitK launches additional thread blocks along the k dimension to calculate partial output sums. The partial results from each thread block are then summed using an atomic reduction. This allows for finer grained work decomposition with resultant performance improvements. More details on SplitK are available in our arxiv paper.

After carefully tuning the other relevant hyperparameters for our kernel such as tile sizes, number of warps and the number of pipeline stages to Llama3-70B problem sizes we were able to produce up to 1.94x speedup over the Triton base implementation. For a more comprehensive introduction to hyperparameter tuning, see our blog.

NCU profiler times for TK-GEMM under varying batch sizes, and compared with PyTorch (calling cuBLAS) FP8 and FP16.

Above: NCU profiler times for TK-GEMM under varying batch sizes, and compared with PyTorch (calling cuBLAS) FP8 and FP16.

Note that starting at M=32, the cuBLAS FP8 kernel starts to outperform TK-GEMM. For M >= 32, we suspect that hyperparameters we found are not optimal, and thus another set of experiments is required to determine the optimal parameters for the mid-sized M regime.

5.0 CUDA Graphs to Enable End-to-End Speedup

To be able to realize these speedups in an end-to-end setting, we must take into account both the kernel execution time (GPU duration) as well as the wall time (CPU+GPU) duration. Triton kernels, which are handwritten (as opposed to torch compile generated) are known to suffer from high-kernel launch latencies. If we use torch profiler to trace the TK-GEMM kernel we can see the call stack on the CPU side to pinpoint exactly what is causing the slowdown.

CPU Launch Overhead: 2.413ms

Figure 5. CPU Launch Overhead: 2.413ms

From above, we see that the majority of the wall time of our optimized kernel is dominated by JIT (Just-in-Time) compilation overhead. To combat this we can use CUDA graphs.

CUDA Graphs Visualization

Figure 6. CUDA Graphs Visualization
Image Credit: PyTorch

The key idea is instead of multiple kernel launches, we instead can create and instantiate a graph (1 time cost) and then submit that instance of the graph for execution. To illustrate this point we simulate a Llama3-70B Attention layer, As shown in the below figure generated using nsight systems, the time between each GEMM is 165us compared to the 12us spent on the actual matmul due the CPU kernel launch overhead. This means that 92% of the time of the time in an Attention layer the GPU is idle and not doing any work.

Simulated Llama3-70B Attention Layer with TK-GEMM

Figure 7. Simulated Llama3-70B Attention Layer with TK-GEMM

To show the impact of CUDA graphs, we then created a graph of the TK-GEMM kernel in the toy Attention layer and replayed the graph. Below, we can see that the gaps between kernel executions are reduced to 6.65us.

Simulated Llama3-70B Attention Layer with TK-GEMM and CUDA Graphs

Figure 8. Simulated Llama3-70B Attention Layer with TK-GEMM and CUDA Graphs

In practice, this optimization would result in a 6.4x speedup of a single attention layer in Llama3-70B, over naively using TK-GEMM in a model without CUDA graphs.

6.0 Potential Future Optimization Paths

TMA Hardware Unit

Figure 9. TMA Hardware Unit
Image Credit: Nvidia

The Nvidia H100 features a TMA hardware unit. The dedicated TMA unit frees up registers and threads to do other work, as address generation is completely handled by the TMA. For memory bound problem sizes, this can provide even further gain when Triton enables support for this feature.

Tensor Core Utilization (Arrows Indicate Degrees of Freedom)

Figure 10. Tensor Core Utilization (Arrows Indicate Degrees of Freedom)

To identify how well we are utilizing the Tensor Core, we can analyze the roofline chart. Notice that we are in the memory-bound region as expected for small M. To improve kernel latency we can either increase the arithmetic intensity, which with a fixed problem size can only be achieved through exploiting data locality and other loop optimizations or increasing the memory throughput. This requires either a more optimal parallel algorithm specialized for the FP8 datatype as well as the type of problem size characteristics we expect to see in FP8 inference.

DRAM Throughput Circled, 1.65TB/s vs Peak 3.35TB/s on H100 (M=16, N=8192, K=8192)

Figure 11. DRAM Throughput Circled, 1.65TB/s vs Peak 3.35TB/s on H100 (M=16, N=8192, K=8192)

Lastly, we can see that we are only achieving around 50% of peak DRAM throughput on the NVIDIA H100. High performance GEMM kernels typically achieve around 70-80% of peak throughput. This means that there is still a lot of room to improve and the techniques mentioned above (loop unrolling, optimized parallelization) are needed for additional gain.

7.0 Future Work

For future research, we would like to explore CUTLASS 3.x and CuTe to leverage more direct control over Hopper features especially in terms of obtaining direct TMA control and exploring pingpong architectures, which have shown promising results for FP8 GEMM.

Read More

ExecuTorch Alpha: Taking LLMs and AI to the Edge with Our Community and Partners

We are excited to announce the release of ExecuTorch alpha, focused on deploying large language models (LLMs) and large ML models to the edge, stabilizing the API surface, and improving our installation processes. It has been an exciting few months from our 0.1 (preview) release in collaboration with our partners at Arm, Apple, and Qualcomm Technologies, Inc.

In this post we’ll discuss our full support for Meta’s Llama 2, early support for Meta’s Llama 3, broad model support in ExecuTorch, and highlight the important work our partners have done to move us forward.

Large Language Models on Mobile

Mobile devices are highly constrained for compute, memory, and power. To bring LLMs to these devices, we heavily leverage quantization and other techniques to pack these models appropriately.

ExecuTorch alpha supports 4-bit post-training quantization using GPTQ. We’ve provided broad device support on CPU by landing dynamic shape support and new dtypes in XNNPack. We’ve also made significant improvements in export and lowering, reduced memory overhead and improved runtime performance. This enables running Llama 2 7B efficiently on iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S22, S23, and S24 phones and other edge devices. Early support for Llama 3 8B is also included. We are always improving the token/sec on various edge devices and you can visit GitHub for the latest performance numbers.

We’re working closely with our partners at Apple, Arm, and Qualcomm Technologies to delegate to GPU and NPU for performance through Core ML, MPS, TOSA, and Qualcomm AI Stack backends respectively.

Supported Models

We remain committed to supporting an ever-expanding list of models with ExecuTorch. Since preview, we have significantly expanded our tested models across NLP, vision and speech, with full details in our release notes. Although support for on-device LLMs is early, we anticipate most traditional models to function seamlessly out of the box, with delegation to XNNPACK, Core ML, MPS, TOSA, and HTP for performance. If you encounter any problems please open a GitHub issue with us.

Productivity

Deploying performant models tuned for specific platforms often require deep visualization into the on-device runtime data to determine the right changes to make in the original PyTorch model. With ExecuTorch alpha, we provide a powerful SDK with observability throughout the process from model authoring to deployment, including delegate and hardware-level information.

The ExecuTorch SDK was enhanced to include better debugging and profiling tools. Because ExecuTorch is built on PyTorch, the debugging capabilities include the ability to map from operator nodes back to original Python source code for more efficient anomaly resolution and performance tuning for both delegated and non-delegated model instances. You can learn more about the ExecuTorch SDK here.

Partnerships

ExecuTorch has only been possible because of strong collaborations across Arm, Apple, and Qualcomm Technologies. The collaboration for the initial launch of ExecuTorch continues as we support LLMs and large AI models on the edge for PyTorch. As we’ve seen with this early work for ExecuTorch alpha, there are unique challenges with these larger models and we’re excited to develop in the open.

We also want to highlight the great partnership with Google on XNNPACK for CPU performance. The teams continue to work together upstreaming our changes and across the TensorFlow and PyTorch teams to make sure we can all support generative AI models on the edge with SOTA performance.

Lastly, our hardware partner MediaTek has been doing work enabling the Llama collection of models with ExecuTorch on their SoCs. We’ll have more to share in the future.

Alpha and Production Usage

With our alpha release, we have production-tested ExecuTorch. Meta is using ExecuTorch for hand tracking on Meta Quest 3 and a variety of models on Ray-Ban Meta Smart Glasses. In addition, we have begun the rollout of ExecuTorch with Instagram and are integrating with other Meta products. We are excited to see how ExecuTorch can be used for other edge experiences.

Community

We are excited to see various efforts in the community to adopt or contribute to ExecuTorch. For instance, Unity recently shared their work at the Game Developers Conference (GDC) on leveraging ExecuTorch and Edge IR to run PyTorch models with their neural network inference library Sentis. Leveraging ExecuTorch’s hackability and extensibility, Unity introduced their own custom backend that serializes ExecuTorch’s Edge Dialect IR into Sentis’ native serialized format enabling developers to begin using PyTorch models easily in their games and apps.

We’ve been building and innovating with ExecuTorch in the open. Our north star is to empower the community to deploy any ML model on edge devices painlessly and efficiently. Whether you are a hobbyist or this is your day job, we’d love for you to jump in to bring your ML models to the edge. We are looking for your help to:

  1. Use ExecuTorch to run your LLM models locally on various deployment targets and share your feedback
  2. Expand our supported models, including bug reports
  3. Expand our quantization schemes
  4. Help us build out delegates to GPU and NPU

To all individual contributors and early adopters of ExecuTorch, a big thank you as well. We can’t wait to have more of you join us!

Read More

PyTorch 2.3 Release Blog

We are excited to announce the release of PyTorch® 2.3 (release note)! PyTorch 2.3 offers support for user-defined Triton kernels in torch.compile, allowing for users to migrate their own Triton kernels from eager without experiencing performance regressions or graph breaks. Tensor Parallelism improves the experience for training Large Language Models using native PyTorch functions, which has been validated on training runs for 100B parameter models. As well, semi-structured sparsity implements semi-structured sparsity as a Tensor subclass, with observed speedups of up to 1.6 over dense matrix multiplication.

This release is composed of 3393 commits and 426 contributors since PyTorch 2.2. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try these out and report any issues as we improve 2.3. More information about how to get started with the PyTorch 2-series can be found at our Getting Started page.

Beta Prototype Performance Improvements
User-defined Triton kernels in torch.compile torch.export adds new API to specify dynamic_shapes Weight-Only-Quantization introduced into Inductor CPU backend
Tensor parallelism within PyTorch Distributed Asynchronous checkpoint generation
Support for semi-structured sparsity

*To see a full list of public feature submissions click here.

Beta Features

[Beta] Support for User-defined Triton kernels in torch.compile

Allows for PyTorch code that contains triton kernels to be executed natively using torch.compile. This enables users to migrate code containing triton kernels from eager PyTorch to torch.compile without running into performance regressions or graph breaks. Native support also creates an opportunity for Torch Inductor to precompile the user-defined Triton kernel as well as better organize code around the Triton kernel allowing for further optimizations.

You can find more information about how to utilize user defined Triton kernels in torch.compile within this tutorial.

[Beta] Tensor Parallelism introduces more efficient ways to train LLMs

The Tensor Parallel API facilitates various tensor manipulations across GPUs/hosts and integrates with FSDP for 2D Parallelism (Tensor parallelism across devices + Data Parallelism across hosts). It also offers a low-level API for constructing higher-level Tensor parallel APIs. This API has been validated to support the training of transformer models with over 100 billion parameters.

You can find more information on how to utilize this within your workflows within this tutorial.

[Beta] Semi-structured sparsity provides users with a way to take advantage of accelerated sparse inference and memory savings

torch.sparse.SparseSemiStructuredTensor implements semi-structured sparsity as a Tensor subclass, which have observed speedups of up to 1.6 over dense matrix multiplication.

In particular it adds:

  • Additional support for quantization composability (mixed dtype, dequant fusion)
  • Updated cuSPARSELt and CUTLASS kernels
  • torch.compile support

You can find more information on how to take advantage of semi-structured sparsity here.

Prototype Features

[PROTOTYPE] torch.export adds new API to specify dynamic_shapes

You can now use torch.export.Dim to better represent dynamic shapes by enabling developers to specify ranges (min and max values) that can be reused across different input dimensions that are constrained to be equal.

To learn more about torch.export.Dim as well as how it can be used to express more interesting relationships (such as linear arithmetic expressions) check out the tutorial here.

[PROTOTYPE] Asynchronous checkpoint generation

Asynchronous checkpoint generation allows users to continue their training loops while checkpoints are being generated, essentially offloading much of the checkpointing cost.

You can find out how to utilize this within your own workflows with this example.

Performance Improvements

[PROTOTYPE] Weight-Only-Quantization introduced into Inductor CPU backend

PyTorch 2.3 enhances LLM inference performance on torch inductor CPU backend. The project gpt-fast offers a simple and efficient PyTorch native acceleration for transformer text generation with torch.compile. Prior to 2.3 only CUDA devices were supported and this feature enables the CPU counterpart by providing highly optimized kernels for the int4 and int8 weight only quantization Linear.

For more information / how to utilize this feature please refer to the gpt-fast README.

Read More

torchtune: Easily fine-tune LLMs using PyTorch

We’re pleased to announce the alpha release of torchtune, a PyTorch-native library for easily fine-tuning large language models.

Staying true to PyTorch’s design principles, torchtune provides composable and modular building blocks along with easy-to-extend training recipes to fine-tune popular LLMs on a variety of consumer-grade and professional GPUs.

torchtune supports the full fine-tuning workflow from start to finish, including

  • Downloading and preparing datasets and model checkpoints.
  • Customizing the training with composable building blocks that support different model architectures, parameter-efficient fine-tuning (PEFT) techniques, and more.
  • Logging progress and metrics to gain insight into the training process.
  • Quantizing the model post-tuning.
  • Evaluating the fine-tuned model on popular benchmarks.
  • Running local inference for testing fine-tuned models.
  • Checkpoint compatibility with popular production inference systems.

To get started, jump right into the code or walk through our many tutorials!

Why torchtune?

Over the past year there has been an explosion of interest in open LLMs. Fine-tuning these state of the art models has emerged as a critical technique for adapting them to specific use cases. This adaptation can require extensive customization from dataset and model selection all the way through to quantization, evaluation and inference. Moreover, the size of these models poses a significant challenge when trying to fine-tune them on consumer-level GPUs with limited memory.

Existing solutions make it hard to add these customizations or optimizations by hiding the necessary pieces behind layers of abstractions. It’s unclear how different components interact with each other and which of these need to be updated to add new functionality. torchtune empowers developers to adapt LLMs to their specific needs and constraints with full control and visibility.

torchtune’s Design

torchtune was built with the following principles in mind

  • Easy extensibility – New techniques emerge all the time and everyone’s fine-tuning use case is different. torchtune’s recipes are designed around easily composable components and hackable training loops, with minimal abstraction getting in the way of fine-tuning your fine-tuning. Each recipe is self-contained – no trainers or frameworks, and is designed to be easy to read – less than 600 lines of code!
  • Democratize fine-tuning – Users, regardless of their level of expertise, should be able to use torchtune. Clone and modify configs, or get your hands dirty with some code! You also don’t need beefy data center GPUs. Our memory efficient recipes have been tested on machines with a single 24GB gaming GPU.
  • Interoperability with the OSS LLM ecosystem – The open source LLM ecosystem is absolutely thriving, and torchtune takes advantage of this to provide interoperability with a wide range of offerings. This flexibility puts you firmly in control of how you train and use your fine-tuned models.

Over the next year, open LLMs will become even more powerful, with support for more languages (multilingual), more modalities (multimodal) and more tasks. As the complexity of these models increases, we need to pay the same attention to “how” we design our libraries as we do to the features provided or performance of a training run. Flexibility will be key to ensuring the community can maintain the current pace of innovation, and many libraries/tools will need to play well with each other to power the full spectrum of use cases. torchtune is built from the ground up with this future in mind.

In the true PyTorch spirit, torchtune makes it easy to get started by providing integrations with some of the most popular tools for working with LLMs.

  • Hugging Face Hub – Hugging Face provides an expansive repository of open source models and datasets for fine-tuning. torchtune seamlessly integrates through the tune download CLI command so you can get started right away with fine-tuning your first model.
  • PyTorch FSDP – Scale your training using PyTorch FSDP. It is very common for people to invest in machines with multiple consumer level cards like the 3090/4090 by NVidia. torchtune allows you to take advantage of these setups by providing distributed recipes powered by FSDP.
  • Weights & Biases – torchtune uses the Weights & Biases AI platform to log metrics and model checkpoints during training. Track your configs, metrics and models from your fine-tuning runs all in one place!
  • EleutherAI’s LM Evaluation Harness – Evaluating fine-tuned models is critical to understanding whether fine-tuning is giving you the results you need. torchtune includes a simple evaluation recipe powered by EleutherAI’s LM Evaluation Harness to provide easy access to a comprehensive suite of standard LLM benchmarks. Given the importance of evaluation, we will be working with EleutherAI very closely in the next few months to build an even deeper and more “native” integration.
  • ExecuTorch – Models fine-tuned with torchtune can be easily exported to ExecuTorch, enabling efficient inference to be run on a wide variety of mobile and edge devices.
  • torchao – Easily and efficiently quantize your fine-tuned models into 4-bit or 8-bit using a simple post-training recipe powered by the quantization APIs from torchao.

What’s Next?

This is just the beginning and we’re really excited to put this alpha version in front of a vibrant and energetic community. In the coming weeks, we’ll continue to augment the library with more models, features and fine-tuning techniques. We’d love to hear any feedback, comments or feature requests in the form of GitHub issues on our repository, or on our Discord channel. As always, we’d love any contributions from this awesome community. Happy Tuning!

Read More

Accelerating MoE model inference with Locality-Aware Kernel Design

Accelerating MoE model inference with Locality-Aware Kernel Design

1.0 Summary

We show that by implementing column-major scheduling to improve data locality, we can accelerate the core Triton GEMM (General Matrix-Matrix Multiply) kernel for MoEs (Mixture of Experts) up to 4x on A100, and up to 4.4x on H100 Nvidia GPUs. This post demonstrates several different work decomposition and scheduling algorithms for MoE GEMMs and shows, at the hardware level, why column-major scheduling produces the highest speedup.

Repo and code available at: https://github.com/pytorch-labs/applied-ai/tree/main/triton/.

Figure 1A. Optimized Fused MoE GEMM Kernel TFLOPs on A100 for varying Batch Sizes M

Figure 1A. Optimized Fused MoE GEMM Kernel TFLOPs on A100 for varying Batch Sizes M

Figure 1B. Optimized Fused MoE GEMM Kernel TFLOPs on H100 for varying Batch Sizes M

Figure 1B. Optimized Fused MoE GEMM Kernel TFLOPs on H100 for varying Batch Sizes M

2.0 Background

OpenAI’s Triton is a hardware-agnostic language and compiler that as our prior blog post has shown can be used to accelerate quantization workflows. We also showed that in terms of kernel development, much of the same learnings and performance analysis tools from CUDA can be leveraged to provide similar insights into how Triton kernels work under-the-hood and subsequent measures to speedup these kernels in latency sensitive environments. As Triton becomes increasingly adopted in production settings, it is important that developers understand the common tips and tricks to developing performant kernels as well as the generality of these methods to various different architectures and workflows. Thus, this post will explore how we optimized the Triton kernel developed by vLLM for the popular Mixture of Experts (MoE) Mixtral model using classical techniques and how these techniques can be implemented in Triton to achieve performance gain.

Mixtral 8x7B is a sparse Mixture of Experts Language Model. Unlike the classical dense transformer architecture, each transformer block houses 8 MLP layers where each MLP is an ‘expert’. As a token flows through, a router network selects which 2 of the 8 experts should process that token and the results are then combined. The selected experts for the same token vary at each layer. As a result, while Mixtral 8x7B has a total of 47B params, during inference only 13B params are active.

The MoE GEMM (General Matrix-Matrix Multiply) kernel receives a stacked weight matrix containing all the experts, and must subsequently route each token to the TopK (2 for Mixtral) experts by utilizing a mapping array produced by the resultant scores of the router network. In this post, we provide methods to efficiently parallelize this computation during inference time, specifically during autoregression (or decoding stages).

3.0 Work Decomposition – SplitK

We have previously shown that for the matrix problem sizes found in LLM inference, specifically in the context of W4A16 quantized inference, GEMM kernels can be accelerated by applying a SplitK work decomposition. Thus, we started our MoE acceleration research by implementing SplitK in the vLLM MoE Kernel, which produced speedups of approximately 18-20% over the Data Parallel approach.

This result shows that the SplitK optimization can be used as a part of a more formulaic approach to improving/developing Triton kernels in inference settings. To build intuition about these different work decompositions, let’s consider a simple example for the multiplication of two 4×4 matrices and SplitK=2.

In the data parallel GEMM kernel shown below, the computation for a single block of the output matrix will be handled by 1 threadblock, TB0.

Figure 2. Data Parallel GEMM

Figure 2. Data Parallel GEMM

In contrast, in the SplitK kernel, the work required to compute 1 block in the output matrix, is “split” or shared amongst 2 thread blocks TB0 and TB1. This provides better load balancing and increased parallelism.

Figure 3. SplitK GEMM

Figure 3. SplitK GEMM

The key idea is that we’ve increased our parallelism from MN to MN*SplitK. This approach does incur some costs such as adding inter-threadblock communication via atomic operations. However, these costs are minimal compared to the savings of other constrained GPU resources like shared memory and registers. Most importantly, the SplitK strategy provides superior load balancing characteristics for skinny matrices, (as is the case in MoE inference) and is the common matrix profile during decoding and inference.

4.0 GEMM Hardware Scheduling – Column Major

To improve upon the ~20% speedup with SplitK we focused our investigation on the logic that controls the hardware scheduling of the GEMM in Triton Kernels. Our profiling of the vLLM MoE kernel showed a low L2 cache hit rate, thus we investigated three scheduling options – column-major, row-major and grouped launch. Due to some intrinsic properties of MoE models, such as large expert matrices, and having to dynamically load TopK (2 for Mixtral) matrices during the duration of the kernel, cache reuse/hit rate becomes a bottleneck that this optimization will target.

For background, in our previous blog, we touched on the concept of “tile swizzling”, a method to achieve greater L2 cache hit rate. This concept relates to how the software schedules the GEMM onto the SMs of a GPU. In Triton, this schedule is determined by the pid_m and pid_n calculations. Our key insight is that for skinny matrix multiplications, a column-major ordering ensures optimal reuse of the columns of the weight matrix, B. To illustrate this, let’s take a look at a snippet of what a column major computation of pid_m, and pid_n would look like:

Figure 4. Column Major ordering in PyTorch

Figure 4. Column Major ordering in PyTorch

From above, we note that with this mapping, we schedule the GEMM such that we calculate the output blocks of C in the following order: C(0, 0), C(1, 0), C(2, 0),… etc. To understand the implications we provide the following illustration:

Activation matrix / Weight matrix

L1/L2 Cache

C - Output Matrix

Figure 5. Cache Reuse Pattern for a Column-Major GEMM Schedule

In the above simplified view of a column-major schedule, let’s assume for a GEMM with skinny activation matrix A, that the entire matrix can fit in the GPU cache which is a reasonable assumption to make for the type of problem sizes we encounter in MoE inference. This allows for maximal reuse of the columns of the weight matrix B, due to the fact that the B column can be re-used for the corresponding output tile calculations, C(0,0), C(1, 0) and C(2, 0). Consider instead, a row-major schedule, C(0,0), C(0,1), C(0, 2) etc. We would have to evict the column of B, and issue multiple load instructions to DRAM to calculate the same amount of output blocks.

An important design consideration when optimizing kernels is a memory access pattern that results in the least amount of global load instructions. This optimal memory access pattern is achieved with the column-major schedule. The results below showcase the performance of the three schedules we investigated:

Figure 6. Comparison of GEMM Schedules on A100 for varying Batch Sizes M

Figure 6. Comparison of GEMM Schedules on A100 for varying Batch Sizes M

The column-major schedule provides up to a 4x speedup over the other patterns, and as we’ll show in the next section, provides an optimal memory access pattern due to greatly improved data locality.

5.0 Nsight Compute Analysis – Throughput and Memory Access Pattern

For performance analysis, we focus on the M = 2 case for the H100. A similar study can be done for the A100 as many of the same observations carry over. We note the following salient results, that showcase the impact of our optimizations.

Figure 7. H100 Memory Throughput Chart for M = 2.  Note the very large increase in the cache hit rates L1 cache hit rate (+2696%) and L2 cache hit rate (+254%).

Figure 7. H100 Memory Throughput Chart for M = 2. Note the very large increase in the cache hit rates L1 cache hit rate (+2696%) and L2 cache hit rate (+254%).

Figure 8. H100 Memory Instruction Statistics M = 2. Note the 49% reduction in global memory loads.

Figure 8. H100 Memory Instruction Statistics M = 2. Note the 49% reduction in global memory loads.

These statistics show that our optimizations had the intended effect, which can be seen in the reduced cache misses, reduced memory accesses and the resultant 2.7x speedup. More concretely, the trace shows us a 2.54x increase in L2 hit rate (Figure 7), and a ~50% reduction in DRAM accesses (Figure 8).

These improvements ultimately yield the reduced latency, with the optimized kernel being 2.7x faster for bs=2 and 4.4x for bs=512.

6.0 Future Work

Our kernel was tested in FP16, which showcases the numerics and performance of the column major scheduling for MoE, but most production models are using BFloat16. We encountered a limitation in Triton such that tl.atomic_add does not support Bfloat16 and hit launch latency concerns which would require cuda graph support for column major production use. In initial testing this translated to a 70% end-to-end speedup but, we encountered some expert mapping inconsistencies in an end to end environment that are not reflected in the test environment, so further work is needed to fully realize these speedups.

For future work, we intend to move this into a CUDA kernel which will ensure full BFloat16 support and reduced launch latency relative to Triton, and potentially resolve the expert routing inconsistency. We’ve also previously published work on enabling GPTQ W4A16 with Triton GEMM kernels, so natural follow-on work would include fusing dequantization into this kernel to allow for a GPTQ quantized inference path.

7.0 Reproducibility

We have open sourced the Triton kernel code along with an easy to run performance benchmark for readers interested in comparing or verifying the performance on their own GPU.

Acknowledgements

We want to thank Daniel Han, Raghu Ganti, Mudhakar Srivatsa, Bert Maher, Gregory Chanan, Eli Uriegas, and Geeta Chauhan for their review of the presented material and Woo Suk from the vLLM team as we built on his implementation of the Fused MoE kernel.

Read More

Maximizing training throughput using PyTorch FSDP

Maximizing training throughput using PyTorch FSDP

In this blog, we demonstrate the scalability of FSDP with a pre-training exemplar, a 7B model trained for 2T tokens, and share various techniques we used to achieve a rapid training speed of 3,700 tokens/sec/GPU, or 40B tokens/day on 128 A100 GPUs. This translates to a model FLOPS utilization (MFU) and hardware FLOPS utilization (HFU) of 57%. Additionally, we have observed near linear scaling of FSDP to 512 GPUs, implying that training a 7B model on 512 GPUs to 2T tokens using this method would take just under two weeks.

IBM researchers trained a Meta Llama 2 7B architecture to 2T tokens, which we will refer to as LlamaT(est). This model demonstrates comparable model quality as Llama 2 on various academic benchmarks. All of the training code, along with our methodology to achieve this throughput, can be found in this blog. We also share the configuration knobs that work well for the Llama 2 models – 7B, 13B, 34B, and 70B for A100s and H100s.

In this process, we also propose a _new _selective activation checkpointing mechanism that applies to FSDP which gives us a 10% boost beyond out-of-the box FSDP. We have open sourced the training code base and an associated scalable data loader as the methodology to achieve this throughput.

One key benefit of a PyTorch native pathway for training is the ability to seamlessly train on multiple hardware backends. For example, the recent end-to-end stack for training that was released by AllenAI through OLMo also leverages PyTorch FSDP for training on AMD and NVIDIA GPUs. There are three main components that we leverage from FSDP to achieve our throughput:

  1. SDPA Flash attention, that enables fused attention kernels and efficient attention computation
  2. Overlap in computation and communication allows for better utilization of the GPU
  3. Selective activation checkpointing enables us to tradeoff between GPU memory and compute

IBM has been working closely with Team PyTorch at Meta on PyTorch FSDP for nearly two years: introducing the rate limiter for achieving better throughput on Ethernet interconnects, distributed checkpointing to improve the checkpoint times by an order of magnitude, and implementing the early version of checkpointing for the hybrid sharding mode of FSDP. Late last year, we used FSDP to train a model end-to-end.

Training Details

The 7B model is trained on 128 A100 GPUs with 400Gbps network connectivity and GPU direct RDMA. We use SDPA FlashAttention v2 for attention computation, and for this model we turned off activation checkpointing that limits the batch size, but provides the highest throughput – batch size is 1 million tokens per batch for 128 GPUs and improves throughput by about 10% when compared to activation checkpointing. With these parameters, we have an almost full overlap in computation and communication. We use the AdamW optimizer in 32-bit with beta1 of 0.9 and beta2 of 0.95, weight decay of 0.1, and a learning rate ending at 3e-5 with a warmup to max learning rate of 3e-4 and a cosine schedule to reduce to 3e-5 over 2T tokens. The training was performed using mixed precision bf16 on an internal dataset. The training stack is using IBM’s Foundation Model Stack for model architecture and PyTorch nightlies post-2.2 release for FSDP and SDPA. We tried a few different nightlies during the time period of Nov 2023 through Feb 2024 and we observed an improvement in the throughput.

Selective activation checkpointing

We jointly implemented a simple and effective mechanism of selective activation checkpointing (AC). In FSDP, the common practice is to checkpoint each transformer block. A simple extension is to checkpoint every _n _blocks and reduce the amount of recomputation, while increasing the memory needed. This is quite effective for the 13B model size, increasing the throughput by 10%. For the 7B model size, we did not need activation checkpointing at all. Future versions of FSDP will provide selective activation checkpointing at an operator level, enabling an optimal compute-memory tradeoff. The code for the above is implemented here.

Throughput and MFU, HFU computation

While we only trained the 7B model to 2T tokens, we performed numerous experiments on the other model sizes to provide the best configuration options. This is summarized in the table below for two types of infrastructure — an A100 cluster with 128 GPUs and 400Gbps inter-node interconnect, and an H100 cluster with 96 GPUs and 800Gbps inter-node interconnect.

Model size

Batch size

Activation checkpoint

Throughput tokens/sec/GPU (A100 80GB and 400Gbps interconnect)

MFU % (A100 80GB)

HFU % (A100 80GB)

Throughput tokens/sec/GPU (H100 80GB and 800Gbps interconnect)

MFU % (H100 80GB)

HFU % (H100 80GB)

7B

2

No

3700

0.57

0.57

7500

0.37

0.37

13B

2

Selective

1800

0.51

0.59

3800

0.35

0.40

34B

2

Yes

700

0.47

0.64

1550

0.32

0.44

70B

2

Yes

370

0.50

0.67

800

0.34

0.45

Table 1: Model and Hardware FLOPS utilization of various model sizes on A100 and H100 GPUs

HFU numbers are computed using the PyTorch FLOP counter and the theoretical bf16 performance of A100 and H100 GPUs, whereas MFU numbers are computed using the methodology outlined in NanoGPT and the PaLM paper. We also note that the batch sizes we use for the larger models are intentionally kept at 2 per GPU to mimic choices made in training models of 4k sequence length and achieve this up to 512 GPUs without exceeding the 4M tokens popular batch size. Beyond that, we would need tensor parallelism or sequence parallelism.

We note in the table above that for A100s, that activation recomputation causes the MFU to reduce, while HFU increases! With the introduction of better activation checkpointing schemes, we expect MFU to increase and catch up with HFU. However, we observe that for H100s, both MFU and HFU are relatively low. We analyze the PyTorch profile traces on H100 and observe that there is a 10% gap due to network “peeking” out. In addition, we hypothesize that the HBM bandwidth of H100s is the cause for the reduced HFU/MFU on H100s and not being able to obtain the 3x improvement (H100s are theoretically 3x faster than A100s – 312 vs 989TFLOPS, but only have <2x the HBM bandwidth than A100s – 2.0 vs 3.35TBps). We plan to try out other configuration options like Tensor Parallel to improve the knobs for the 70B model on H100s.

Model details

The loss curve for training is shown in the below figure.

loss curve for training

Figure 1: LlamaT training loss curve

The 2T checkpoint is converted to Hugging Face format by a script that is provided in the repository and we then use lm-evaluation-harness to compute key academic benchmarks and compare that by running it on Llama2-7B. These results are captured in the below table.

Evaluation metric Llama2-7B (baseline) LlamaT-7B
MMLU (zero shot) 0.41 0.43
MMLU (5-shot weighted avg) 0.47 0.50
Arc challenge 0.46 0.44
Arc easy 0.74 0.71
Boolq 0.78 0.76
Copa 0.87 0.83
Hellaswag 0.76 0.74
Openbookqa 0.44 0.42
Piqa 0.79 0.79
Sciq 0.91 0.91
Winogrande 0.69 0.67
Truthfulqa 0.39 0.39
GSM8k (8-shot) 0.13 0.11

Table 1: LM eval harness scores

We observe that the model performs competitively with Llama2 (bolder is better).

Training chronicles

Training was stable with no crashes, though we did observe a few hiccups:

0-200B tokens: We observed a slowdown in the iteration time (time taken to execute one training step). We stopped the job to ensure that the data loader was not causing any slowdowns and the checkpointing was performant and accurate. We did not find any issues. By this time, HSDP checkpointing code was available in PyTorch, and we took this opportunity to make the switch to PyTorch checkpointing code.

200B tokens-1.9T: We did not do any manual intervention in the job in late December. When we came back early January, disk space had exceeded and checkpoints were failing to be written, although the training job continued. The last known checkpoint was 1.5T.

1.5T-1.7T: We evaluated the 1.5T checkpoint with lm-evaluation-harness and discovered that model has been trained with an extra special token between two documents due to the Hugging Face tokenizer introducing a separator token and our dataloader also appending its own document separator. We modified the dataloader to eliminate the extra special token, and continued training with the modified dataloader from 1.7T token onwards.

1.7T-2T: The loss initially spiked due to the change in the special tokens which was quickly recovered in a few billion tokens. The training finished without any other manual intervention!

Key takeaways and even more speed

We demonstrated how one can use FSDP to train a model to 2T tokens with an excellent performance of 3700 tokens/sec/GPU and that generates a good quality model. As part of this exercise, we open sourced all our code for training and the knobs to achieve this throughput. These knobs can be leveraged by not only large-scale runs, but also smaller scale tuning runs. You can find the code here.

FSDP APIs implement the ZeRO algorithms in a PyTorch native manner and allow for tuning and training of large models. In the past, we have seen FSDP proof points (Stanford Alpaca, Hugging Face, Llama 2 recipes) on tuning a variety of LLMs (such as Meta Llama 2 7B to 70B Llama) using simple training loops and achieving good throughputs and training times.

Finally, we note that there are several levers for speeding up training:

  1. Node optimizations that can speedup specific operations (e.g., attention computation using Flash Attention V2)
  2. Graph optimizations (e.g., fusing kernels, torch.compile)
  3. Overlap in compute-communications
  4. Activation recomputation

We have leveraged 1, 3, and a variation of 4 in this blog and are working closely with Team PyTorch at Meta to get torch.compile (2) as well as a more advanced version of 4 with per-operator selective activation recomputation. We plan to share a simple formatting code and example data to ingest into our data loader to enable others to use the code base for training of models.

Acknowledgements

There are several teams that have been involved in reaching this proof point and we would like to thank the teams across Meta and IBM. Specifically, we extend our gratitude to the PyTorch distributed team, Facebook Research and Applied AI teams that built the FSDP APIs and made enhancements based on our feedback. We also wish to thank the data team at IBM Research that curated the data corpus used in this exercise and the infrastructure team at IBM Research (especially, Claudia Misale, Shweta Salaria, and Seetharami Seelam) that optimized NCCL and network configurations. By building and leveraging all of these components, we have successfully demonstrated the LlamaT proof point.

The selective activation checkpointing was conceptualized at IBM by Linsong Chu, Davis Wertheimer, Mudhakar Srivatsa, and Raghu Ganti and implemented by Less Wright at Meta.

Special thanks to Stas Bekman and Minjia Zhang, who provided extensive feedback and helped improve the blog. Their insights have been invaluable in highlighting key aspects of optimizing the training and exploring further enhancements.

Appendix

Communication computation overlap

Another key aspect of training in a multi-node setting is the ability to overlap communication and computation. In FSDP, there are multiple opportunities for overlapping – during the FSDP unit gathering phase at forward pass as well as the backward pass computation. Overlapping the gather during forward pass while the computation of the previous unit and overlapping backward computation with the next unit gathering and gradient scattering help improve GPU utilization by nearly 2x. We illustrate this on the 400Gbps network interconnect with A100 80GB GPUs. In the case of HSDP, there is no inter-node traffic during the pre-fetch stage for forward pass and the overlap is only for the backward gradient computation phase. Of course, HSDP is feasible only when the model can be sharded within a single node, limiting the size of models to around 30B parameters.

The below figure shows three steps in FSDP with the communication between nodes at the bottom and the compute stream at the top of the second half of the image. For the 7B model with no activation recomputation, we observe the overlap to be complete. In practice, the overlap percentage possible is 90% since the first block during forward pass and the last block during backward pass are not able to overlap.

three steps in FSDP with the communication between nodes at the bottom and the compute stream at the top of the second half

A zoomed in view of the above three-step process is shown below for a single step. We can clearly see the granularity of the computation and communication and how they overlap in an interleaved manner.

zoomed in view of the above three-step process

Read More