PyTorch 2.0 & XLA—The Latest Cutting Edge Features

Today, we are excited to share our latest work for PyTorch/XLA 2.0. The release of PyTorch 2.0 is yet another major milestone for this storied community and we are excited to continue to be part of it. When the PyTorch/XLA project started in 2018 between Google and Meta, the focus was on bringing cutting edge Cloud TPUs to help support the PyTorch community. Along the way, others in the community such as Amazon joined the project and very quickly the community expanded. We are excited about XLA’s direction and the benefits this project continues to bring to the PyTorch community. In this blog we’d like to showcase some key features that have been in development, show code snippets, and illustrate the benefit through some benchmarks.

TorchDynamo / torch.compile (Experimental)

TorchDynamo (Dynamo) is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. It provides a clean API for compiler backends to hook in; its biggest feature is to dynamically modify Python bytecode just before execution. In the PyTorch/XLA 2.0 release, an experimental backend for Dynamo is provided for both inference and training.

Dynamo provides a Torch FX (FX) graph when it recognizes a model pattern and PyTorch/XLA uses a Lazy Tensor approach to compile the FX graph and return the compiled function. To get more insight regarding the technical details about PyTorch/XLA’s dynamo implementation, check out this dev-discuss post and dynamo doc.

Here is a small code example of running ResNet18 with torch.compile:

import torch
import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.eval()
  dynamo_resnet18 = torch.compile(
      xla_resnet18, backend='torchxla_trace_once')
  for data, _ in loader:
    output = dynamo_resnet18(data)

With torch.compile PyTorch/XLA only traces the ResNet18 model once during the init time and executes the compiled binary everytime dynamo_resnet18 is invoked, instead of tracing the model every step. To illustrate the benefits of Dynamo+XLA, below is an inference speedup analysis to compare Dynamo and LazyTensor (without Dynamo) using TorchBench on a Cloud TPU v4-8 where the y-axis is the speedup multiplier.

Inference Speedup - PyTorch/XLA Dynamo on TPU

Dynamo for training is in the development stage with its implementation being at an earlier stage than inference. Developers are welcome to test this early feature, however, in the 2.0 release, PyTorch/XLA supports the forward and backward pass graphs and not the optimizer graph; the optimizer graph is available in the nightly builds and will land in the PyTorch/XLA 2.1 release. Below is an example of what training looks like using the ResNet18 example with torch.compile:

import torch
import torchvision
import torch_xla.core.xla_model as xm

def train_model(model, data, target):
  loss_fn = torch.nn.CrossEntropyLoss()
  pred = model(data)
  loss = loss_fn(pred, target)
  loss.backward()
  return pred

def train_model_main(loader):
  device = xm.xla_device()
  xla_resnet18 = torchvision.models.resnet18().to(device)
  xla_resnet18.train()
  dynamo_train_model = torch.compile(
        train_model, backend='aot_torchxla_trace_once')
  for data, target in loader:
    output = dynamo_train_model(xla_resnet18, data, target)

Note that the backend for training is aot_torchxla_trace_once (API will be updated for stable release) whereas the inference backend is torchxla_trace_once (name subject to change). We expect to extract and execute 3 graphs per training step instead of 1 training step if you use the Lazy tensor. Below is a training speedup analysis to compare Dynamo and Lazy using the TorchBench on Cloud TPU v4-8.

Training Speedup - PyTorch/XLA Dynamo on TPU

PJRT Runtime (Beta)

PyTorch/XLA is migrating from XRT to the new PJRT runtime. PJRT is a better-maintained stack, with demonstrated performance advantages, including, on average, a 35% performance for training on TorchBench 2.0 models. It also supports a richer set of features enabling technologies like SPMD. In the PyTorch/XLA 2.0 release, PJRT is the default runtime for TPU and CPU; GPU support is in experimental state. The PJRT features included in the PyTorch/XLA 2.0 release are:

  • TPU runtime implementation in libtpu using the PJRT Plugin API improves performance by up to 30%
  • torch.distributed support for TPU v2 and v3, including pjrt:// init_method (Experimental)
  • Single-host GPU support. Multi-host support coming soon. (Experimental)

Switching to PJRT requires no change (or minimal change for GPUs) to user code (see pjrt.md for more details). Runtime configuration is as simple as setting the PJRT_DEVICE environment variable to the local device type (i.e. TPU, GPU, CPU). Below are examples of using PJRT runtimes on different devices.

# TPU Device
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
# TPU Pod Device
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git"

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
# GPU Device (Experimental)
PJRT_DEVICE=GPU GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1

Below is a performance comparison between XRT and PJRT by task on TorchBench 2.0 on v4-8 TPU. To learn more about PJRT vs. XRT please review the documentation.

TorchBench Training Time

Parallelization

GSPMD (Experimental)

We are delighted to introduce General and Scalable Parallelization for ML Computation Graphs (GSPMD) in PyTorch as a new experimental data & model sharding solution. GSPMD provides automatic parallelization for common ML workloads, allowing developers to write PyTorch programs as if on a single large device and without custom sharded computation ops and/or collective communication ops. The XLA compiler transforms the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. The API (RFC) will be available in the PyTorch/XLA 2.0 release as an experimental feature on a single TPU VM host.

Next Steps for GSPMD

GSPMD is experimental in 2.0 release. To bring it to Stable status, we plan to address a number of feature gaps and known issues in the following releases, including multi-host support, DTensor integration, partial replication sharding, asynchronous data loading, and checkpointing.

FSDP (Beta)

PyTorch/XLA introduced fully sharded data parallel (FSDP) experimental support in version 1.12. This feature is a parallel representation of PyTorch FSDP and there are subtle differences in how XLA and upstream CUDA kernels are set up. auto_wrap_policy is a new argument that enables developers to automatically specify conditions for propagating partitioning specifications to neural network submodules. auto_wrap_policys may be simply passed in as an argument when wrapping a model with FSDP. Two auto_wrap_policy callables worth noting are: size_based_auto_wrap_policy, transformer_auto_wrap_policy.

size_based_auto_wrap_policy enables users to wrap submodules with a minimum number of parameters. The example below wraps model submodules having at least 10M parameters.

auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=1e7)

transformer_auto_wrap_policy enables users to wrap all submodules that match a specific layer type. The example below wraps model submodules named torch.nn.Conv2d. To learn more, review this ResNet example by Ronghang Hu.

auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})

PyTorch/XLA FSDP is now integrated in HuggingFace trainer class (PR) enabling users to train much larger models on PyTorch/XLA (official Hugging Face documentation). A 16B parameters GPT2 model trained on Cloud TPU v4-64 with this FSDP configuration achieved 39% hardware utilization.

TPU Accelerator – Num Devices v4-64
GPT2 Parameter Count 16B
Layers Wrapped with FSDP GPT2Block
TFLOPs / Chip 275
PFLOPs / Step 50
Hardware Utilization 39%

Differences Between FSDP & GSPMD

FSDP is a data parallelism technique that reduces device memory footprint by storing model parameters, optimizer states, and gradients all sharded. Note that the actual computation is still local to the device and requires all-gathering the sharded model parameters for both forward and backward passes, hence the name “data parallel”. FSDP is one of the newest additions to PyTorch/XLA to scale large model training.

GSPMD on the other hand, is a general parallelization system that enables various types of parallelisms, including both data and model parallelisms. PyTorch/XLA provides a sharding annotation API and XLAShardedTensor abstraction, so a user can annotate any tensor with sharding specs in the PyTorch program. Developers don’t need to manually implement sharded computations or inject collective communications ops to get it right. The XLA compiler does the work so that each computation can run in a distributed manner on multiple devices.

Examples & Preliminary Results

To learn about PyTorch/XLA parallelism sharding API, visit our RFC and see the Sample Code references. Below is a simple example to enable data and model parallelism.

model = SimpleLinear().to(xm.xla_device())
# Sharding annotate the linear layer weights.
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
# Training loop
model.train()
for step, (data, target) in enumerate(loader):
  optimizer.zero_grad()
  data = data.to(xm.xla_device())
  target = target.to(xm.xla_device())
  # Sharding annotate input data, we can shard any input
  # dimensions. Sharidng the batch dimension enables 
  # data parallelism, sharding the feature dimension enables
  # spatial partitioning.
  xs.mark_sharding(data, mesh, partition_spec)
  ouput = model(data)
  loss = loss_fn(output, target)
  optimizer.step()
  xm.mark_step()

The following graph highlights the memory efficiency benefits of PyTorch/XLA FSDP and SPMD on Cloud TPU v4-8 running ResNet50.

Batch Size Scaling with Spatial Partitioning

Closing Thoughts…

We are excited to bring these features to the PyTorch community, and this is really just the beginning. Areas like dynamic shapes, deeper support for OpenXLA and many others are in development and we plan to put out more blogs to dive into the details. PyTorch/XLA is developed fully open source and we invite you to join the community of developers by filing issues, submitting pull requests, and sending RFCs on GitHub. You can try PyTorch/XLA on a variety of XLA devices including TPUs and GPUs. Here is how to get started.

Congratulations again to the PyTorch community on this milestone!

Cheers,

The PyTorch Team at Google

Read More