Enabling Fully Sharded Data Parallel (FSDP2) in Opacus

Introduction and Context

Opacus is making significant strides in supporting private training of large-scale models with its latest enhancements. Recently, we introduced Fast Gradient Clipping (FGC) and Ghost Clipping (GC), which enabled developers and researchers to perform gradient clipping without instantiating the per-sample gradients. These methods reduce the memory footprint of DP-SGD compared to the native Opacus implementation that relies on hooks.

Even with these advancements, training large-scale models such as Large Language Models (LLMs) remains a significant challenge for Opacus. As the demand for private training of large-scale models continues to grow, it is crucial for Opacus to support both data and model parallelism techniques. Currently, Opacus supports Differentially Private Distributed Data Parallel (DPDDP) to enable multi-GPU training. While DPDDP effectively scales model training across multiple GPUs and nodes, it requires each GPU to store a copy of the model and optimizer states, leading to high memory requirements, especially for large models.

This limitation underscores the need for alternative parallelization techniques, such as Fully Sharded Data Parallel (FSDP), which can offer improved memory efficiency and increased scalability via model, gradients, and optimizer states sharding. In the context of training Llama or other large language models, different parallelism strategies are typically employed to scale the training depending on the model size:

  • 1D Parallelism: DDP or FSDP for small-sized models (<10 billion parameters).
  • 2D Parallelism: FSDP combined with Tensor Parallelism (TP) for medium-sized models (10-100 billion parameters).
  • 4D Parallelism: FSDP combined with TP, Pipeline Parallelism (PP), and Context Parallelism (CP) for large-sized models (>100 billion parameters).

By adopting FSDP (example), Opacus is enhancing its capability to facilitate more efficient and scalable private training or fine-tuning of LLMs. This development marks a promising step forward in meeting the evolving needs of the ML community, paving the way for advanced parallelism strategies such as 2D and 4D parallelism to support private training of medium to large-scale models.

FSDP with FGC and GC

Fully Sharded Data Parallelism (FSDP) is a powerful data parallelism technique that enables the training of larger models by efficiently managing memory usage across multiple GPU workers. FSDP allows for the sharding of model parameters, gradients, and optimizer states across workers, which significantly reduces the memory footprint required for training. Although this approach incurs additional communication overhead due to parameter gathering and discarding during training, the cost is often mitigated by overlapping it with computation.

In FSDP, even though the computation for each microbatch of data is still local to each GPU worker, the full parameters are gathered for one block (e.g., a layer) at a time resulting in lower memory footprint. Once a block is processed, each GPU worker discards the parameter shards collected from other workers, retaining only its local shard. Consequently, the peak memory usage is determined by the maximum size of parameters+gradients+optimizer_states per layer, as well as the total size of activations, which depends on the per-device batch size. For more details on FSDP, please refer to the PyTorch FSDP paper.

The flow of FSDP with FGC or GC is as follows:

  • Forward pass
    • For each layer in layers:
      • [FSDP hook] all_gather full parameters of layer
      • Forward pass for layer
      • [FSDP hook] Discard full parameters of layer
      • [Opacus hook] Store the activations of layer
  • Reset optimizer.zero_grad()
  • First backward pass
    • For each layer in layers:
      • [FSDP hook] all_gather full parameters of layer
      • Backward pass for layer
      • [Opacus hook] compute per-sample gradient norms using FGC or GC
      • [FSDP hook] Discard full parameters of layer
      • [FSDP hook] reduce_scatter gradients of layer → not necessary
  • Rescale the loss function using the per-sample gradient norms
  • Reset oprimizer.zero_grad()
  • Second backward pass
    • For each layer in layers:
      • [FSDP hook] all_gather full parameters of layer
      • Backward pass for layer
      • [FSDP hook] Discard full parameters of layer
      • [FSDP hook] reduce_scatter gradients of layer
  • Add noise on each device to the corresponding shard of parameters
  • Apply optimizer step on each device to the corresponding shard of parameters

Figure 1: Workflow of FSDP based Fast Gradient Clipping or Ghost Clipping in Opacus. Note that there is an overlap between compute and communication – 1) In the forward pass: computation of the current layer (l) is overlapped with the all_gather of next layer’s (l+1) parameter. 2) In the backward pass: gradient computation of the current layer (l) is overlapped with the reduce_scatter of the previous layer’s (l+1) gradients and all_gather of the next layer’s (l-1) parameter.

How to use FSDP in Opacus

The training loop is identical to that of the standard PyTorch loop. As in Opacus before, we use

  • PrivacyEngine(), which configures the model and optimizer for running DP-SGD.
  • To enable Ghost Clipping with FSDP, the argumentgrad_sample_mode="ghost_fsdp" is used.
  • Additionally, we wrap the model with FSDP2Wrapperbefore initializing the optimizer and callingmake_private()

FSDP2Wrapper applies FSDP2 (second version of FSDP) to the root module and also to each torch.nn layer that does not require functorch to compute per-sample gradient norms. Layer types that depend on functorch are not separately wrapped with FSDP2 and thus fall into the root module’s communication group. The layers that are attached to the root module’s communication group will be unsharded first (at the beginning of forward/backward pass) and resharded the last (after the entire forward/backward pass). This will impact the peak memory as layers attached to the root module will not be resharded immediately after that layer is executed.

We use FSDP2 in our implementation as the previous version (FSDP) is not compatible with two backward passes setup of Ghost Clipping.

from opacus import PrivacyEngine
from opacus.utils.fsdp_utils import FSDP2Wrapper

def launch(rank, world_size):
   torch.cuda.set_device(rank)
   setup_distributed_env(rank, world_size)
   criterion = nn.CrossEntropyLoss() # example loss function
   model     = ToyModel()
   model     = FSDP2Wrapper(model) # different from DPDDP wrapper
   optimizer = optim.SGD(model.parameters(), lr=args.lr)

   privacy_engine = PrivacyEngine()
   model_gc, optimizer_gc, criterion_gc, train_loader, = privacy_engine.make_private(
           module=model,
           optimizer=optimizer,
           data_loader=train_loader,
           noise_multiplier=noise_multiplier
           max_grad_norm=max_grad_norm,
           criterion=criterion,
           grad_sample_mode="ghost_fsdp",)

   # The training loop below is identical to that of PyTorch
   for input_data, target_data in train_loader:
       input_data, target_data = input_data.to(rank), target_data.to(rank)
       output_gc = model_gc(input_data) # Forward pass
       optimizer_gc.zero_grad()
       loss = criterion_gc(output_gc, target_data)
       loss.backward()
       optimizer_gc.step()  # Add noise and update the model

world_size = torch.cuda.device_count()
mp.spawn(
       launch,
       args=(world_size,),
       nprocs=world_size,
       join=True,
   )

Memory Analysis

We provide results on memory consumption with full fine-tuning of GPT2 series of models. We only train the transformer blocks; the embedding layers (token embedding layer, positional embedding layer) are frozen.

Figure 2 reports the maximum batch size supported by FSDP2 vs DPDDP when training the model with the Ghost Clipping method. With FSDP2, we can achieve a 2.6x larger batch size for a 1.5B parameter GPT2 model on 1×8 A100 40GB GPUs, compared to using DPDDP. FSDP2 shows significant improvements for larger models (>1B) where the size of parameters and optimizer states dominates the activations.

Figure 2: Maximum batch size on a 1×8 A100 40GB GPU node while training a series of GPT2 models with Ghost Clipping based DP-SGD. We used the Shakespeare Dataset with a maximum sequence length of 1024 and float32 AdamW optimizer.

Table 1 presents the peak memory for a given step and total occupied memory after the execution of the step. Notably, the total memory of FSDP2 after model initialization is 8x lower than that of DPDDP since FSDP2 shards the model across 8 GPUs. The forward pass, for both FSDP2 and DPDDP, roughly increases the peak memory by ~10GB as activations are not sharded in both types of parallelism. For the backward pass and optimizer step, the peak memory for DPDDP is proportional to the model size whereas for FSDP2 it is proportional to the model size divided by the number of workers. Typically, as model sizes increase, the advantages of FSDP2 become even more pronounced.

Table 1: Memory estimates on rank 0 when training GPT2-xl model (1.5B) with Ghost Clipping based DP-SGD (AdamW optimizer, iso batch size of 16) on 1×8 A100 40GB GPU node. Peak Memory indicates the maximum memory allotted during the step, and total memory is the amount of occupied memory after executing the given step.

DPDDP  FSDP2
Peak Memory (GB) Total Memory (GB) Peak Memory (GB) Total Memory (GB)
Model initialization 5.93 5.93 1.08 0.78
Forward Pass 16.17 16.13 11.31 10.98
GC Backward Pass 22.40 11.83 12.45 1.88
Optimizer step 34.15 28.53 3.98 3.29
Optimizer zero grad 28.53 17.54 3.29 2.59

Latency Analysis

Table 2 shows the max batch size and latency numbers for LoRA fine-tuning of the Llama-3 8B model with DP-DDP and FSDP2. We observe that, for DP-SGD with Ghost Clipping, FSDP2 supports nearly twice the batch size but has lower throughput (0.6x) as compared to DP-DDP with hooks for the same effective batch size. In this particular setup, with LoRA fine-tuning, using FSDP2 does not result in any significant improvements. However, if the dataset has samples with a sequence length of 4096, which DP-DDP cannot accommodate, FSDP2 becomes necessary.

Table 2a: LoRA fine-tuning of Llama-3 8B with (trainable parameters: 6.8M) on Tiny Shakespeare dataset, AdamW optimizer with 32-bit precision, max sequence length of 512, 1×8 A100 80GB GPUs. Here, we do not use any gradient accumulation.

Training Method Parallelism Max batch size per device Total batch size Tokens per second Samples per second
SGD

(non-private)

DP-DDP 4 32 18,311 ± 20 35.76 ± 0.04
FSDP2 4 32 13,158 ± 498 25.70 ± 0.97
8 64 16,905 ± 317 33.02 ± 0.62
DP-SGD with Hooks DP-DDP 4 32 17,530 ± 166 34.24 ± 0.32
DP-SGD with Ghost Clipping DP-DDP 4 32 11,602 ± 222 22.66 ± 0.43
 

FSDP2

4 32 8,888 ± 127 17.36 ± 0.25
8 64 10,847 ± 187 21.19 ± 0.37

Table 2b: LoRA fine-tuning of Llama-3 8B with (trainable parameters: 6.8M) on Tiny Shakespeare dataset, AdamW optimizer with 32-bit precision, max sequence length of 512, 1×8 A100 80GB GPUs. Here, we enable gradient accumulation to increase the total batch size to 256. 

Training Method Parallelism Max batch size per device Gradient accumulation steps Total batch size Tokens per second Samples per second
DP-SGD with Hooks DP-DDP 4 8 256 17,850 ± 61 34.86 ± 0.12
DP-SGD with Ghost Clipping DP-DDP 4 8 256 12,043 ± 39 23.52 ± 0.08
FSDP2 8 4 256 10,979 ± 103 21.44 ± 0.20

Table 3 presents the throughput numbers for full fine-tuning of Llama-3 8B. Currently, FSDP2 with Ghost Clipping does not support tied parameters (embedding layers). We freeze these layers during fine-tuning, which brings the trainable parameters down from 8B to 7.5B. As shown in Table 3, DP-DDP throws an OOM error even with batch size 1 per device. Whereas, with FSDP2, each device can fit a batch size of 8 enabling full fine-tuning of Llama-3 8B.

To compare full fine-tuning of FSDP2 with DP-DDP, we shift from AdamW optimizer to SGD w/o momentum and reduce the trainable parameters from 7.5B to 5.1B by freezing normalization layers’ and gate projection layers’ weights. This allows DP-DDP to run with a batch size of 2 (1 if gradient accumulation is enabled). In this setting, we observe that FSDP2 is 1.65x times faster than DP-DDP for iso batch size.

Table 3: Ghost Clipping DP-SGD based full fine-tuning of Llama-3 8B on Tiny Shakespeare dataset, max sequence length of 512, 1×8 A100 80GB GPUs.

Setup Parallelism Max batch size per device Gradient accumulation steps Total batch size Tokens per second Samples per second
Trainable parameters: 7.5B
Optimizer: AdamW
DP-DDP 1 1 8 OOM OOM
FSDP2 8 1 64 6,442 ± 68 12.58 ± 0.13
Trainable parameters: 5.1B
Optimizer: SGD
DP-DDP 2 1 16 5,173 ± 266 10.10 ± 0.52
FSDP2 2 1 16 4,230 ± 150 8.26 ± 0.29
DP-DDP 2 4 64 OOM OOM
1 8 64 4,762 ± 221 9.30 ± 0.43
FSDP2 8 1 64 7,872 ± 59 15.37 ± 0.12

Correctness Verification

We did an integration test of Ghost Clipping DP-SGD with FSDP on an internal Meta use case, consisting of a Llama-3 8B model with LoRA fine-tuning for the next word prediction task. Our results show that Ghost Clipping with FSDP has roughly the same train loss (negligible difference) as DP-DDP. The previous results, as well as the unit test (link) have proved the correctness of the implementation.

Figure 3: Training loss (y-axis) vs iterations (x-axis) for LoRA fine-tuning of Llama-3 8B using Ghost Clipping DP-SGD for next word prediction task.

Limitations

The current version of FSDP does not support the following scenarios:

  1. Layers with tied parameters.
  2. Freezing/unfreezing the trainable parameters within the training phase.

The following are the two main limitations with the current implementation of gradient accumulation with FSDP2 Ghost Clipping.

    • Latency:
      • The current implementation of gradient accumulation with FSDP2 synchronizes the gradients after every backward pass. Since Ghost Clipping has two backward passes, we have 2k gradient synchronization calls (reduce_scatter) for k gradient accumulation steps.
      • This is because no_sync can’t be directly used when there are two backward passes for each forward pass.
      • Ideally, we should have only 1 gradient synchronization call for k gradient accumulation steps.
      • The latency of reduce_scatter is negligible in case of LoRA fine-tuning. Also, with a reasonable compute / communication overlap, this overhead can be masked.
  • Memory:
    • Gradient accumulation uses an additional buffer to store the accumulated gradients (sharded) irrespective of the number of gradient accumulation steps.
    • We would like to avoid the usage of an additional buffer when the number of gradient accumulation steps is equal to one. This is not specific to FSDP2 and is a bottleneck for the Opacus library in general.

Main takeaway

  • For models with small size of trainable parameters, e.g. LoRA fine-tuning
    • It is recommended to use DP-DDP with gradient accumulation whenever possible.
    • Shift to FSDP2 if DP-DDP throws an OOM error for the required sequence length or model size.
  • For full fine-tuning with a reasonably large number of trainable parameters
    • It is recommended to use FSDP2 as it has higher throughput than DP-DDP
    • In most cases, FSDP2 is the only option as DP-DDP triggers OOM even with a batch size of one.
  • The above observations hold for both private and non-private cases.

Conclusion

In this post, we present the integration of Fully Sharded Data Parallel (FSDP) with Fast Gradient Clipping (FGC) and Ghost Clipping (GC) in Opacus, demonstrating its potential to scale the private training of large-scale models with over 1 billion trainable parameters. By leveraging FSDP, we have shown that it is possible to fully fine-tune the Llama-3 8B model, a feat that is not achievable with Differentially Private Distributed Data Parallel (DP-DDP) due to memory constraints.

The introduction of FSDP in Opacus marks a significant advancement to the Opacus library, offering a scalable and memory-efficient solution for private training of LLMs. This development not only enhances the capability of Opacus to handle large-scale models but also sets the stage for future integration of other model parallelism strategies.

Looking ahead, our focus will be on enabling 2D parallelism with Ghost Clipping and integrating FSDP with native Opacus using hooks. These efforts aim to further optimize the training process, reduce latency, and expand the applicability of Opacus to even larger and more complex models. We are excited about the possibilities that these advancements will unlock and are committed to pushing the boundaries of what is possible in private machine learning. Furthermore, we invite developers, researchers, and enthusiasts to join us in this journey. Your contributions and insights are invaluable as we continue to enhance Opacus.

Acknowledgments

We would like to thank Will Bullock, Wei Feng, Ilya Mironov, and Iden Kalemaj for their technical review and guidance.

Read More

Reducing Storage Footprint and Bandwidth Usage for Distributed Checkpoints with PyTorch DCP

Summary

PyTorch Distributed Checkpointing (DCP) is a versatile and powerful tool for managing model checkpoints in distributed training environments. Its modular design empowers developers to tailor its components to their specific requirements, making it an ideal solution for a wide range of use cases.

In this blog post, we’ll showcase how we leveraged PyTorch DCP’s modularity to integrate compression and achieve a 22% reduction in checkpoint size. We’ll also provide a deep dive into the implementation details of our customization, offering practical insights and guidance on how you can apply similar techniques to optimize your own checkpointing workflows and improve overall efficiency.

Motivation

Large Distributed Checkpoints

As models increase in complexity and size, distributed checkpointing becomes a critical component of the training process. However, these checkpoints often result in substantial storage demands and elevated bandwidth costs due to their large sizes.

Compression

To address this challenge, compression emerges as a natural solution. Given that checkpoints primarily consist of binary data (tensors), we aimed for an optimal compression ratio with minimal compression overhead. We chose the zstd compression algorithm for its efficiency and effectiveness.

DCP

The modular design of DCP, featuring well-defined and easily extensible components, made it an ideal choice as our checkpointing solution.

Details

Customizing StorageWriter

PyTorch DCP’s StorageWriter component is responsible for writing checkpoint data to storage. We customized this component by modifying _FileSystemWriter, which extends the base StorageWriter class. The _FileSystemWriter class now takes an additional parameter _extension, which is an instance of StreamTransformExtension.

def save(
    state_dict: STATE_DICT_TYPE,
    *,
    checkpoint_id: Union[str, os.PathLike, None] = None,
    # We used a _FileSystemWriterextended as a storage writer component
    storage_writer: Optional[StorageWriter] = None, 
    planner: Optional[SavePlanner] = None,
    process_group: Optional[dist.ProcessGroup] = None,
    no_dist: bool = False,
) -> Metadata:

class _FileSystemWriter(StorageWriter):

    def __init__(
        self,
        path: Union[str, os.PathLike],
        single_file_per_rank: bool = True,
        sync_files: bool = True,
        thread_count: int = 1,
        per_thread_copy_ahead: int = 10_000_000,
        overwrite: bool = True,
 # We customized _FileSystemWriterextended to take in an extension
        _extensions: Optional[Sequence[StreamTransformExtension]] = None,
        serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE,
        *args: Any,
        **kwargs: Any,
    ) -> None:

StreamTransformExtension is an abstract class that defines two methods: transform_to(), which is called on an output stream, and transform_from(), which is called on an input stream. These enable us to perform custom transformations on the stream data.

class StreamTransformExtension(Extension):

    @abc.abstractmethod
    def transform_to(self, output: IO[bytes]) -> IO[bytes]:

    @abc.abstractmethod
    def transform_from(self, input: IO[bytes]) -> IO[bytes]:

Implementing ZStandard Compression

We implemented a concrete subclass of StreamTransformExtension called ZStandard, which provides compression functionality using the zstd compression algorithm. Our ZStandard class implements the transform_to() to compress the outgoing stream data and the transform_from() to decompress the incoming stream data.

class ZStandard(StreamTransformExtension):

    def transform_to(self, output: IO[bytes]) -> IO[bytes]:
# Our compression implementation

    def transform_from(self, input: IO[bytes]) -> IO[bytes]:
# Our decompression implementation

Combining Customizations

Finally, we combined our custom _FileSystemWriter class with the ZStandard compression extension while saving the checkpoint. We wrote a sample test to demonstrate how everything comes together

fs_writer = FileSystemWriter(
          path=path,
          thread_count=thread_count,
         _extensions=[ZStandard()],
)

save(
         state_dict=state_dict_to_save,
         storage_writer=fs_writer,
)

Evaluation

Results

In collaboration with IBM, we conducted an evaluation of our proposed solution on one of their internal training clusters. The results showed a significant 22% reduction in checkpoint sizes, albeit at the cost of increased compression time. However, with multi-threading, we were able to mitigate this trade-off and limit the increase in checkpointing time to just 9%. This demonstrates the potential of our solution to strike a balance between checkpoint size reduction and performance.

Model Threads per Rank DCP Checkpoint Size (in GB) Checkpointing Time (s)
Baseline ZStd 𝚫 Baseline ZStd 𝚫
granite-3b-code-instruct 8 6.72 5.26 -21.8% 1.96 2.15 9.7%
4 6.72 5.26 -21.8% 1.98 2.38 20.2%
1 6.72 5.26 -21.8% 2.34 3.86 64.9%
granite-3.2-8b-instruct 8 15.6 12.08 –22.5% 3.37 3.65 8.3%
4 15.6 12.08 –22.5% 3.72 4.37 17.5%
1 15.6 12.08 –22.5% 5.37 8.45 57.4%

Setup

We chose two of IBM’s open sourced models (Granite-3B-Code-Instruct-128K and Granite-3.2-8B-Instruct). For evaluation, we perform full-parameter FSDP fine-tuning on these models with the Alpaca dataset on IBM’s Vela AI supercomputer, which is housed in IBM cloud. Each of Vela’s nodes has eight 80GB A100 GPUs, which are connected to each other by NVLink and NVSwitch. In addition, each node has two 2nd Generation Intel Xeon Scalable processors (Cascade Lake) and 1.5TB of DRAM. We provision one node of Vela with the following resources:

Testbed

  • Openshift 4.14 Cluster
  • Pod: 64 Intel Cascade Lake CPU cores, 800GB host memory, 8 x A100-80GB GPUs
  • Storage options exposed as persistent volumes:
    • 1TB local GPFS
    • S3 bucket

Workload

  • Full-parameter FSDP finetuning with checkpointing every epoch

Checkpointing configuration

  • save_state_dict() to storage
  • 1 to 8 threads per rank
  • 1 file per rank
  • 8 ranks

Conclusion

PyTorch DCP’s modular design empowers developers to tailor its components to specific use cases, unlocking new levels of customization and extensibility. By customizing the StorageWriter component and implementing a compression extension, we achieved significant checkpoint size reductions, leading to lower storage requirements, and reduced bandwidth costs.

We invite you to explore the vast possibilities of PyTorch DCP customization by diving into our documentation and experimenting with various extensions and modifications. Join the conversation on PyTorch GitHub and connect with the PyTorch Checkpointing team (open GitHub issue with label “oncall: distributed checkpointing”) to share your experiences, ask questions, and stay up-to-date on the latest developments!

Read More

PyTorch + vLLM = ♥️

PyTorch + vLLM = ♥️

Key takeaways:

  • PyTorch and vLLM are both critical to the AI ecosystem and are increasingly being used together for cutting-edge generative AI applications, including inference, post-training, and agentic systems at scale. 
  • With the shift of the PyTorch Foundation to an umbrella foundation, we are excited to see projects being both used and supported by a wide range of customers, from hyperscalers to startups and everyone in between.
  • vLLM is leveraging the broader PyTorch ecosystem to accelerate innovation, benefiting from projects such as torch.compile, TorchAO, FlexAttention, and collaborating to support heterogeneous hardware and complex parallelism.
  • The teams (and others) are collaborating to build out PyTorch native support and integration for large-scale inference and post-training. 

Even prior to vLLM joining the PyTorch Foundation, we’ve seen organic and widespread adoption of PyTorch and vLLM together in some of the top companies deploying LLMs at scale in the world. Interestingly, the projects share many commonalities, including: strong visionary leaders; a broad organically-attained multilateral governance structure with committers from several entities, including both industry and academia; and an overwhelming focus on the developer experience. 

Additionally, over the last year plus, we’ve seen the two projects underpin many of the most popular open source LLMs, including the various Llama and DeepSeek models. With such similarities and how complementary the projects are, it’s really exciting to see all of the different integration points. 

PyTorch → vLLM Integrations

The overall goal for integrating at various points is to unlock performance and bring new capabilities to users. This includes optimization and support for Llama models, but also broader open models. 

torch.compile: torch.compile is a compiler that optimizes PyTorch code, delivering fast performance with minimal user effort. While manually tuning a model’s performance can take days, weeks, or even months, this approach becomes impractical for a large number of models. Instead, torch.compile provides a convenient solution for optimizing model performance. vLLM uses torch.compile by default to generate optimized kernels for the majority of its models. Recent benchmarks show significant speedups with torch.compile, ranging from 1.05x to 1.9x speedups on CUDA for popular models like Llama4, Qwen3, and Gemma3.

TorchAO: We’re excited to announce that TorchAO is now officially supported as a quantization solution in vLLM. This integration brings high-performance inference capabilities using Int4, Int8, and FP8 data types, with upcoming support for MXFP8, MXFP4, and NVFP4 optimizations specifically designed for B200 GPUs. Additionally, we’re working on planned FP8 inference support for AMD GPUs, expanding hardware compatibility for high-performance quantized inference.

TorchAO’s quantization APIs are powered by a robust collection of high-performance kernels, including those from PyTorch Core, FBGEMM, and gemlite. TorchAO techniques are designed to compose torch.compile. This means simpler implementations and automatic performance gains from PT2. Write less code, get better performance

One of the most exciting aspects of this integration is the seamless workflow it enables: vLLM users can now perform float8 training using TorchTitan, Quantization-Aware Training (QAT) using TorchTune, then directly load and deploy their quantized models through vLLM for production inference. This end-to-end pipeline significantly streamlines the path from model training and fine-tuning to deployment, making advanced quantization techniques more accessible to developers.

FlexAttention: vLLM now includes FlexAttention – a new attention backend designed for flexibility. FlexAttention provides a programmable attention framework that allows developers to define custom attention patterns, making it easier to support novel model designs without extensive backend modifications.

This backend, enabled by torch.compile, produces JIT fused kernels. This allows for flexibility while maintaining performance for non-standard attention patterns.  FlexAttention is currently in early development within vLLM and not yet ready for production use. We’re continuing to invest in this integration and plan to make it a robust part of vLLM’s modeling toolkit. The goal is to simplify support for emerging attention patterns and model architectures, making it easier to bridge the gap between research innovations and deployment-ready inference.

Heterogeneous hardware: The PyTorch team worked with different hardware vendors and provided solid support for different types of hardware backends, including NVIDIA GPU, AMD GPU, Intel GPU, Google TPU, etc. vLLM inference engine leverages PyTorch as a proxy talking to different hardware, and this significantly simplified the support for heterogeneous hardware.

In addition, PyTorch engineers work closely with other vLLM contributors to support the next generation of NVIDIA GPUs. For example, we have thoroughly tested FlashInfer support in vLLM on Blackwell, conducted performance comparison, and debugged accuracy issues.

The PyTorch team also worked with AMD to enhance the support for vLLM + Llama4, such as day 0 llama 4 support on AMD as well as Llama4 perf optimization on MI300x.

Parallelism: At Meta, we leverage different types of parallelism and their combination in production. Pipeline parallelism (PP) is one important type. The original PP in vLLM has hard dependencies on Ray. However, not all the users leverage Ray to manage their service and coordinate different hosts. The PyTorch team developed the PP with plain torchrun support, and further optimized its approach to overlap the computation between microbatches. In addition, PyTorch engineers also developed the Data Parallelism for vision encoder, which is critical to the multi-modal models’ performance.

Continuous integration (CI): With vLLM critical for the PyTorch ecosystem, we are collaborating to ensure that CI between the projects has good test coverage, is well funded and overall the community can rely on all of these integrations. Just integrating APIs isn’t enough; it’s also important that CI is in place to ensure that nothing breaks over time as vLLM and PyTorch both release new versions and features. More concretely, we are testing the combination of vLLM main and PyTorch nightlies, which we believe will give us and the community the signal needed to monitor the state of the integration between the two projects. At Meta, we started moving some of our development effort on top of vLLM main to stress test various correctness and performance aspects of vLLM.  As well, performance dashboards for vLLM v1 are now available on hud.pytorch.org.

What’s next..

This is just the beginning. We are working together to build out the following advanced capabilities:

1. Large Scale Model Inference: The primary goal is to ensure vLLM runs efficiently at scale on cloud offerings, demonstrates key capabilities (prefill-decode disagg, multi-node parallelism, performant kernels and comms, context-aware routing and fault-tolerance) to scale to thousands of nodes, and becomes a stable foundation for enterprises to build on. In Q2, Meta engineers have prototyped disagg integration on top of the vLLM engine and KV connector APIs. The team is working with the community to try out new strategies and will plan to upstream the most successful ones to push further what can be done with vLLM.

Hardware: H100 GPUs, 96 GB HBM2e, AMD Genoa CPUs, CUDA 12.4

2. Post-training with reinforcement learning:  Inference time compute is quickly becoming critical for LLMs and agentic systems. We are working on end-to-end native post-training that incorporates RL at large scale with vLLM as the inference backbone of the system. 

Cheers!

-Team PyTorch (at Meta) & Team vLLM

Read More

FlagGems Joins the PyTorch Ecosystem: Triton-Powered Operator Library for Universal AI Acceleration

FlagGems Joins the PyTorch Ecosystem: Triton-Powered Operator Library for Universal AI Acceleration

FlagGems

In the race to accelerate large language models across diverse AI hardware, FlagGems delivers a high-performance, flexible, and scalable solution. Built on Triton language, FlagGems is a plugin-based PyTorch operator and kernel library designed to democratize AI compute. Its mission: to enable a write-once, JIT-everywhere experience, so developers can deploy optimized kernels effortlessly across a wide spectrum of hardware backends.  FlagGems recently joined the PyTorch Ecosystem upon acceptance by the PyTorch Ecosystem Working Group.

With over 180 operators already implemented—spanning native PyTorch ops and widely used custom ops for large models—FlagGems is evolving fast to keep pace with the generative AI frontier.

To view the PyTorch Ecosystem, see the PyTorch Landscape and learn more about how projects can join the PyTorch Ecosystem.

Key Features

  • Extensive Operator Library: 180+ PyTorch-compatible operators and growing
  • Performance Optimized: Select operators hand-tuned for speed
  • Torch.compile Independent: Fully functional in eager mode
  • Pointwise Operator Codegen: Auto-generates kernels for arbitrary input types and layouts
  • Fast Kernel Dispatching: Per-function runtime dispatch logic
  • C++ Triton Dispatcher: In development for even faster execution
  • Multi-Backend Ready: Works across 10+ hardware platforms with a backend-neutral runtime API

Architecture

FlagGems extends the PyTorch dispatch system out-of-tree through a multi-backend library powered by Triton. It intercepts ATen operator calls and provides backend-specific Triton implementations, making it easy to support alternative GPUs and domain-specific accelerators (DSAs).

Plug-and-Play

  • Registers with PyTorch’s dispatch system
  • Intercepts ATen operator calls
  • Seamlessly replaces CUDA operator implementations

Write Once, Compile Anywhere

  • Unified operator library code
  • Compilable on any backend with Triton support
  • Supports GPUs and heterogeneous chips like DSAs

Getting Started in 3 Steps

  1. Install dependencies
pip install torch>=2.2.0  # 2.6.0 preferred

pip install triton>=2.2.0 # 3.2.0 preferred
  1. Install FlagGems
git clone https://github.com/FlagOpen/FlagGems.git

cd FlagGems

pip install --no-build-isolation .

or editable install:

pip install --no-build-isolation -e .
  1. Enable FlagGems in your project
import flag_gems

flag_gems.enable()  # Replaces supported PyTorch ops globally

Prefer finer control? Use a managed context:

with flag_gems.use_gems():

    output = model.generate(**inputs)

Need explicit ops?

out = flag_gems.ops.slice_scatter(inp, dim=0, src=src, start=0, end=10, step=1)

Automatic Codegen for Pointwise Ops

With the @pointwise_dynamic decorator, FlagGems can auto-generate efficient kernels with broadcast, fusion, and memory layout support. Here’s an example implementing fused GeLU and element-wise multiplication:

@pointwise_dynamic(promotion_methods=[(0, 1, “DEFAULT”)])

@triton.jit
def gelu_tanh_and_mul_kernel(x, y):

    x_fp32 = x.to(tl.float32)

    x_gelu = 0.5 * x_fp32 * (1 + tanh(x_fp32 * 0.79788456 * (1 + 0.044715 * pow(x_fp32, 2))))

    return x_gelu * y

Performance Validation

FlagGems includes built-in testing and benchmarking:

  • Accuracy Testing
cd tests

pytest test_<op_name>_ops.py --ref cpu
  • Performance Benchmarks
cd benchmark

pytest test_<op_name>_perf.py -s  # CUDA microbenchmarks

pytest test_<op_name>_perf.py -s --mode cpu  # End-to-end comparison

Benchmark Results

FlagGems Speedup

Initial benchmark results of FlagGems showcase its performance against PyTorch’s native operator implementations. The results represent the average measured speedups, a value greater than 1 indicating that FlagGems is faster than the native PyTorch operator. For a vast majority of operators, FlagGems either matches or significantly surpasses the performance of PyTorch’s native implementations.

For a significant portion of the 180+ operators, FlagGems achieves a speedup close to 1.0, indicating performance on par with the native PyTorch implementations.

Some of the core operations like LAYERNORMCROSS_ENTROPY_LOSSADDMM and SOFTMAX also show impressive speedups.

Multi-Backend Support

FlagGems is vendor-flexible and backend-aware:

Set the desired backend with:

export GEMS_VENDOR=<vendor>

Check active backend in Python:

import flag_gems

print(flag_gems.vendor_name)

meetups

Summary

FlagGems delivers a unified kernel library for large models acceleration that bridges software portability and hardware performance. With broad backend support, a growing op set, and advanced codegen features, it’s your go-to Triton playground for pushing the limits of AI compute.

Read More

Presenting Flux Fast: Making Flux go brrr on H100s

Presenting Flux Fast: Making Flux go brrr on H100s

In our earlier post, diffusion-fast, we showed how the Stable Diffusion XL (SDXL) pipeline can be optimized up to 3x using native PyTorch code. Back then, SDXL was an open SoTA pipeline for image generation. Quite unsurprisingly, a lot has changed since then, and it’s safe to say that Flux is now one of the most capable open-weight models in the space.

In this post, we’re excited to show how we enabled ~2.5x speedup on Flux.1-Schnell and Flux.1-Dev using (mainly) pure PyTorch code and a beefy GPU like H100.

If you cannot wait to get started with the code, you can find the repository here.

Overview of the optimizations

The pipelines shipped in the Diffusers library try to be as `torch.compile`-friendly as possible. This means:

  • No graph breaks wherever possible 
  • No recompilations wherever possible
  • None-to-minimal CPU<->GPU syncs to reduce inductor cache lookup overhead

Therefore, it already gives us a reasonable starting point. For this project, we took the same underlying principles used in the diffusion-fast project and applied them to the FluxPipeline. Below, we share an overview of the optimizations we applied (details in the repository):

  • `torch.compile` with “fullgraph=True” and “max-autotune” mode, ensuring the use of CUDAgraphs
  • Combining q,k,v projections for attention computation. This is particularly helpful during quantization as it thickens the dimensionality, improving compute density
  • `torch.channels_last` memory format for the decoder output
  • Flash Attention v3 (FA3) with (unscaled) conversion of inputs to `torch.float8_e4m3fn`
  • Dynamic float8 activation quantization and quantization of Linear layer weights via torchao’s `float8_dynamic_activation_float8_weight`
  • Some flags for tuning Inductor performance on this model:
    • conv_1x1_as_mm = True
    • epilogue_fusion = False
    • coordinate_descent_tuning = True
    • coordinate_descent_check_all_directions = True
  • torch.export + Ahead-of-time Inductor (AOTI) + CUDAGraphs

Most of these optimizations are self-explanatory, barring these two:

  • Inductor flags. Interested readers can check out this blog post for more details.
  • With AoT compilation, we aim to eliminate the framework overhead and obtain a compiled binary that can be exported through `torch.export`. With CUDAGraphs, we want to enable optimization of kernel launches. More details are available in this post.

Unlike LLMs, diffusion models are heavily compute-bound, so optimizations from gpt-fast don’t exactly carry over here. The figure below shows the impact of each of the optimizations (applied incrementally from left-right) to Flux.1-Schnell on an H100 700W GPU:

For Flux.1-Dev on H100, we have the following

Below is a visual comparison of the images obtained with different optimizations applied to Flux.1-Dev:

It should be noted that only FP8 quantization is lossy in nature so, for most of these optimizations, the image quality should stay identical. However, in this case, we see very negligible differences in the case of FP8.

Note on CUDA syncs

During our investigations, we found out that at the first step of the denoising loop, there’s a CPU<->GPU sync point caused by this step in the scheduler. We could get rid of it by adding `self.scheduler.set_begin_index(0)` at the beginning of the denoising loop (PR). 

It actually makes a bigger deal when torch.compile is used, since the CPU has to wait for the sync before it can do a Dynamo cache lookup and then launch instructions on the GPU, and this cache lookup is a bit slow. Hence, the takeaway message is that it’s always wise to profile your pipeline implementation and to try to eliminate these syncs as much as possible to benefit compilation.

Conclusion

The post went over a recipe to optimize Flux for Hopper architectures using native PyTorch code. The recipe tries to balance between simplicity and performance. Other kinds of optimizations are also likely possible (such as using fused MLP kernels and fused adaptive LayerNorm kernels), but for the purpose of simplicity, we didn’t go over them. 

Another crucial point is that GPUs with the Hopper architecture are generally costly. So, to provide reasonable speed-memory trade-offs on consumer GPUs, there are other (often `torch.compile`-compatible) options available in the Diffusers library, too. We invite you to check them here and here.

We invite you to try these techniques out on other models and share the results. Happy optimizing!

Read More

Fault Tolerant Llama: training with 2000 synthetic failures every ~15 seconds and no checkpoints on Crusoe L40S

Fault Tolerant Llama: training with 2000 synthetic failures every ~15 seconds and no checkpoints on Crusoe L40S

Collaborators: Less Wright, Howard Huang, Chien-Chin Huang, Crusoe: Martin Cala, Ethan Petersen

tl;dr: we used torchft and torchtitan to train a model in a real-world environment with extreme synthetic failure rates to prove reliability and correctness of fault tolerant training

Training loss across 1200 failures with no checkpoints. 

NOTE: Each small spike is a non-participating worker recovering which affects the metrics but not the model

Introduction

We want to demonstrate torchft in worst case scenarios by running a training job with the most extreme failure rates possible.

Most LLM pre-training uses sharded models using FSDP. torchft supports sharded models using HSDP2, which combines a sharded model with the fault tolerant DDP all reduce from torchft. We’ve integrated torchft into torchtitan so you can use fault tolerance out of the box. torchft+titan also support other sharding/parallelisms within each replica group, such as tensor parallelism (TP), pipeline parallelism (PP) and more.

Here’s the structure of a training job with torchft:

 

The structure of the training job. torchft’s fault tolerant DDP implementation is used across the replica groups to synchronize the gradients. Standard FSDP2 and other parallelisms are used within each replica group.

torchft uses a global Lighthouse server and per replica group Managers to do the real time coordination of workers. The Lighthouse knows the state of all workers and which ones are healthy via heartbeats.

torchft implements a few different algorithms for fault tolerance. The two most primary ones are:

  • Fault Tolerant HSDP: An extension of FSDPv2 that uses a fault tolerant all-reduce. This exactly emulates standard HSDP training with per step all_reduce of the gradients and per step fault tolerance. This works best for large scale training with fast backend networks such as infiniband.
  • LocalSGD/DiLoCo: A fault tolerant implementation of semi-sync training. These algorithms minimize communication overhead by synchronizing at specified intervals instead of every step like HSDP. This is often used in communication limited training scenarios such as over ethernet/TCP or in geographically separate locations (federated learning or multidatacenter training).

We’re always keeping an eye out for new algorithms, such as our upcoming support for streaming DiLoCo. If you have a new use case you’d like to collaborate on, please reach out!

Cluster Setup

Crusoe graciously lent us a cluster of 300 L40S GPUs. The GPUs were split up across 30 hosts, each with 10 NVIDIA L40S GPUs.

For the model, we used torchtitan with a Llama 3 model with 1B parameters to match the hardware available. 

NVIDIA L40S GPUs are typically used for inference and thus gave us an opportunity to test torchft in a non-traditional environment where things such as DiLoCo really shine due to the lower TCP-only (no infiniband/nvlink) network bottleneck. The L40S has 48GB of VRAM (closer to consumer GPUs) so we used a smaller model and batch size. The average step time for training was ~9s each.

To maximize performance with the limited network, we trained the model in a 30x1x10 configuration. We had 30 replica groups (fault tolerant domains), each with 1 host and 10 gpus/workers. torchft can have many, many hosts in each replica group, but for this cluster, a single host/10 gpus per replica group had the best performance due to limited network bandwidth. We ran with 30 replica groups, as more groups stressed the coordination and reconfiguration algorithms more.

For network communication, we used NCCL for all communication (i.e., FSDP) within each replica group and Gloo for communication across replica groups. Gloo, while often not as performant, initializes much faster and can also fail much faster, which is important for quick detection of failures. torchft does support fault tolerance using NCCL for IB clusters with some caveats but wasn’t used in this demo. Since we wanted to maximize the total number of failures and recoveries, we used Gloo since it can reinitialize in <1s for our use case, and we were able to set the timeout on all operations at 5s.

For the fault tolerance algorithms, we did the bulk of the testing with Fault Tolerant HSDP, as it stresses the communication and quorum layers the most. For the final test, we used DiLoCo, which is a better fit for the ethernet based cluster.

Recovering with No Checkpoints

Traditional machine learning achieves “fault tolerance” by reloading from checkpoints when an error occurs. This involves a complete stop-the-world operation where all workers restart and load from the most recently persisted checkpoint.

With torchft, we instead focus on isolating failures to an individual group of GPUs. When an error occurs within that group we can restart that group asynchronously and all other groups can reconfigure and continue training without that group.

When that group recovers through a restart or the scheduler replaces the machines, those workers no longer have a valid copy of the weights and optimizer states. If we tried to recover from a checkpoint, the other groups would have already moved on. Instead, we rely on an asynchronous weight transfer at runtime. This does a peer-to-peer transfer of the weights from a healthy replica.

Since we’re always recovering from another worker – it turns out that we actually don’t need any checkpoints as long as we can guarantee that at least one group is healthy. For this demonstration, we turned off checkpointing entirely as a persistent checkpoint save and load is much longer than our P2P recovery time.

Here’s a diagram showing how a recovering replica (replica 1) can join the quorum and recover from a healthy peer (replica 0) without having any downtime or impacting the healthy worker training:

torchft adapts a number of concepts from distributed databases:

  • The quorum operation determines which workers are healthy using frequent heartbeats and guarantees that we can quickly determine which workers are alive, exchange metadata in a fault tolerant way, and enforce no split-brain conditions.
  • To ensure consistency and identify when we need to recover a worker, we effectively treat training with traditional database semantics. Traditional databases use “transactions” where each operation is either committed (entirely applied) or rolledback (discarded). torchft treats each training step the same way. Each training step within a replica group is handled as a distributed transaction, where we ensure all workers commit the step by stepping the optimizer or if an error occurs they all rollback by discarding the gradients.

For more details, please see the torchft README, which has links to the documentation, design docs, and presentations. 

Training Loop Integration

TorchFT has already been integrated with TorchTitan, and thus, enabling it is just a matter of setting a configuration flag. For a typical model, torchft provides wrappers which automatically call hooks into torchft’s Manager to provide fault tolerance.

from torchft import Manager, DistributedDataParallel, Optimizer, ProcessGroupGloo

# Instantiate your model and optimizer as normal
m = nn.Linear(2, 3)
optimizer = optim.AdamW(m.parameters())

# Setup torchft Manager and wrap the model and optimizer.
manager = Manager(
    pg=ProcessGroupGloo(),
    load_state_dict=lambda state_dict: m.load_state_dict(state_dict),
    state_dict=lambda: m.state_dict(),
)
m = DistributedDataParallel(manager, m)
optimizer = Optimizer(manager, optimizer)

for batch in dataloader:
    # When you call zero_grad, we start the asynchronous quorum operation 
    # and perform the async weights recovery if necessary.
    optimizer.zero_grad()

    out = m(batch)
    loss = out.sum()

    # The gradient allreduces will be done via torchft's fault tolerant 
    # ProcessGroupGloo wrapper.
    loss.backward()

    # The optimizer will conditionally step depending on if any errors occured. 
    # The batch will be discarded if the gradient sync was interrupted.
    optimizer.step()

Fault Tolerant Scheduling

We can use standard ML job schedulers such as Slurm since the semantics for the workers within a replica group are the same as a normal job. If an error occurs on any of the workers within a group we expect the entire group to restart simultaneously. Within each replica group, the application is a completely standard training job using standard non-fault tolerant operations. 

To achieve fault tolerance on a traditional scheduler, we run multiple of these jobs. Each replica group ran on Slurm as a separate training job with the Lighthouse and a monitoring script running on the head node. All the cross-group communication is done via torchft’s managed ProcessGroup and quorum APIs. To restart groups on failure and inject failures we used a small script using the torchx Python API.

The monitoring script looks something like this:

from torchx.runner import get_runner

NUM_REPLICA_GROUPS = 30

with get_runner() as runner:
    while True:
        jobs = runner.list(scheduler)
        active_replicas = {
            parse_replica_id(job.name)
            for job in jobs
            if not job.is_terminal()
        }

        missing_replicas = set(range(NUM_REPLICA_GROUPS)) - active_replicas

        for replica_id in missing_replicas:
            app_def = make_app_def(replica_id=replica_id)
            app_handle = runner.run(
                app_def, 
                scheduler="slurm", 
                cfg={"partition": "batch"},
            )
            print("launched:", replica_id, app_handle)

        time.sleep(5.0)

The failures were injected by cancelling the specific replica group’s Slurm job using scancel. In a real world scenario we would expect the failure to be triggered by an error in the training process which would crash that replica group in isolation rather than an external failure.

Metrics and Logs

To ensure we had a consistent view of the job, we avoided injecting failures into one replica group to make it simpler to track metrics and quorum events for the job. That one group was able to consistently log the number of participants, step success/failures, and the loss.

Since we’re doing per step fault tolerance, the number of participants and thus batch size changes per step depending on which workers are healthy.

The loss is averaged across all workers/replica groups in the job using an allreduce across replica groups.

Note: the small little spikes in the loss graphs below are due to how we average the loss across all hosts, including recovering workers, which have out of date weights, which leads to incorrectly higher loss on those steps.

Runs

We ran three different runs showcasing various failure scenarios and features of torchft.

Run 1: Injected Failure Every 60s for 1100 Failures

This run lasted a little over 19 hours and 6249 steps. On average, each step took 10.9 seconds.

For the initial run, we injected a failure every 60 seconds with a very repeatable pattern. We initially had a bad machine in the cluster, so we briefly shrunk the world size to 25 hosts until the machine was replaced, and we scaled the job back up with zero downtime.

With the failure every 60s we expected to be able to do ~5 steps between each failure without any issue. Looking at the results, we see that there were 6249 steps and 5145 successful commits. torchft is designed to be as safe as possible, and if any errors occurred, it will discard the step via “should_commit” prior to running the optimizer.

For the overall step efficiency, we have:

5145 successful steps / 6249 total steps = 82.3%

With a step time of ~11 seconds and a failure every 60 seconds we should be able to complete 5 out of every 6 steps (83.3%) and that matches almost exactly with the measured performance.

We averaged 29.6 participating replica groups per step, so the total training efficiency of this was 81.2%. Not bad for over 1000 failures.

Run 2: Injected Failure Every 15s for 1015 Failures

We wanted to see how much further we could push this and also make it even more challenging. For the second run, we ran with a failure injected between 0-30 seconds with a failure on average every 15 seconds. 

This failure rate is extreme compared to training jobs, which typically have mean time between failures in the 10s of minutes to hours range, but lets us validate that we can recover no matter when the error happens and lets us run a huge amount of test cycles to gain confidence in our implementation.

By randomizing the failure interval, we cause failures to happen while workers are still initializing rather than in steady state and are much more likely to hit edge cases. We’re happy to report that torchft behaved as expected with no unrecoverable errors.

As you can see, this job is behaving much more erratically. Rather than the very close to 30 machines we had with a 60 second failure rate, with a failure every 15 seconds we’re anywhere from 1 machine to 30 machines available on each step. 

On average, we had 18.9 (18.9/30 = 63%) workers healthy and participating on any given step and an average step time of 15.46 seconds.

Out of the first 888 steps, 268 of those steps were committed successfully, which gives us a 30.2% step efficiency.

This gives us training efficiency of just 13.4%, which in any normal training job would be terrible but it’s remarkable that the model is converging despite a crash every 15 seconds! Just loading a model from a checkpoint often takes longer than 1 minute.

The loss converges slower as compared to our 60s MTBF run, but that’s expected as many more batches are being discarded due to errors.

We do see some bigger spikes in the loss, which are correlated with times when only 1 participant is healthy and thus has 1/30th the batch size. This is easily avoided by adjusting the minimum number of replicas. We had it set to 1 for this test.

Run 3: Semi-synchronous Training

TorchFT also supports semi-synchronous training algorithms, including LocalSGD and DiLoCo, with plans to add more in the future. Unlike HSDP2, these algorithms do not synchronize at every step. Instead, they perform local training for several steps before synchronizing weights through averaging parameters or gradients. This approach enhances performance by reducing communication costs to once every N steps (a configurable hyperparameter), rather than at every step. Our tests on the cluster demonstrate a noticeable improvement in throughput. When synchronizing every 40 steps, we minimize the communication overhead, resulting in higher overall throughput. Below is a comparison of DiLoCo’s throughput (yellow), averaging around 4000 tps, compared with that of regular HSDP2 (purple), which averages around 1200 tps.

Naturally, the longer the interval between synchronizations, the more the models within replica groups will diverge. This divergence can potentially impact the convergence of the model. However, in our testing, we observed that the model was still able to train effectively and reach convergence despite these longer synchronization intervals. This resilience is beneficial in dynamic environments where replicas might leave the group unexpectedly. Even in such scenarios, the model demonstrated the ability to continue training without significant disruption.

Next Steps

torchft is under active development, and we have a lot of planned improvements around newer algorithms such as streaming DiLoCo, making PyTorch Distributed more robust to failures (even on infiniband/nvlink!), and even more efficient. 

If you’re interested in using torchft please take a look at torchft README and torchft Documentation. We’d also love to chat with you, so feel free to reach out directly via GitHub, LinkedIn, or Slack.

Read More

PyTorch Docathon 2025: Wrap Up

PyTorch Docathon 2025: Wrap Up

Huge congratulations and a massive thank you to all the amazing participants of the PyTorch Docathon 2025!

Over the past two weeks (June 3rd-15th), our virtual Docathon brought together over 150+ registrants who actively contributed to resolving long-standing documentation issues. We’re thrilled to announce that your efforts resulted in more than 60+ merged pull requests across two PyTorch repositories!

We’d like to extend a special shout-out to our top contributors who went above and beyond during this event. Your dedication, expertise, and commitment to improving PyTorch documentation are truly inspiring. You’re the driving force behind open source projects like PyTorch, and we’re grateful for your contributions. 

First place: j-silv, kiszk, windsonsea

Second place: Rachel0619, jafraustro, loganthomas, nirajkamal, Dhia-naouali

Third place: Juliandlb, ggsmith842, ParagEkbote

PyTorch Docathon Top Community Contributors

Check out the full list of contributors here.

As we wrap up this Docathon, we encourage you to keep pushing the boundaries of what’s possible with PyTorch. Your collective efforts are revolutionizing the AI community, and we can’t wait to see what you achieve next.

Thank you again for being part of this incredible journey. Keep contributing, innovating, and inspiring others!

Team PyTorch

Read More

DeepNVMe: Affordable I/O scaling for Deep Learning Applications

DeepNVMe: Affordable I/O scaling for Deep Learning Applications

Introduction

We introduced DeepNVMe in summer 2024 as a suite of optimizations for tackling I/O bottlenecks in Deep Learning (DL). DeepNVMe delivers significant speedups for I/O bound DL workloads by leveraging storage innovations including local NVMe SSDs, NVIDIA Magnum IOTM GPUDirect® Storage (GDS), and Linux Asynchronous I/O (AIO). In this update, we are delighted to announce DeepNVMe improvements on multiple fronts: (i) expanding application coverage to FastPersist model checkpointing and SGLang inference, (ii) I/O performance scaling by upgrading from PCIe Gen4 to Gen5 NVMe SSDs, and (iii) expanding usability to CPU-only environments, offset-based I/O operations, and tensor data type casting. The results reported in this blog are available in DeepSpeed versions >= 0.17.1.

Evaluation environments

Our experiments are conducted on Azure ND-H200-v5 VM. The key software configurations are summarized in the following table.
Software Version
Ubuntu 24.04.2
PyTorch 2.6.0
CUDA 12.6
SGLang 0.4.4.post4

Addressing I/O Bottlenecks of Deep Learning

We used DeepNVMe to develop FastPersist and ZeRO-Inference to target I/O bottlenecks in DL training and inference respectively. Our experiments are conducted using a single VM, in which we combine the available NVMe SSDs into a single RAID-0 (i.e., disk striping) volume to leverage aggregate read and write bandwidths. Since DeepNVMe can offload tensors using CPU bounce buffers (a.k.a., AIO), or NVIDIA GPUDirect Storage (a.k.a., GDS), we report results for both modes.

FastPersist: Faster Model Checkpoint Creation

Although saving model checkpoints to persistent storage is critical in model training, it is also a major bottleneck due to the inefficiencies of existing approaches. We developed FastPersist to address the performance challenges of checkpointing. FastPersist makes checkpointing overheads negligible during training through three key techniques: (i) DeepNVMe, (ii) data parallelism, and (iii) overlapping I/O and computation.

Our goal here is to demonstrate the impact of DeepNVMe in FastPersist using single-process micro-benchmarks (available here), which serialize a model checkpoint state from HBM to local NVMe. We use the popular PyTorch torch.save() as the baseline in our experiments, and integrate FastPersist into torch.save() to simplify adoption and performance comparisons.

Faster Saving of PyTorch Models to local NVMe Storage

We measure the throughput of serializing Phi-3-Mini checkpoint state from HBM to local NVMe storage. The results are summarized in the Figure below. We observe significantly faster checkpointing with FastPersist compared to the baseline. We see speedups of over 20X in the 8xGen5 NVMe settings. We also observe FastPersist scaling with increased NVMe bandwidth of 8xGen5 compared with 4xGen5.

FastPersist provides significantly faster model checkpointing to local NVMe.

ZeRO-Inference: Democratizing Generative AI

ZeRO-Inference is a technology that democratizes access to state-of-the-art models by reducing the GPU costs of model inference. ZeRO-Inference enables inference computations of massive models (hundreds-of-billions of parameters) on as few as one GPU by offloading the model weights to DRAM and NVMe storage. ZeRO-Inference is designed for offline or throughput-oriented inference scenarios. In this blog, we share two updates on ZeRO-Inference. First, we have integrated ZeRO-Inference into SGLang, a state-of-the-art model serving framework. Second, we observed ZeRO-Inference performance scales with the faster NVMe SSDs in the latest Azure SKUs.

Democratizing SGLang through ZeRO-Inference integration

SGLang is a state-of-the-art serving framework for large language models (LLMs) and vision language models (VLMs). Our integration of ZeRO-Inference into SGLang makes SGLang available to budget-constrained users, and offers a cost-reduction option to existing SGLang users. We used SGLang’s offline benchmarking tool to measure the generation throughput of LLAMA3-70B on a single H200 with NVMe offloading (LLAMA3-70B cannot fit in the 141GB VRAM without offloading). The experiment is configured with prompt length of 512, generation length of 32, and batch size of 128. We summarize the results in the figure below for both AIO and GDS offloading.

ZeRO-Inference improves SGLang inference with NVMe offloading to reduce hardware costs.

Scaling HF Transformer Generation with Faster NVMe SSDs

ZeRO-Inference enhances HF Transformer inference with efficient model offloading to DRAM or NVMe. We previously evaluated LLAMA-3-70B generation performance with NVMe offloading on a single GPU and four Gen4 NVMes in an Azure NC_A100_v4 VM. We measured the generation speed for a prompt of 512 tokens, output of 32 tokens, and batch size 96. Since NVMe bandwidth was the main bottleneck, we repeat the experiments on Azure ND-H200-v5 offering Gen5 NVMes. The results summarized in the Figure below show that ZeRO-Inference uses the increased NVMe bandwidths to improve generation speeds. For example, with GDS, generation speed improves from 7 tokens/sec with four Gen4 NVMes to 17 tokens/sec with four Gen5 NVMes, and further to 26 tokens/sec with eight Gen5 NVMes. We observe similar improvements without GDS. These results show that ZeRO-Inference performance can be improved in cost-effective manner by increasing NVMe bandwidths.

ZeRO-Inference leverages available NVMe bandwidth to scale LLAMA-3-70B generation.

I/O performance scaling

We used our ds_io benchmarking tool to demonstrate DeepNVMe proportionally scaling I/O performance with available NVMe bandwidths. This empowers users to accelerate I/O bound DL applications at modest cost using more or faster NVMe SSDs. In our experiments, we measure the achieved read and write bandwidths of 1GB data transfers between HBM and NVMes. We evaluate scaling up NVMes from PCIe Gen4 to Gen5, and scaling out from 4 to 8 SSDs. The SSDs are combined into a single RAID-0 (disk striping) volume. We summarize the results in the Figure below which show that DeepNVMe scales I/O performance on both dimensions. Scaling up from 4xGen4 SSDs to 4xGen5 SSDs improves reads from 10GB/sec to 27GB/sec, and writes from 5GB/sec to 11GB/sec. Scaling out from 4xGen5 to 8xGen5 further improves reads to 48GB/sec, and writes to 26GB/sec.

Microbenchmark shows DeepNVMe scales I/O performance with available NVMe bandwidth

Broadening usability

We have increased the usage scenarios of DeepNVMe by removing restrictions regarding hardware environments and I/O operations, as explained below.

CPU-Only environments

Although GPUs (and similar accelerators) dominate DL, CPUs are still used in important machine learning (ML) workloads such as recommendation systems. However, DeepNVMe was previously unusable in CPU-only environments. This was because DeepNVMe relied on torch.pin_memory() for page-locked CPU tensors, whereas torch.pin_memory() does not work in the CPU versions of torch as illustrated below.

>>> import torch
>>> torch.__version__
'2.6.0+cpu'
>>> x = torch.empty(1024).pin_memory()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Cannot access accelerator device when none is available.
>>>

We have made DeepNVMe usable in CPU environments by adding mechanisms for allocating (new_cpu_locked_tensor()) and releasing (free_cpu_locked_tensor()) page-locked CPU tensors. The snippet below illustrates allocating a pinned CPU tensor (x).

>> import torch
>>> torch.__version__
'2.6.0+cpu'
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> h = AsyncIOBuilder().load().aio_handle()
>>> x = h.new_cpu_locked_tensor(1024, torch.Tensor())
>>> x.shape
torch.Size([1024])
>>> x.dtype
torch.float32

Offset-based I/O operations

Previously, DeepNVMe functionality was restricted to reading or writing the entire contents of a file. We have now improved DeepNVMe to read or write a user-specified portion of file content from a given offset. In particular, we have extended the existing read/write APIs to accept a user-specified file offset argument (with default value 0) such as below:

>> from deepspeed.ops.op_builder import AsyncIOBuilder
>>> help(AsyncIOBuilder().load().aio_handle().pread)
Help on method pread in module async_io:

pread(...) method of async_io.aio_handle instance
pread(self: async_io.aio_handle, buffer: torch.Tensor, filename: str, validate: bool, async: bool, file_offset: int = 0) -> int

Tensor data type casting

While developing FastPersist, we needed to manipulate model tensors, typically of floating point data types, in byte format for both performance and convenience of I/O operations. However, we could not find a zero-copy mechanism for casting tensors from arbitrary data types to a byte data type (i.e., torch.uint8), so we decided to create one. This functionality is available via the UtilsBuilder op as demonstrated in the example below. In the example, we cast a torch.bfloat16 tensor into torch.uint8. Note that due to the zero-copy nature of the functionality, bf16_tensor and byte_tensor are aliases.

>>> import torch
>>> from deepspeed.ops.op_builder import UtilsBuilder
>>> util_ops = UtilsBuilder().load()
>>> bf16_tensor = torch.zeros(1024, dtype=torch.bfloat16, device='cuda')
>>> bf16_tensor
tensor([0., 0., 0., ..., 0., 0., 0.], device='cuda:0', dtype=torch.bfloat16)
>>> byte_tensor = util_ops.cast_to_byte_tensor(bf16_tensor)
>>> byte_tensor
tensor([0, 0, 0, ..., 0, 0, 0], device='cuda:0', dtype=torch.uint8)
>>> bf16_tensor += 1.0
>>> bf16_tensor
tensor([1., 1., 1., ..., 1., 1., 1.], device='cuda:0', dtype=torch.bfloat16)
>>> byte_tensor
tensor([128, 63, 128, ..., 63, 128, 63], device='cuda:0',
dtype=torch.uint8)

Summary

This blog post has provided updates on our continued development of DeepNVMe, an I/O optimization technology for accelerating DL applications. We have announced DeepNVMe improvements on multiple aspects, including application coverage, I/O performance scaling, and usability.

Acknowledgements

This blog describes work done by Joe Mayer, Logan Adams, and Olatunji Ruwase of the DeepSpeed team at Microsoft.

Read More

ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization

ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization

The field of large language models is shifting toward lower-precision computation. This shift necessitates a rethinking of scaling laws to account for the effects of quantization on resulting quantized model performance. In this work, we demonstrate that previous conclusions on the low-bit scaling laws can be significantly sharpened by better quantization scheme design and training improvements.

We propose ParetoQ, the first algorithm that unifies binary, ternary, and 2-to-4 bit quantization-aware training. ParetoQ demonstrates its robustness by yielding state-of-the-art (SOTA) models at all bit widths, surpassing prior works tailored for individual bit levels. We’ve released the MobileLLM low-bit model collection on Hugging Face, featuring models quantized with our ParetoQ method. The smallest model is an ultra-efficient 1-bit 125M variant, with just ~16MB equivalent storage size.

These SOTA points in the Pareto chart ensure that our scaling law comparisons are both reliable and consistent, as they derive from homogeneous settings. Our scaling laws reveal that binary quantization significantly compromises accuracy, while ternary, 2-bit, and 3-bit quantization are tied in performance, often surpassing 4-bit. 

ParetoQ is based on PyTorch models, including LLaMA and MobileLLM. We utilized a popular PyTorch Library: HuggingFace Transformers for accuracy experiments. For the latency experiments, we utilize the low-bit quantization kernels on the CPU with ExecuTorch. We compared their speed with that of 4-bit quantization. Additionally, we implemented state-of-the-art 2-bit GPU kernels, which showed up to a 4.14x speedup compared to FP16 and a 1.24x speedup over the Machete 4-bit kernel on TritonBench.

ParetoQ has been integrated into torchao [pull]. This integration enables users to leverage ParetoQ by specifying “paretoq” as the quantization method within torchao’s codebase. Once set, the users can utilize torchao’s ParetoQ workflow, optimizing quantization parameters to balance accuracy and compression trade-offs and compare different quantization bit`s apple-to-apple using Pareto frontier analysis. This allows for the efficient deployment of models on edge devices without requiring manual tuning of quantization settings. 

To obtain the ParetoQ-quantized models, simply navigate to the torchao/prototype/paretoq directory and execute the training script:

cd torchao/prototype/paretoq && bash 1_run_train.sh $w_bit

Here, $w_bit specifies the target weight bit-width for quantization.

ParetoQ code is available at: https://github.com/facebookresearch/ParetoQ

Paper link: https://arxiv.org/abs/2502.02631 

1 A Better QAT Scheduling Strategy for Extreme Low-Bit LLMs

1.1 Training Budget Allocation

Given a fixed training budget B_train = B_FPT +B_QAT, how should the budget be optimally allocated between full-precision training (B_FPT) and quantization-aware training/fine-tuning (B_QAT) to maximize the accuracy of the quantized model?

Figure 1: Optimal allocation between full-precision pretraining and QAT fine-tuning.

Finding-1 QAT finetuning consistently surpasses both PTQ with B_FPT = B_train and QAT from scratch with B_QAT = B_train. Optimal performance is nearly achieved by dedicating the majority of the training budget to full precision (FP) training and approximately 10% to QAT.

1.2 Fine-tuning Characteristics

Figure 2: Analysis of training token requirements for quantization-aware fine-tuning and training from scratch

Finding-2 While fine-tuning enhances performance across all bit-widths, even binary and ternary, optimal fine-tuning effort inversely correlates with bit-width. For 3-bit and 4-bit weights, fine-tuning adjusts within a nearby grid to mitigate accuracy loss and requires less fine-tuning tokens. In contrast, binary and ternary weights break the grid, creating new semantic representations to maintain performance, requiring longer fine-tuning.

Figure 3: L1 norm difference between QAT-finetuned weights and full-precision initialization (||W_finetune −W_init||_l1 /||W_init||_l1).

2 A Hitchhiker’s Guide to Quantization Method Choices

In sub-4-bit quantization, the choice of function is highly sensitive and can drastically alter scaling law outcomes.

 

 

Figure 4: Impact of quantization grid choice across bit widths. 2.1.1 Range clippingCompared to statistics-based quantization (e.g., min-max quantization), learnable scales which optimize quantization ranges as network parameters, balancing outlier suppression and precision, yields more stable and superior performance. As shown in Figure (b)-(e), learnable policies consistently outperform stats-based methods across all bit widths.

2.1.2 Quantization grids

Level symmetry in quantization grids is vital for lower-bit quantization but often overlooked. Including “0” in even-level quantization (e.g., 2-bit, 3-bit, 4-bit) can cause imbalance. For instance, 2-bit quantization options like (-2, -1, 0, 1) limit positive representation to only one level, while (-1.5, -0.5, 0.5, 1.5) offers more balanced representation. We propose Stretched Elastic Quant (SEQ) to address this in lower-bit scenarios.

SEQ balances quantized levels and evenly divides the full-precision weight span, crucial for extremely low-bit quantization. Figures show SEQ’s advantage in ternary and 2-bit quantization, while LSQ with “0” slightly excels in 3 and 4-bit cases.

Figure 5: Comparison of quantization methods across different bit-widths

2.2 Quantization Function

Based on our analysis, we combine the optimal quantization functions identified for each bit-width into one formula, denoted as ParetoQ. This includes Elastic Binarization [1] for 1-bit quantization, LSQ [2] for 3 and 4-bit quantization, and the proposed SEQ for 1.58 and 2-bit quantization.

Here, k equals 3 in the ternary case and 2Nbit otherwise; n = 2Nbit-1 and p = 2Nbit-1 -1. In the backward pass, the gradients to the weights and scaling factor can be easily calculated using a straight-through estimator.

With ParetoQ, we present a robust comparison framework across five bit-widths (1-bit, 1.58-bit, 2-bit, 3-bit, 4-bit), each achieving state-of-the-art accuracy. This facilitates direct, apple-to-apple comparisons to identify the most effective bit-width selection.

3 Comparison with SoTA

3.1 Comparisons on 1.58-bit quantization

The figure below illustrates that ParetoQ consistently outperforms previous methods targeting ternary quantization aware training including Spectra [3] and 1-bit Era [4]. Given that a full-precision LLaMA-3 3B model achieves 69.9 accuracy, it’s remarkable that ParetoQ ternary 3B-parameter model narrows the gap to just 4.1 points, while previous methods experience drops exceeding 11.7 points.

Figure 6: Ternary quantization accuracy averaged across six tasks: ARC-e, ARC-c, BoolQ, PIQA, HellaSwag, and WinoGrande. ParetoQ consistently outperforms all prior methods in ternary quantization-aware training.

3.2 comparisons 2-bit / 3-bit / 4-bit quantization

As evidenced by Figure 1, compared to previous state-of-the-art PTQ and QAT methods on 2, 3 or 4-bit quantization settings, our approach consistently resides on the Pareto front, with a particularly pronounced advantage in lower-bit quantization settings. These results confirm that our bit-accuracy trade-off conclusions are benchmarked against SoTA results across all bit settings, ensuring its reliability.

Figure 7: Accuracy comparison on 8 models. ParetoQ outperforms all state-of-the-art PTQ and QAT methods in 2, 3, and 4-bit settings.

4 Pareto Curve

4-bit quantization-aware training (QAT) achieves near-lossless compression in many scenarios. With ParetoQ, we are able to further improve the trade-off curve. Figure (a) demonstrates that sub-4-bit quantization, including binary, ternary, 2-bit, and 3-bit, often surpasses 4-bit. Notably, 2-bit and ternary models reside on the Pareto frontier.

To evaluate potential speedup benefits beyond memory reduction, we utilize the High-Performance Low-Bit Operators for 2-bit quantization and compare the latency with 4-bit quantization. The curves in Figure8 (c) demonstrate that, within our experimental range, 2-bit quantized models consistently outperform 4-bit models in terms of accuracy-speed performance, positioning 2-bit quantization as a superior choice for on-device applications where both latency and storage are critical.

Figure 8: (a) (b) In sub-4-bit regime, 1.58-bit, 2-bit, and 3-bit quantization outperform 4-bit in terms of the accuracy-model size trade-off. (c) Under hardware constraints, 2-bit quantization demonstrates superior accuracy-speed trade-offs compared to higher-bit schemes.

5 GPU Latency

We measured the latency of LLaMA 3.2 models (1B, 3B, 8B) on an H100 NVL GPU (94GB memory). The W4A16 kernel used the Machete kernel from vLLM, while the W2A16 kernel was implemented based on the CUTLASS mixed precision backbone kernel. All tests were performed on a single GPU with a context length of 2048 tokens. For kernel-level latency, we compared the 2-bit kernel to the 4-bit Machete kernel across three weight shapes: (4096 x 4096), (8192 x 8192), and (16384 x 16384) on TritonBench. For larger size kernels, 2-bit can achieve ~24% speed up compared to the 4-bit Machete kernel.

Conclusion

In this study, we propose ParetoQ, an advanced quantization framework that achieves state-of-the-art performance across all bit-width levels. This framework uniquely enables a direct, consistent comparison across different bit-widths, ensuring an equitable evaluation of performance metrics. Our empirical analysis indicates that quantization at 1.58-bit, 2-bit, and 3-bit offers a superior trade-off between accuracy and effective quantized model size compared to 4-bit, highlighting their potential for optimized model deployment.

Feel free to try running ParetoQ from torchao/prototype/paretoq, following the steps in that repo. If you have any questions, feel free to reach out to Zechun Liu <zechunliu@meta.com>, Changsheng Zhao <cszhao@meta.com> Andrew Or <andrewor@meta.com> 

References

[1] BiT: Robustly Binarized Multi-Distilled Transformer.

[2] Learned Step Size Quantization.

[3] Spectra: A Comprehensive Study of Ternary, Quantized, and FP16 Language Models.

[4] The Era of 1-bit LLMs: All Large Language Models Are in 1.58 Bits

Read More

HuggingFace Safetensors Support in PyTorch Distributed Checkpointing

HuggingFace Safetensors Support in PyTorch Distributed Checkpointing

Summary 

PyTorch Distributed Checkpointing (DCP) is making investments into addressing the interoperability blockers to ensure that popular formats, like HuggingFace safetensors, can work well with PyTorch’s ecosystem. Since HuggingFace has become a leading format in inference and fine-tuning, DCP is beginning to support HuggingFace safetensors. The first customer of these changes is torchtune, who has seen an improved user experience as they can now cleanly read and write directly to HuggingFace with DCP APIs.

Problem

Since HuggingFace is used widely, with over 5 million users, many ML engineers would like to save and load their checkpoints in safetensors format to be able to easily work with their ecosystem. By supporting safetensors format natively in DCP, checkpointing is simplified for our users in the following ways:

  • DCP currently has its own custom format, so users who want to work with HuggingFace models, but leverage DCP’s performance wins and features, had to build custom converters and components so that they could work between both systems.
  • Instead of users having to download and upload their checkpoints to local storage every time, HuggingFace models can now be saved and loaded directly into the fsspec-supported storage of their choosing.

How to Use

From a user’s perspective, the only change needed to use safetensors is to call load with the new load planner and storage reader, and similarly save with the new save planner and storage writer.

The load and save APIs are called as follows:


load(
	state_dict=state_dict,
	storage_reader=HuggingFaceStorageReader(path=path),
)

save(
	state_dict=state_dict,
	storage_writer=HuggingFaceStorageWriter(
				path=path,
				fqn_to_index_mapping=mapping
			),
)

The HuggingFaceStorageReader and HuggingFaceStorageWriter can take any fsspec based path and so it can read/write in HF safetensors format to any fsspec supported back-end, including local storage and HF storage. Since HuggingFace safetensors metadata doesn’t natively provide the same level of information as DCP metadata, distributed checkpoints are currently not well-supported in these APIs, but DCP plans on supporting this natively in the future.

 

torchtune

Our first customer of HuggingFace DCP support is torchtune – a post-training library written in native PyTorch. The primary way torchtune users retrieve model weights is from the Hugging Face Hub. Before, users had to download the model weights and upload the trained checkpoints via extra CLI commands; the new DCP APIs allow them to directly read and write to HuggingFace, resulting in a much better user experience. 

In addition, the support of safetensor serialization in DCP greatly simplifies the checkpointing code in torchtune. No longer will there need to be format-specific checkpointing solutions, thus increasing developer efficiency in the project.

Future Work

DCP plans to handle the distributed loading and saving of HuggingFace safetensors checkpoints with resharding. DCP also plans to support the ability to produce a consolidated final checkpoint to a single file for publishing.

Read More