INT4 Decoding GQA CUDA Optimizations for LLM Inference

INT4 Decoding GQA CUDA Optimizations for LLM Inference

An efficient decoding Grouped-Query Attention with low-precision KV cache

Introduction

Generative AI has taken the world by storm with its ability to generate content like humans. Many of these generative AI tools are powered by large language models (LLMs), like Meta Llama models and OpenAI’s ChatGPT. One of the main challenges of LLMs is supporting large “context lengths” (also known as “sequence lengths”). The context length refers to the number of tokens that the model uses to understand the input context and generate responses. Longer context lengths generally translate into higher precision and quality in the responses. However, long context lengths are compute and memory intensive. This is mainly due to the following reasons:

  • The computational complexity of attention layers increases proportionally with the context length (the growth rate depends on the attention algorithm). As a result, when using long context lengths, the attention layers can become a bottleneck, particularly during the prefill phase where attentions are compute bound.
  • The KV cache size grows linearly with the context length, thus, putting higher pressure on the memory requirement and consequently slowing down the already memory-bound attention decoding. Moreover, since the memory capacity is limited, the batch size reduces when the KV cache gets bigger, which generally results in a drop in throughput.

The computational complexity growth is difficult to solve compared to the other problem mentioned above. One way to address the KV cache size growth problem is to use low precision KV cache. From our experiments, group-wise INT4 quantization provides comparable results in terms of accuracy compared to BF16 KV cache during the decode phase in Meta Llama 2 inference. However, we did not observe any latency improvement, despite reading 4x lesser data in attention decoding layers. This means that the INT4 attention is 4x less efficient at utilizing precious HBM bandwidth than BF16 attention.

In this note, we discuss the CUDA optimizations that we applied to INT4 GQA (grouped-query attention – the attention layer that we use in the LLM inference phase) to improve its performance by up to 1.8x on the NVIDIA A100 GPU and 1.9x on the NVIDIA H100 GPU.

  • The optimized CUDA INT4 GQA outperformed INT4 Flash-Decoding GQA (the best performing INT4 GQA that we used in the experiment mentioned above) by 1.4x-1.7x on A100 and 1.09x-1.3x on H100.
  • The optimized CUDA INT4 GQA performs better than BF16 Flash-Decoding GQA by 1.5x-1.7x on A100 and 1.4x-1.7x on H100.

Background

GQA for LLM Inference

Grouped-Query Attention (GQA) is a variant of multi-head attention (MHA) where each KV cache head is shared across a group of query heads. Our LLM inference adopts GQA as an attention layer in both the prefill and decode phases in order to reduce the capacity requirement for the KV cache. We use multiple GPUs in inference where the KV cache and query heads are distributed across GPUs. Each GPU runs an attention layer with a single KV head and a group of Q heads. Therefore, when viewed from a single GPU perspective, the GQA component can also be described as MQA (Multi-Query Attention).

The simplified workflow of decoding GQA is illustrated in Figure 1. GQA takes three main inputs: input query (denoted Q), K cache (denoted K), and V cache (denoted V). Our current GQA inference uses BF16 for Q, K, and V.

  • Q is a 4D BF16 tensor of shape (B, 1, HQ, D)
  • K is a 4D BF16 tensor of shape (B, Tmax, HKV, D)
  • V is a 4D BF16 tensor of shape (B, Tmax, HKV, D)

where

  • B is the batch size (the number of input prompts)
  • HQ is the number of query heads
  • HKV is the number of KV heads (HQ must be divisible by HKV)
  • Tmax is the maximum context length
  • D is the head dimension (fixed to 128)

GQA is simply bmm(softmax(bmm(Q, KT) / sqrt(D)), V). This yields a single output tensor (denoted as O) which is a 4D BF16 tensor that has the same shape as Q. Note that matrix multiplications are performed using BF16, however, accumulation and softmax are carried out in FP32. We call this “BF16 GQA” as the KV cache is BF16.

Figure 1: The simplified workflow of BF16 GQA for LLM inference

Figure 1 The simplified workflow of BF16 GQA for LLM inference

INT4 GQA

To further reduce the size of the KV cache, we explore the possibility of using INT4 for KV cache instead of BF16. We estimate the potential performance improvement by calculating the computational intensity (CI) of INT4 GQA and comparing it to that of BF16 GQA, as CI represents FLOPS per byte. We compute the CI for QKT and PV (as shown in Equation 1) as they take KV cache as an operand. Note that we disregard the Q load as it is negligible compared to the KV cache. We also ignore any intermediate data loads/stores that are not on global memory. Thus, the CI only takes into account the computation FLOPS and KV cache loads.

Equation 1

Equation (1)

Assuming that HQ = 8 and HKV = 1, CI for BF16 KV cache is 8 while CI for INT4 KV cache is 32. The CIs indicate that both BF16 and INT4 GQAs are memory bound (the peak CIs for BF16 tensor cores for A100 and H100 are 312 TF / 2 TB/s = 141 and 990 TF / 3.35 TB/s = 269; note that these TF numbers are without sparsity). Moreover, with INT4 KV cache, we should expect up to 4x performance improvement compared to BF16 GQA.

To enable INT4 KV cache support in GQA, we can dequantize the KV cache from INT4 to BF16 before passing it to the BF16 GQA operator. However, since KV cache is typically large, copying it from/to global memory can be costly. Moreover, decoding GQA is a memory bound operation (the memory unit is utilized much more heavily than the compute unit). Figure 2 shows the NCU profile of the FMHA CUTLASS BF16 GQA kernel in xFormers, which is one of the state of the art implementations of GQA. From the figure, it is obvious that memory is a bottleneck.

Figure 2: The NCU profile of the FMHA CUTLASS BF16 kernel in xFormers

Figure 2 The NCU profile of the FMHA CUTLASS BF16 kernel in xFormers

A more efficient alternative is to fuse INT4 dequantization with the GQA operation (shown in Figure 3). In other words, having GQA read INT4 KV cache directly and perform the INT4 to BF16 conversion within the kernel. This change can potentially reduce the amount of global memory reads required for the KV cache, which could lead to a decrease in latency. We call this “INT4 GQA.”

Figure 3: The workflow of fused INT4 GQA

Figure 3 The workflow of fused INT4 GQA

We list the state of the art implementations of GQA in the table below along with their features in Table 1.

Table 1 State of the art GQA implementations

Implementation Denote BF16 GQA Fused INT4 GQA
Flash-Decoding (Triton implementation) FD Yes Yes
Flash Attention (v2.3.3) FA Yes No
CUDA baseline CU Yes Yes

All implementations, except for CU, support both split-K and non split-K. CU only has the split-K implementation. Only FA has a heuristic in the backend to determine whether to run the split-K or non split-K kernel. For other implementations, users must explicitly choose which version to run. In this note, we focus on long context lengths (in our experiments, we use a context length of 8192) and therefore opt for the split-K version wherever possible.

As the baseline, we measured the performance of the state of the art GQA implementations on NVIDIA A100 and H100 GPUs. The latency (time in microseconds) and achieved bandwidth (GB/s) are reported in Table 2. Note that we ran a range of split-Ks (from 2 to 128 splits) and reported the best performance for each implementation. For all experiments, we use a context length of 8192. For INT4 GQA, we used row-wise quantization (i.e., num quantized groups = 1).

Table 2 Baseline GQA performance

On A100

Time (us) BF16 GQA INT4 GQA
Batch size FD FA CU FD FA CU
32 139 133 183 137 143
64 245 229 335 234 257
128 433 555 596 432 455
256 826 977 1127 815 866
512 1607 1670 2194 1581 1659
Effective Bandwidth (GB/s) BF16 GQA INT4 GQA
Batch size FD FA CU FD FA CU
32 965 1012 736 262 250
64 1097 1175 802 305 278
128 1240 968 901 331 314
256 1301 1100 954 351 331
512 1338 1287 980 362 345

On H100

Time (us) BF16 GQA INT4 GQA
Batch size FD FA CU FD FA CU
32 91 90 114 70 96
64 148 146 200 113 162
128 271 298 361 205 294
256 515 499 658 389 558
512 1000 1011 1260 756 1066
Effective Bandwidth (GB/s) BF16 GQA INT4 GQA
Batch size FD FA CU FD FA CU
32 1481 1496 1178 511 371
64 1815 1840 1345 631 443
128 1982 1802 1487 699 487
256 2087 2156 1634 736 513
512 2150 2127 1706 757 537

First, let’s discuss the BF16 GQA performance: CU ranks last in terms of performance among all implementations. FD and FA have comparable performance. When the batch size is less than or equal to 64, FA utilizes the split-K kernel and performs slightly better than FD. However, when the batch size is greater than 64, FD performs better.

The same trend holds true for INT4 GQAs. However, we did not measure the performance of FA as it does not support INT4 KV cache. FD outperforms CU for all cases.

When comparing the latencies of FD between BF16 and INT4 GQAs, we find that they are almost identical. This suggests that INT4 GQA is highly inefficient, which can be further confirmed by the significantly lower achievable bandwidth for INT4 GQA compared to BF16 GQA. The same trend is also true when looking at the performance of CU.

CUDA with Tensor Cores INT4 GQA Implementation

In this section, we briefly describe our baseline implementation which is CUDA with tensor cores INT4 GQA (CU). Each thread block processes only one KV head and a group of query heads from one input prompt. Therefore, each thread block performs mm(softmax(mm(Q, KT) / sqrt(D)), V); notice that mm is being performed not bmm. Moreover, since this is a split-K implementation, tokens in the KV cache are split among different thread blocks. Note that each thread block contains 4 warps (each warp contains 32 threads for NVIDIA A100 and H100 GPUs). Work in each thread block is split among warps. Within each warp, we use the WMMA API to compute matrix multiplication on tensor cores. Figure 4 demonstrates the work partitioning in CU.

Figure 4: CU work partitioning

Figure 4 CU work partitioning

Optimizing CUDA with Tensor Cores Kernel of INT4 GQA

In this note, we discuss the optimizations that we have applied to the CUDA with tensor cores implementation of INT4 GQA (CU). The ideal goal is to improve the INT4 GQA performance by 4 times based on the CI analysis in the previous section. Note that the query size is negligible compared to the KV cache size when the context length is long.

In our analysis, we used the NVIDIA Nsight Compute (NCU) as the main profiler. Our general bottleneck elimination approach is to minimize the stall cycles. We applied 10 optimizations to INT4 GQA, three of which are specific for NVIDIA A100/H100 GPUs. These optimizations are well known CUDA optimization techniques which can be generalized to many applications.

It is worth noting that the reason that we choose to optimize the CUDA implementation rather than the Flash-Decoding implementation (FD) (which is Triton based) is because with CUDA, we have a better control of how the low-level instructions are being generated. Many optimization techniques that we apply such as, operating on tensor core fragments directly (Optimizations 7-9), cannot be done through Triton since it does not expose low-level details to developers. However, these optimizations can be integrated into the compiler-based solution to make the optimizations available to broader operators, which is indeed a part of our future plan.

Optimization 1: Unroll K Loads

Problem Analysis:

The NCU profile shows that during K loading, there are only 2 global loads followed by memory stalls at dequantize_permuted_int4. The memory stalls are the long scoreboard stalls which indicates the waits for global memory access. This suggests that the kernel does not issue sufficient memory loads

to hide the global load latency. The kernel issues data loading, and then waits to consume the data immediately causing the global load latency to be exposed. The stalls are shown in Figure 5.

Figure 5: K loading before unrolling

Figure 5 K loading before unrolling (the numbers that the arrows point to are stall cycles caused by global memory wait)

Solution:

In the baseline implementation, we use uint32_t to load 8 INT4 K values in a single load and we perform 2 uint32_t loads in each iteration, which is 16 INT4 K values. To allow for a better global load latency hiding, we issue 8 uint32_t loads instead of two before consuming the K values in dequantize_permuted_int4. This allows the compiler to unroll the loads as well as reorder the instructions to hide the global load latency better. Figure 6 shows the NCU profile of K loading after unrolling. Comparing Figure 5 and Figure 6, we effectively reduce the stall cycles by unrolling the K loads.

Figure 6: K loading after unrolling

Figure 6 K loading after unrolling (the numbers that the arrows point to are stall cycles caused by global memory wait)

Results:

Table 3 Performance of Optimization 1 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 1 Baseline Opt 1
32 137 143 134 262 250 267 1.02 1.07
64 234 257 237 305 278 302 0.99 1.09
128 432 455 422 331 314 339 1.02 1.08
256 815 866 806 351 331 355 1.01 1.07
512 1581 1659 1550 362 345 369 1.02 1.07

Optimization 2: Improve P Type Casting (FP32->BF16)

Problem Analysis:

Since the product of softmax(bmm(Q, KT) / sqrt(D)) is FP32 (denoted as P in Figure 3), the kernel has to convert P from FP32 to BF16 before feeding it to the next bmm computation. The kernel performs the FP32 to BF16 conversion of P by copying the FP32 data from one location in shared memory to another location in shared memory. This causes stalls during the shared memory access (shown in Figure 7) which might be caused by (1) the shared memory indirection; and (2) the shared memory bank conflict since each thread accesses an 16-bit element (because of this, two threads can access the same memory bank simultaneously).

Figure 7: P type casting before Optimization 2

Figure 7 P type casting before Optimization 2 (the number that the arrow points to is stall cycles caused by shared memory wait)

Solution:

We use all threads in the thread block to do in-place type conversion. Each thread operates on two consecutive elements in order to avoid the shared memory bank conflict when storing BF16. All threads work on the same head (h) at the same time to guarantee correctness of the conversion. The in-place conversion steps are as follows:

  1. Each thread loads 2 FP32 token elements from the same head from the shared memory into registers
  2. Call __syncthreads() to make sure that every thread finishes reading the data
  3. Each thread converts its data to 2 BF16 token elements and then stores the results to the same shared memory

Some optimizations that we apply to the implementation:

  • Use vector types (especially nv_bfloat2)
  • Unroll data loading/storing, i.e., performing multiple loads before calling __syncthreads() and performing multiple stores after __syncthreads()

After this optimization, long stalls are not observed during P type casting as shown in Figure 8.

Figure 8: P type casting after Optimization 2

Figure 8 P type casting after Optimization 2 (the numbers that the arrow points to are stall cycles caused by shared memory wait)

Culprits:

Since we unroll data loading/storing by using registers as an intermediate storage, the number of registers per thread increases resulting in reduced occupancy.

Results:

Table 4 Performance of Optimization 2 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 2 Baseline Opt 2
32 137 143 126 262 250 285 1.09 1.14
64 234 257 221 305 278 324 1.06 1.16
128 432 455 395 331 314 362 1.09 1.15
256 815 866 749 351 331 382 1.09 1.16
512 1581 1659 1435 362 345 399 1.10 1.16

Optimization 3: Remove Local Memory Usage for max QKT computation

Problem Analysis:

During the softmax computation, the kernel has to compute max QKT for each head. It uses a temporary “thread-local” storage for storing per-thread max QKT results (one float value for each head). Depending on the compiler, the thread-local storage can be allocated on registers (on chip) or the local memory (off chip == global memory). Unfortunately, in the baseline, the thread-local storage resides in the local memory which is much slower than the registers (shown in Figure 9). We suspect that this is because the compiler cannot determine the indices of thread-local storage at compile time (since the number of heads (H) in the kernel is a runtime variable). Accessing local memory as if accessing registers can hurt the performance of the kernel.

Figure 9: Local memory access during max QKT computation

Figure 9 Local memory access during max QKT computation

Solution:

We realize that we do not need H (number of heads) floats as temporary storage per thread since each thread can compute max QKT for only one head instead of all the heads. Thus, we only need one float per thread, which can be easily stored in a register. To accumulate the max results among warps, we use shared memory. This optimization eliminates the local memory usage during max QKT computation.

Results:

Table 5 Performance of Optimization 3 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 3 Baseline Opt 3
32 137 143 119 262 250 300 1.14 1.20
64 234 257 206 305 278 348 1.14 1.25
128 432 455 368 331 314 389 1.17 1.24
256 815 866 696 351 331 411 1.17 1.24
512 1581 1659 1338 362 345 428 1.18 1.24

Optimization 4: Remove local memory usage for row sum

Problem Analysis:

Similar to Optimization 3, the local memory usage problem is also observed during the row sum computation in the softmax computation. Since local memory is off chip, accessing it as if accessing registers can hurt the performance of the kernel.

Solution:

We apply the same solution as the max QKT computation for the row sum computation. That is to have each thread compute a row sum of only one head, which requires only one float per thread. This eliminates the need for local memory.

Results:

Table 6 Performance of Optimization 4 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 4 Baseline Opt 4
32 137 143 118 262 250 302 1.15 1.21
64 234 257 204 305 278 351 1.15 1.26
128 432 455 364 331 314 393 1.19 1.25
256 815 866 688 351 331 416 1.18 1.26
512 1581 1659 1328 362 345 431 1.19 1.25

Optimization 5: Add prefetch for V load

Problem Analysis:

The same issue as K loading is observed when loading V. That is, the kernel issues data loading, and then waits to consume the data immediately causing the global load latency to be exposed. However, when using the unrolling technique mentioned above, the compiler allocates the temporary buffer on local memory instead of registers causing a large slow down.

Solution:

We adopt the data prefetching technique for V loading. We load the next iteration V values immediately after the current iteration values are consumed. This allows the data loading to be overlapped with the PK computation resulting in better kernel performance.

Results:

Table 7 Performance of Optimization 5 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 5 Baseline Opt 5
32 137 143 109 262 250 327 1.25 1.31
64 234 257 194 305 278 370 1.21 1.33
128 432 455 345 331 314 414 1.25 1.32
256 815 866 649 351 331 441 1.26 1.33
512 1581 1659 1244 362 345 460 1.27 1.33

Optimization 6: Add Group-Wise INT4 (Groups = 4) with Vector Load

Problem Analysis:

Prior to this optimization, CU only supported row-wise INT4 quantization. That is, every column in each row shares the same scales. The scales of each row are stored in the first 4 bytes of each row as shown in Figure 10. In the kernel, each thread loads only one row at a time. Since each row contains 68 bytes (4 bytes for scales and 64 bytes for data), it cannot guarantee that every row aligns with a size of any vector type. Thus, vector loads cannot be used for loading the KV cache.

Figure 10: The layout of each row of INT4 KV cache with row-wise quantization

Figure 10 The layout of each row of INT4 KV cache with row-wise quantization

Solution:

We have implemented support for group-wise INT4 quantization with num groups = 4. In this case, columns in each row in the KV cache tensor are divided into 4 equal groups. Columns within the same group share the same scales for quantization/dequantization. The data layout for INT4 KV cache is shown in Figure 11. The scales for all groups are serialized and stored at the beginning of each row. The INT4 data is also serialized and laid out next to the scales.

Because the number of bytes in each row now becomes 80 bytes, we can use a vector type, i.e., uint2 in our case, to load data. (We do not use uint4 since each thread loads only 16 INT4s at a time due to the tensor core fragment size.) Vector load is generally better than scalar load since it does not cause extra byte loads.

Figure 11: The layout of each row of INT4 KV cache with row-wise quantization

Figure 11 The layout of each row of INT4 KV cache with row-wise quantization

Results:

Table 8 Performance of Optimization 6 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 6 Baseline Opt 6
32 137 143 111 262 250 322 1.23 1.29
64 234 257 192 305 278 372 1.22 1.34
128 432 455 346 331 314 414 1.25 1.32
256 815 866 642 351 331 446 1.27 1.35
512 1581 1659 1244 362 345 460 1.27 1.33

Table 9 Performance of Optimization 6 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD
Opt 6 Opt 6
32 129 116 325 364 1.31
64 219 195 385 431 1.36
128 392 347 429 484 1.39
256 719 638 468 527 1.41
512 1375 1225 489 550 1.43

Optimization 7: Compute max QKT From WMMA Fragment Directly (A100/H100 specific)

Problem Analysis:

We observe large stalls due to shared memory accessing during the max QKT computation (showing as large short scoreboard stalls) as shown in Figure 12.

Figure 12: Stalls due to shared memory access during max QKT computation

Figure 12 Stalls due to shared memory access during max QKT computation (the number that the arrow points to is stall cycles caused by shared memory wait)

Solution:

We bypass shared memory when computing max QKT by computing it from the WMMA fragment (i.e., the tensor core fragment) directly. The layout of the WMMA fragment is specific to the GPU architecture. In this optimization, we only enabled this optimization for the NVIDIA A100/H100 GPUs. Other GPUs will still use shared memory for the max QKT computation. By bypassing shared memory, we effectively eliminate the stalls caused by shared memory access. The tensor core layout of the C fragment which is used for storing the QKT results is shown in Figure 13.

Figure 13: C fragment (QKT storage) tensor core layout on A100/H100

Figure 13 C fragment (QKT storage) tensor core layout on A100/H100

Table 10 Performance of Optimization 7 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 7 Baseline Opt 7
32 137 143 107 262 250 333 1.27 1.33
64 234 257 183 305 278 391 1.28 1.40
128 432 455 333 331 314 430 1.30 1.37
256 815 866 620 351 331 461 1.31 1.40
512 1581 1659 1206 362 345 475 1.31 1.38

Table 11 Performance of Optimization 7 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD vs CUDA_WMMA Opt 6
Opt 6 Opt 7 Opt 6 Opt 7
32 129 116 111 325 364 380 1.17 1.04
64 219 195 187 385 431 449 1.17 1.04
128 392 347 333 429 484 506 1.18 1.04
256 719 638 615 468 527 547 1.17 1.04
512 1375 1225 1184 489 550 569 1.16 1.03

Optimization 8: Write FP32->BF16 Results to P Fragment Directly (A100/H100 specific)

Problem Analysis:

During the FP32-BF16 conversion for the P fragment, the kernel loads the FP32 data from shared memory, does the conversion and then stores the BF16 data back to shared memory. Moreover, the conversion requires many thread block synchronizations (__syncthreads()).

Solution:

Due to the data partitioning design of the kernel, each warp performs only one pass through the P fragment. Thus, we do not have to write the conversion results back to the shared memory for future usage. To avoid writing the BF16 data to the shared memory and thread block synchronizations, we have each warp load the FP32 data of the P WMMA fragment from the shared memory, do the conversion and then write the BF16 data directly to the P fragment.

Note that this optimization is applied to only the NVIDIA A100 and H100 GPUs because the WMMA fragment layout is architecture dependent. For non-A100/H100 GPUs, the kernel will fallback to the original path.

The P fragment tensor core layout is shown in Figure 14. Note that this layout is specific to the NVIDIA A100/H100 GPU.

Figure 14: P fragment tensor core layout on A100/H100

Figure 14 P fragment tensor core layout on A100/H100

Table 12 Performance of Optimization 8 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 8 Baseline Opt 8
32 137 143 101 262 250 353 1.35 1.41
64 234 257 174 305 278 410 1.34 1.47
128 432 455 317 331 314 451 1.36 1.43
256 815 866 590 351 331 485 1.38 1.47
512 1581 1659 1143 362 345 501 1.38 1.45

Table 13 Performance of Optimization 8 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD vs CUDA_WMMA Opt 6
Opt 6 Opt 8 Opt 6 Opt 8
32 129 116 106 325 364 396 1.22 1.09
64 219 195 180 385 431 467 1.21 1.08
128 392 347 319 429 484 528 1.23 1.09
256 719 638 596 468 527 565 1.21 1.07
512 1375 1225 1138 489 550 591 1.21 1.08

Optimization 9: Swizzle P Shared Memory Layouts (A100/H100 specific)

Problem Analysis:

We observe large shared memory bank conflicts during P loading. The amount of bank conflict depends on the memory access stride. For instance, for split-Ks = 32 and max seq length = 8192, we observed that only 4 out of 32 banks are being accessed in parallel (memory access stride = 256). From Figure 14, when all threads access element 0, threads that have the same threadIdx.x % 4 access the same bank.

Figure 15: P fragment in shared memory before swizzling

Figure 15 P fragment in shared memory before swizzling

Solution:

We shuffle the layout of P load/store in the shared memory in such a way that avoids bank conflicts. In other words, we store the QKT results (C fragment) and load them (P fragment) using the swizzled layout. Moreover, instead of using the original memory access stride which is dependent on the number of tokens per thread block, we use the fragment’s column size as the stride which is constant. Thus, the load and store of the P fragment is always contiguous.

The new layouts for the C and P fragments are shown in Figure 16. With the new layout, it is guaranteed that 16 banks are being accessed in parallel as shown in Figure 17.

Figure 16: The swizzled layouts of C and P fragments

Figure 16 The swizzled layouts of C and P fragments

Figure 17: P fragment in shared memory after swizzling

Figure 17 P fragment in shared memory after swizzling

Table 14 Performance of Optimization 9 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 9 Baseline Opt 9
32 137 143 98 262 250 365 1.39 1.46
64 234 257 167 305 278 429 1.41 1.54
128 432 455 299 331 314 479 1.45 1.52
256 815 866 549 351 331 521 1.48 1.58
512 1581 1659 1060 362 345 540 1.49 1.56

Table 15 Performance of Optimization 9 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD vs CUDA_WMMA Opt 6
Opt 6 Opt 9 Opt 6 Opt 9
32 129 116 105 325 364 400 1.23 1.10
64 219 195 174 385 431 484 1.26 1.12
128 392 347 302 429 484 558 1.30 1.15
256 719 638 560 468 527 601 1.28 1.14
512 1375 1225 1065 489 550 632 1.29 1.15

Optimization 10: Pad Shared Memory for INT4 Dequantization

Problem Analysis:

Once the kernel reads the INT4 K or V cache from global memory, it performs dequantization and stores the results (BF16) in the shared memory. Then, the BF16 data is loaded to the WMMA fragment from shared memory (via the WMMA interface). We observed a large number of bank conflicts for both K and V accesses. For instance, for K stores, only 4 out of 32 banks are being accessed in parallel. For K loads, 16 banks are being accessed in parallel. The same also occurs for V stores and loads. See the figures in the solution section.

Solution:

We pad the shared memory to reduce the bank conflict. Specifically, we pad each row by 2. That is, the row stride of K becomes F_K + 2 and the row stride of V becomes F_N + 2 (F_K and F_N are the fixed widths of the K and V WMMA fragments, respectively). With this optimization, we are able to reduce the bank conflict by 1.8x as shown in Figure 18.

Figure 18: Bank conflicts before and after Optimization 10

Figure 18 Bank conflicts before and after Optimization 10

After Optimization 10, for K stores, 32 banks are being accessed in parallel (shown in Figure 19), while for K loads, 29 banks are accessed in parallel (shown in Figure 20).

Figure 19: K fragment store shared memory layout without and with padding

Figure 19 K fragment store shared memory layout without and with padding

Figure 20: K fragment load shared memory layout without and with padding

Figure 20 K fragment load shared memory layout without and with padding

Table 16 Performance of Optimization 10 for INT4 GQA (row-wise quantization)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CU FD CU vs FD vs CU baseline
Baseline Opt 10 Baseline Opt 10
32 137 143 94 262 250 380 1.45 1.52
64 234 257 151 305 278 475 1.55 1.71
128 432 455 266 331 314 538 1.63 1.71
256 815 866 489 351 331 586 1.67 1.77
512 1581 1659 930 362 345 616 1.70 1.79

Table 17 Performance of Optimization 10 for INT4 GQA (group-wise quantization with num groups = 4)

Batch size Time (us) Bandwidth (GB/s) Speed up
FD CUDA_WMMA FD CUDA_WMMA vs FD vs CUDA_WMMA Opt 6
Opt 6 Opt 10 Opt 6 Opt 10
32 129 116 99 325 364 425 1.31 1.17
64 219 195 161 385 431 523 1.36 1.21
128 392 347 282 429 484 598 1.39 1.23
256 719 638 509 468 527 662 1.41 1.25
512 1375 1225 965 489 550 698 1.43 1.27

Performance Evaluation

Microbenchmark results

We also evaluated BF16 GQA performance using our optimized kernel (as shown in Table 19). CU still performs generally worse than FD and FA for BF16. This is expected since our optimizations are INT4 focused.

While INT4 GQA is still not as efficient as BF16 GQA (see the achieved bandwidths), it is important to note that when comparing FD BF16 GQA performance against CU INT4 GQA performance, we can see that the latency of INT4 is smaller than that of BF16.

Table 19 Performance of BF16 GQA and INT GQA after CU optimizations

On A100

Time (us) BF16 GQA INT4 GQA
Batch size FD FA CU before CU after FD FA CU before CU after
32 139 133 183 163 137 143 94
64 245 229 335 276 234 257 151
128 433 555 596 517 432 455 266
256 826 977 1127 999 815 866 489
512 1607 1670 2194 1879 1581 1659 930
Effective Bandwidth (GB/s) BF16 GQA INT4 GQA
Batch size FD FA CU before CU after FD FA CU before CU after
32 965 1012 736 824 262 250 380
64 1097 1175 802 972 305 278 475
128 1240 968 901 1039 331 314 538
256 1301 1100 954 1075 351 331 586
512 1338 1287 980 1144 362 345 616

On H100

Time (us) BF16 GQA INT4 GQA
Batch size FD FA CU before CU after FD FA CU before CU after
32 91 90 114 100 70 96 64
64 148 146 200 183 113 162 101
128 271 298 361 308 205 294 170
256 515 499 658 556 389 558 306
512 1000 1011 1260 1066 756 1066 575
Effective Bandwidth (GB/s) BF16 GQA INT4 GQA
Batch size FD FA CU before CU after FD FA CU before CU after
32 1481 1496 1178 1341 511 371 560
64 1815 1840 1345 1470 631 443 710
128 1982 1802 1487 1743 699 487 844
256 2087 2156 1634 1934 736 513 935
512 2150 2127 1706 2015 757 537 996

E2E results

We evaluated our optimized INT4 GQA kernel in Llama 2 70B on 8 H100 GPUs. We ran the model end-to-end, but only reported the decode latency. We use FP8 FFN (feed forward network) to emphasize the attention performance in the decoding phase. We vary the batch size from 1 to 256 and the context length from 2,048 (2K) to 16,384 (16K). The E2E performance results are shown in the figure below.

Figure 21: Meta Llama 2 decode latency (ms) comparison

Figure 21 Meta Llama 2 decode latency (ms) comparison (BF16 GQA runs out of memory in large batch size configurations)

Code

If you are interested, please checkout our code here. If you have any questions, please feel free to open an issue on GitHub, and we will be happy to help. Your contributions are welcome!

Read More

Ready, Set, Contribute: PyTorch Docathon Kickoff H1 2024

The PyTorch Docathon is now live! This event is dedicated to enhancing the quality of the PyTorch documentation with the invaluable assistance of our community. Our hope with this Docathon is to simplify the process for new users to get started with PyTorch, guide them in effectively utilizing its features, and ultimately expedite the transition from research to production in machine learning.

JOIN THE KICK-OFF EVENT
on June 4th at 10 AM PT

Event Details

  • June 4: Kick-off – join a 30-minutes livestream kick off event on Discord on June 4th at 10 AM PT here. If you can’t join the kick-off event, watch our welcome video on YouTube
  • June 4-June 16: Submissions and Feedback
  • June 17-18: Final Reviews
  • June 20: Winner Announcements

How to Contribute

Review the Docathon H1 2024 issue in the pytorch/pytorch or pytorch/tutorials repo that contain all the necessary information on participating in the Docathon and highlights the specific issues to work on. Remember to sign the CLA in your first PR and adhere to the Code of Conduct guidelines.

Read the Code of Conduct

Take a moment to review the PyTorch code of conduct found here. This document outlines the expectations for behavior and communication within our team, and it is important that everyone is aware of and adheres to these guidelines.

Join our Discord

This channel serves as the main communication hub during the Docathon. You can join it using by using this link:

JOIN DISCORD SERVER

When you first join the server, you will have limited access. To gain full access to our Discord PyTorch Docathon Channel:

  1. Enter the server and navigate to the #self-roles channel.
  2. In the #self-roles channel, click on the ‘Join Docathon’ button in the relevant post to assign yourself the docathon role.
  3. After assigning the role, you will see the ‘PyTorch Docathon H1 2024 Section’ in the left-hand menu for discussions.
  4. To help prevent spam we are asking that you change your server username to your GitHub username or the email username you registered with.

Explore the GitHub Issues

All the Docathon issues are posted on GitHub. You can find them by the docathon-h1-2024 label in the following participating repositories:

The issues are categorized into three levels of difficulty: easy, medium, and advanced. If this is your first time contributing to PyTorch, we recommend starting with an issue at the easy level.

Prizes for Winners

We will have a leaderboard throughout the duration of the Docathon. The more you contribute, the higher you’ll get on the board! Our top three winners will get free admission to PyTorch Conference 2024.

Thank you to our Partners

This year, we’re thrilled to work with the PyTorch Teams at Meta, Google and Snowflake to help us put on a successful event. We’ll also be at Snowflake Dev Day on June 6 where you can hear from Meta’s Matthias Reso, and check out our PyTorch booth.

Happy contributing!

Read More

Maximizing Training Throughput Using PyTorch FSDP and Torch.compile

Maximizing Training Throughput Using PyTorch FSDP and Torch.compile

Recently, we demonstrated how FSDP and selective activation checkpointing can be used to achieve 57% MFU (Model Flops Utilization) for training a 7B model on A100 GPUs. We also demonstrated how it can train a high quality model, which we open sourced as Granite 7B base model on Hugging Face Hub under the Apache v2.0 license.

We continued our quest to improve the utilization of GPUs by leveraging torch.compile. Using torch.compile and the selective activation checkpointing from our previous work, we achieve a MFU of 68% for the 7B model on A100 GPUs! torch.compile improves training MFU between 10% and 23% for various model sizes.

This blog is organized into three parts: (1) Challenges addressed in order to train using torch.compile, (2) Numerical parity of compile with no-compile, and (3) MFU report.

We open sourced all the code and updated it in the fms-fsdp repository. We are also working with Team PyTorch at Meta to contribute these to the newly released torch titan repository for pre-training.

Challenges of using torch.compile

torch.compile is a graph compilation technique that improves GPU utilization. For details on how torch compile works, we refer the readers to the recent PyTorch paper and associated tutorials. A key challenge in getting torch.compile to perform well is to minimize (or eliminate) graph breaks. We initially started with the Llama implementation provided by Meta, but compiling it caused too many graph breaks resulting in reduced training throughput.

Several portions of the model architecture had to be fixed, with the most important one being the positional embedding layer (RoPE). The typical RoPE implementation uses complex numbers, which was not supported in torch.compile at the time of testing. We implemented RoPE using einops while maintaining parity with the original model architecture implementation. We had to properly cache the frequencies so that we did not run into graph breaks within the RoPE implementation.

Compiling an FSDP model does result in graph breaks, which the PyTorch team at Meta is working to remove. However, these graph breaks as of PyTorch 2.3 are at FSDP unit boundaries and do not affect throughput significantly.

When using custom kernels, we need to wrap each kernel by exposing its API to torch.compile. This involves indicating what parameters are modified in-place, how they are modified, and what shapes and strides will their return values have based on the inputs. In our case, SDPA Flash attention is already integrated appropriately and we were able to get that kernel to work with torch.compile with no graph breaks.

We also noticed that when increasing the amount of data from 2T to 6T tokens, the data loader became a bottleneck. A key reason for this is the fact that previously, we implemented document shuffling in our dataloader naively, by having each worker maintain a list of shuffled document pointers.

With the larger dataset, these pointer lists were growing to hundreds of thousands of entries per worker. Maintaining pointer lists at this scale became expensive enough that cpu contention throttled our training throughput. We re-implemented document shuffling without any pointer lists using a Linear Congruential Generator. LCG is a pseudorandom number generator algorithm that implements a random walk over a population, providing sampling without replacement.

We leveraged the same idea to produce implicit bijective mappings from ordered to shuffled document indices. This enables us to shrink those annoying lists of hundreds of thousands of pointers down to a single integer state for the LCG. This eliminated 80% of the bottleneck and provided a significant boost to our performance. We will devote a separate blog to go into all the details of our performant pre-training data loader.

Numerical Parity of torch.compile and torch.no-compile

We had previously observed parity issues when training with compile and no-compile options, with one of these being related to the use of SDPA. After a few days of intense debugging sessions between the PyTorch teams at Meta and IBM, we were able to achieve parity between PyTorch compile and no-compile modes. To document and verify this parity, we take a mini-Llama model architecture of 1.4B size and train it to 100B tokens in four variations – no-compile, compile with no activation checkpointing, compile with selective activation checkpointing, and compile with full activation checkpointing.

We plot the loss curves and gradient norm for these options below:

Figure 1: Loss curve and gradient norm for various compile options

Figure 1: Loss curve and gradient norm for various compile options

Further, we run the lm-evaluation-harness and compare the various model scores on different benchmarks and observe no major differences between compile and no-compile, which is shown below.

Figure 2: lm-evaluation-harness comparison of various benchmarks between compile and no-compile

Figure 2: lm-evaluation-harness comparison of various benchmarks between compile and no-compile

We observe from all these results that compile with all its variants is equal to no-compile option, thus demonstrating parity between compile and no-compile.

MFU report

Finally, like our previous blog, we compute the MFU for four different model sizes on two clusters. One cluster is 128 A100 GPUs with 400 Gbps inter-node connectivity, and the other is 464 H100 GPUs with 3.2 Tbps inter-node connectivity. We use the selective activation checkpointing that we covered in the prior blog in addition to compile. We capture the results in the table below.

Model size Batch size MFU no-compile MFU compile Percentage gain (%)
7B 2 0.57 0.68 20
13B 2 0.51 0.60 17
34B 2 0.47 0.54 15
70B 2 0.50 0.55 10

Table 1: MFU results with compile and no compile for Llama2 model architectures on 128 A100 80GB GPUs with 400Gbps internode interconnect

Model size Batch size MFU no-compile MFU compile Percentage gain
7B 2 0.37 0.45 21
13B 2 0.35 0.43 23
34B 2 0.32 0.38 19
70B 2 0.32 0.38 19

Table 2: MFU results with compile and no compile for Llama2 model architectures on 464 H100 80GB GPUs with 3.2Tbps internode interconnect

We also had an internal production run on 448 GPUs using a Llama2 7B architecture. Using compile and selective activation checkpointing, with a global batch size of 3.7M, we trained for 4T tokens in 13 days 10 hours!

During training, the data center cooling had to kick in with extra air conditioning and our training team was alerted to this, since we were using the GPUs quite effectively ☺

One key observation from the tables 1 and 2 is that the MFU numbers do not linearly scale with model size. There are two possible explanations that we are actively investigating, one is the scalability of FSDP as model size increases and when tensor parallel needs to be enabled to more effectively use the GPU and the other is batch size, which can be increased further to get better MFU. We plan to explore FSDP v2 and selective operator checkpointing along with the tensor parallel feature to study the scaling laws of FSDP with model size.

Future Work

We plan to start testing FSDP v2 which will be released as part of PyTorch 2.4. FSDP2 provides per parameter sharding and selective operator checkpointing feature that can potentially provide even better memory-compute tradeoffs.

We have also been engaged with the PyTorch team at Meta to evaluate the new asynchronous checkpointing feature that can further improve the GPU utilization by reducing the time to write checkpoints.

We are exploring extending various Triton kernels currently used in inference to perform backward operations to gain speedups beyond inference only.

Finally, as recent work on use of fp8 is emerging, we plan to explore how we can even further accelerate model training using the new data type that promises a 2x acceleration.

Acknowledgements

There are several teams that have been involved in reaching this proof point and we would like to thank the teams across Meta and IBM. Specifically, we extend our gratitude to the Meta PyTorch distributed and compiler teams and IBM Research.

Multiple people were extensively involved in the effort of achieving torch.compile numerical parity with our models, and we wish to acknowledge the key folks involved in this effort; Animesh Jain and Less Wright at Meta, and Linsong Chu, Davis Wertheimer, Brian Vaughan, Antoni i Viros Martin, Mudhakar Srivatsa, and Raghu Ganti at IBM Research.

Special thanks to Stas Bekman, who provided extensive feedback and helped improve this blog. Their insights have been invaluable in highlighting key aspects of optimizing the training and exploring further enhancements.

Read More

Achieving Sustainability Goals with PyTorch and Intel AI

Achieving Sustainability Goals with PyTorch and Intel AI

This post was contributed by Intel AI in partnership with the PyTorch Foundation.

In 2017, the UN Global Compact emphasized digital technology, particularly open source, as crucial for achieving Sustainable Development Goals (SDGs), projecting a potential $2.1 trillion boost to the tech sector by 2030. The SDGs, part of the “2030 Agenda for Sustainable Development,” address global prosperity across various sectors.

The Linux Foundation’s Sustainability Initiative aligns projects with sustainable development goals. By assessing project impact, resources can be better allocated for enhancement. Intel is also a contributor to this initiative, and recently presented three use cases with PyTorch and Intel AI to address UN SDG-aligned issues.

Sustainability Goals

SDG 15: Life on Land

  • Using a bone likelihood map to pinpoint dinosaur bones, which paves the way for transfer learning to tackle contemporary challenges like wildfire prediction.
  • Employing transfer learning for wildfire prediction and generating data with Stable Diffusion.

SDG 9: Industry, Innovation, Infrastructure

  • Identifying crucial minerals, oil, and gas through subsurface models.

Here are the key highlights from the workshops. Read below for a summary, and be sure to watch the full workshop videos and visit the GitHub repositories.

Session 1: Introduction to Dinosaur Bone Bed Maps

Bob Chesebrough recently led a PyTorch workshop demonstrating how to create a dinosaur bone bed map for Dinosaur National Monument. He shared footage of his discoveries and explained his AI-driven approach, utilizing geological data to pinpoint possible bone-rich areas.

Attendees learned to set up JupyterLab, access the training section, and launch a BASH shell. Bob’s classification model, applied to aerial images, facilitated heatmap generation to identify potential bone locations, refined through field data. The GitHub repo “Jurassic” guided participants through directory setup and model optimization steps.

Rahul Unnikrishnan Nair demonstrated the use of PyTorch, focusing on performance enhancements. The workshop covered modeling best practices, such as data transformations, class distribution, dropout layers, and efficient training methods. Training and scoring procedures were examined, with a focus on model accuracy and transportability to other regions. Heatmap creation involved cutting images into tiles, considering context for accurate environmental identification.

Watch the full workshop video here and visit the GitHub repository to access the code sample and experiment with the code using Intel ® Extension for PyTorch. Try it out with PyTorch and explore what works best for you. Happy dinosaur bone hunting!

Session 2: Seismic Data to Subsurface Models with OpenFWI: Training an AI Model with PyTorch

Seismic exploration is crucial for subsurface imaging in mineral and oil/gas exploration. Full waveform inversion (FWI) recreates subsurface sound wave velocities, akin to ultrasound for the Earth.

Ben Consolvo, an AI Software Engineering Manager at Intel, presented training AI models directly from seismic data using PyTorch on Intel high-performance processors. FWI, though accurate, is computationally intensive and relies on precise initial models. AI models offer an alternative approach, learning directly from data without the need for precise initializations. Ben explained the challenges of AI models, highlighting the need for diverse datasets and the potential use of CPUs for fine-tuning. He also discussed FWI’s surprising medical applications.

Watch the full video here and go to the paper for more details. The GitHub repo is OpenFWI.

Session 3: Using PyTorch to Aid Wildfire Prediction

Forest fires pose significant threats to ecosystems, wildlife, and communities. Machine learning presents a promising approach to enhance prediction accuracy. In this Earth Day webinar, Bob Chesebrough and Rahul Unnikrishnan Nair demonstrated image analysis techniques using the MODIS dataset which was used to predict early forest fire probabilities. Through fine-tuning a ResNet18 model with the Intel® Extension for PyTorch, pre-trained models were adjusted with aerial photos, utilizing geo-spatial and color data for fire risk assessment.

Emphasizing the temporal and geographical filtering requirements for dataset analysis, showcasing images from fire-affected areas like Paradise, CA, the model’s adaptability to different hardware configurations was highlighted, along with the utilization of Stable Diffusion for data synthesis when real datasets were unavailable. The presenters encouraged audience engagement in PyTorch experimentation for early fire detection by extending a challenge to leverage these tools for critical predictive tasks. Join them in this endeavor to enhance wildfire prevention and protection efforts.

Watch the full video here and go to the paper for more details. The GitHub repo is ForestFirePrediction.

About the Intel Speakers

Bob Chesebrough, Sr Solutions Architect

Bob Chesebrough’s industry experience is software development/AI solution engineering for Fortune 100 companies and national laboratories for over three decades. He is also a hobbyist who has logged over 800 miles and 1000 hours in the field finding dinosaur bones. He and his sons discovered an important fossil of the only known crocodilian from the Jurassic in New Mexico, they have also discovered and logged into the museum 2000+ bones localities and described a new mass bone bed in New Mexico.

Rahul Unnikrishnan Nair, Architect in Applied AI and the Engineering Lead at Intel® Liftoff

In his current role at Intel® Liftoff for Startups program, Rahul Nair brings his extensive experience in applied AI and engineering to mentor early-stage AI startups. His dedication lies in helping these startups transform their innovative ideas into fully-fledged, market-ready products with a strong emphasis on use-case-driven, practical engineering and optimization.

Ben Consolvo, AI Software Engineering Manager

Ben Consolvo is an AI Solutions Engineering Manager at Intel. He has been building a team and a program around Intel’s AI technology paired with Intel’s hardware offerings. He brings a background and passion in data science, particularly in deep learning (DL) and computer vision. He has applied his skills in DL in the cybersecurity industry to automatically identify phishing websites, as well as to the oil and gas industry to identify subsurface features for geophysical imaging.

Kelli Belcher, AI Solutions Engineer

Kelli Belcher is an AI Solutions Engineer at Intel with over 5 years of experience across the financial services, healthcare, and tech industries. In her current role, Kelli helps build Machine Learning solutions using Intel’s portfolio of open AI software tools. Kelli has experience with Python, R, SQL, and Tableau, and holds a Master of Science in Data Analytics from the University of Texas.

Read More

Speeding up ViTs using Block Sparsity

Speeding up ViTs using Block Sparsity

TLDR: We show promising results of up to a 1.46x speedup with <2% drop in accuracy on float32 Vision Transformers on A100 GPUs by applying block sparsity on MLP module’s weights. This approach can potentially be applied to other types of transformers including large language models. Our implementation and benchmarks to reproduce our results are available at https://github.com/pytorch-labs/superblock.

Introduction

PyTorch has landed a lot of improvements to CUDA kernels that implement block sparse matrix multiplications. Recent updates to Pytorch can lead up to 4.8x speedup on large matrix multiplication shapes with high sparsity levels over dense baselines.

In this blog, we show the promising results of applying block sparsity on weights of linear layers of MLP (multi-layer perceptron) layers in vision transformers (ViTs) and show end-to-end model speedups on A100 Nvidia GPUs.

As a recap, block sparsity sparsifies weights in tiles of blocks of predetermined size, rather than sparsifying individual elements. This particular sparsity pattern is interesting because it is amenable to GPU acceleration via fast sparse kernels. For more information about the differences between different sparsity patterns, or about sparsity as a whole, please check out torchao.

Illustrations of different types of sparsity.

Illustrations of different types of sparsity.

Approach

Our approach can be broken down into two distinct steps:

  1. Training the model from scratch using block sparse masks subnets.
  2. Folding these masks into our weights to accelerate them for inference.

We explain our training and inference steps below

Training

Starting with an uninitialized Vision Transformer, we apply random trainable masks with a specified block size and sparsity level on the weights of output projection linear layer of attention blocks, the weights of the two linear layers inside the MLP, a.k.a., FFN (feed forward networks), as well as the final linear classification layer. The forward pass during training follows the supermask approach, as each mask is converted to binary map using a tuned threshold based on sparsity requirements, e.g., if we want 80% sparsity, we will have the threshold automatically tuned to keep top 20% weights. The masks are of a square <block size>x<block size> elements, where <block size> is a hyperparameter. The priority of the weights is dependent on the mask value or score which is trained. We multiply the binary masks of each layer with the weights to sparsify the model.

Illustration of the Supermask sparsification approach

Illustration of the Supermask sparsification approach.

Inference

After training, the dense weights can be turned to sparse weights by multiplying with the mask and stored for inference. At this stage, although the weights have a high percentage of zero values, they are still stored in dense format. We use PyTorch’s to_sparse_bsr() API to to convert the weights to Block Sparse Representation (BSR) format that stores only the non-zero values and the indices of their blocks. This step only needs to be done once and the results can be cached for runtime.

During runtime, no changes in code are required. We just pass any input tensor to the model, and when the forward() function of the sparsified linear layers are invoked, PyTorch takes care of invoking the optimized matrix multiplication for block sparse weights. This should work for A100 as well as H100 NVIDIA GPUs.

Results: Microbenchmarks

To validate the viability of block sparsity from a performance standpoint, we first ran a series of microbenchmarks using this simple script. Using the linear shapes from ViT-b, we compared the speedup of our block sparse kernels across a single linear layer as we varied the sparsity level and block size of the weight matrix.

We run using PyTorch 2.3.0.dev20240305+cu121 nightly on NVIDIA A100s and report the speedup of each sparsity configuration compared to dense baseline. We observed positive speedups when block size >=32 or sparsity level >= 0.8 for float32, while for bfloat16 we observe smaller speedups and usually for block size 64 and higher sparsities. Hence, for end-to-end speedups on the model, we will focus in this blog on float32 and leave bfloat16 for future work.

Micro benchmarking results on linear layers of ViT-b-16.

Micro benchmarking results on linear layers of ViT-b-16.

Micro benchmarking results on linear layers of ViT-b-16.

Results: Vision Transformers

Once we confirmed that we were able to show speedups over the linear layers, we focused on showing end-to-end speedups on ViT_B_16.

We trained this model from scratch on ImageNet dataset using the standard ViT_B_16 recipe. We show speedups for sparsifying MLP modules and leave sparsifying weights of input and output projections of attention for future work.

We looked at wall-clock inference speedup, focusing on batch size 256. We found that:

  • For 90% sparsity we can get 1.24x, 1.37x, 1.65x speedups for block sizes 16, 32, and 64 respectively.
  • To obtain speedup, the minimum sparsity for block sizes 16, 32, and 64 are 0.86, 0.82, and 0.7 respectively. Hence, as expected, the larger the block size, the smaller sparsity we need to obtain speedup.

We note a limitation of the sparse_bsr() API: that layers need to be multiples of the block size. Since the dimensions of the last FC classification layer in ViT was not a multiple of the block size, they were not converted to BSR representation in our experiments.

Speedup on ViT-b-16 with batch size 256 on MLP modules across different batch sparsities and block sizes.

Speedup on ViT-b-16 with batch size 256 on MLP modules across different batch sparsities and block sizes.

We also explored the speedup for different batch sizes for 90% sparsity. We observed a speedup over the baseline for batch sizes starting from 16 and upwards. While bigger block sizes have bigger speedups at the largest batch sizes, the smallest possible batch size to obtain >1 speedup is smaller for smaller block sizes.

We believe on-device hardware can obtain speedups for batch size 1 as they – unlike server GPUs – can be fully utilized at such small batch sizes.

Speedup on ViT-b-16 with 90% sparsity on MLP modules across different batch sizes and block sizes.

Speedup on ViT-b-16 with 90% sparsity on MLP modules across different batch sizes and block sizes.

Looking at the Top-1 accuracy on ImageNet=blurred test set of the sparsified models for different block sizes and sparsities, we see a few expected results:

  • low levels of sparsity (<=70%) have no meaningful regression in accuracy
  • mid levels of sparsity (>=80% to <90%) have limited regression in accuracy
  • high levels of sparsity (>=90%) removes so many weights that accuracy is significantly impacted

More research could be done to improve accuracies of higher sparsities and larger block sizes. We hope that the block sparsity support in PyTorch and the illustrated speedups in this blog will encourage researchers to explore more accurate sparsification approaches.

Accuracies on training ViT-b-16 on ImageNet-blurred using the SuperMask approach.

Accuracies on training ViT-b-16 on ImageNet-blurred using the SuperMask approach.

Next Steps

We have shown promising speedups for block sparsifying MLP modules ViT in float32 precision. There is still more work to be done in order to observe speedups on bfloat16 and we hope to obtain progress on that soon. Possible next steps to further optimize block sparsity on vision transformers and transformers in general:

  • Perform block sparsity on attention input and output projections.
  • Perform block sparsity during finetuning rather than training from scratch.
  • Perform further optimizations on the matmul kernels for ViT’s linear operator specific shapes (especially for 80% and lower sparsity).
  • Combine with other optimizations such as int8 and torch.compile()
  • Explore other weight sparsification algorithms, e.g., Spartan, to improve accuracy
  • Explore selecting weights to sparsify (e.g., specific transformer layers)

Please reach out to melhoushi@meta.com if you have questions or are interested in contributing to block sparsification!

Additionally if you’re broadly interested in sparsity please feel free to reach out to @jcaip / jessecai@meta.com and please come check out torchao, a community we’re building for architecture optimization techniques like quantization and sparsity.

Read More

Deep Learning Energy Measurement and Optimization

Deep Learning Energy Measurement and Optimization

Zeus logo

This post is authored by Jae-Won Chung, a PhD student at the University of Michigan and the lead of the ML.ENERGY Initiative.

Deep learning consumes quite a bit of energy. For instance, training a single 200B LLM on AWS p4d instances consumed around 11.9 GWh (source: CIDR 2024 keynote), which is an amount that can single-handedly power more than a thousand average US households for a year.

Zeus is an open-source toolbox for measuring and optimizing the energy consumption of deep learning workloads. Our goal is to make energy optimization based on accurate measurements as easy as possible for diverse deep learning workloads and setups by offering composable tools with minimal assumptions.

Zeus largely provides two types of tools:

  1. Programmatic and command line GPU energy measurement tools
  2. Several energy optimization tools that find the best ML and/or GPU configurations

Zeus can benefit those who would like to

  • measure and optimize their electricity cost
  • reduce heat dissipation from their GPUs (by lowering power draw)
  • report energy usage from research and development
  • reduce carbon footprint from electricity usage

Part 1: Measuring Energy

Just like performance optimization, accurate measurement is the basis of effective energy optimization. Popular proxies for estimating power consumption like the maximum power draw of the hardware can sometimes be vastly off compared to actual measurement.

To make energy measurement as easy and transparent as possible, the core utility Zeus offers is the ZeusMonitor class. Let’s take a look at the actual snippet:

from zeus.monitor import ZeusMonitor

# All four GPUs are measured simultaneously.
monitor = ZeusMonitor(gpu_indices=[0,1,2,3])

# Measure total time and energy within the window.
monitor.begin_window("training")
for e in range(100):

    # Measurement windows can arbitrarily be overlapped.
    monitor.begin_window("epoch")
    for x, y in train_dataloader:
        y_hat = model(x)
        loss = criterion(y, y_hat)
        loss.backward()
        optim.step()
    measurement = monitor.end_window("epoch")
    print(f"Epoch {e}: {measurement.time} s, {measurement.total_energy} J")

measurement = monitor.end_window("training")
print(f"Entire training: {measurement.time} s, {measurement.total_energy} J")

<script src=”https://gist.github.com/jaywonchung/f580b782ff0513374c6fa507d5e072a8.js”></script>

What you see above is a typical PyTorch training loop which uses four GPUs for data parallel training. Inside, we created an instance of ZeusMonitor and passed in a list of GPU indices to monitor. Then, using the monitor, we can measure the time and energy consumption of arbitrary execution windows within the training script by pairing calls to begin_window and end_window. Multiple windows can overlap and nest in arbitrary ways without affecting the measurement of each, as long as their names are different.

ZeusMonitor adds very little overhead – typically single digit milliseconds – around the window. This allows ZeusMonitor to be used in various applications. For instance:

  • The ML.ENERGY Leaderboard: The first open-source benchmark on how much energy LLM text generation consumes.
  • The ML.ENERGY Colosseum: An online service that lets users compare LLM responses side-by-side based on response quality and energy consumption.

See our blog post for a deeper technical dive into accurate GPU energy measurement.

Part 2: Optimizing Energy

Let me introduce you to two of the energy optimizers provided by Zeus.

GlobalPowerLimitOptimizer

GPUs allow users to configure its maximum power draw, called power limit. Typically, as you lower the GPU’s power limit from the default maximum, computation may get slightly slower, but you’ll save disproportionately more energy. The GlobalPowerLimitOptimizer in Zeus automatically finds the optimal GPU power limit globally across all GPUs.

from zeus.monitor import ZeusMonitor
from zeus.optimizer.power_limit import GlobalPowerLimitOptimizer

# The optimizer measures time and energy through the ZeusMonitor.
monitor = ZeusMonitor(gpu_indices=[0,1,2,3])
plo = GlobalPowerLimitOptimizer(monitor)

for e in range(100):
    plo.on_epoch_begin()
    for x, y in train_dataloader:
        plo.on_step_begin()

        y_hat = model(x)
        loss = criterion(y, y_hat)
        loss.backward()
        optim.step()

        plo.on_step_end()
    plo.on_epoch_end()

<script src=”https://gist.github.com/jaywonchung/1922ddd56b15f8764f2bdacc4a441109.js”></script>

In our familiar PyTorch training loop, we have instantiated GlobalPowerLimitOptimizer and passed it an instance of the ZeusMonitor, through which the optimizer sees the GPUs. Then, we just need to let the optimizer know about training progress (step and epoch boundaries), and the optimizer will transparently do all the necessary profiling and converge to the optimal power limit.

If you’re using the HuggingFace Trainer or SFTTrainer, integration is even easier:

from zeus.monitor import ZeusMonitor
from zeus.optimizer.power_limit import HFGlobalPowerLimitOptimizer

# ZeusMonitor actually auto-detects CUDA_VISIBLE_DEVICES.
monitor = ZeusMonitor()
pl_optimizer = HFGlobalPowerLimitOptimizer(monitor)

# Pass in the optimizer as a Trainer callback. Also works for SFTTrainer.
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    ...,
    callbacks=[pl_optimizer],
)

<script src=”https://gist.github.com/jaywonchung/69aa379dd9633a6a486cede1887cec2c.js”></script>

The HFGlobalPowerLimitOptimizer wraps GlobalPowerLimitOptimizer so that it automatically detects step and epoch boundaries. We have example integrations here, including running Gemma 7B supervised fine-tuning with QLoRA.

Now, we know how to integrate the optimizer, but what is the optimal power limit? We know different users can have different preferences regarding trading off time and energy, so we allow users to specify an OptimumSelector (basically the Strategy Pattern) to express their needs.

# Built-in strategies for selecting the optimal power limit.
from zeus.optimizer.power_limit import (
    GlobalPowerLimitOptimizer,
    Time,
    Energy,
    MaxSlowdownConstraint,
)

# Minimize energy while tolerating at most 10% slowdown.
plo = GlobalPowerLimitOptimizer(
    monitor,
    MaxSlowdownConstraint(factor=1.1),
)

<script src=”https://gist.github.com/jaywonchung/1077b14bc7440b849be1f8320d4bf791.js”></script>

Some of the built-in strategies include “Minimize time” (Time, this might still reduce the power limit from the default since some workloads exhibit almost no slowdown even on lower power limits), “Minimize energy” (Energy), “Somewhere in between” (ZeusCost), and “Minimize energy given maximum slowdown” (MaxSlowdownConstraint). Users can also create their own optimum selectors as needed.

PipelineFrequencyOptimizer

The pipeline frequency optimizer, based on our research paper Perseus, is our latest work on energy optimization for large model training, like GPT-3. Perseus can reduce the energy consumption of large model training with no or negligible training throughput degradation. We’ll briefly talk about how.

one iteration of training with four stage pipeline parallelism

The above is a visualization of one iteration of training with four stage pipeline parallelism running with the 1F1B schedule. Each box is either a forward or a backward computation, and is colored with its power consumption.

The key observation here is that when models are partitioned into pipeline stages, it’s very difficult to slice them in perfectly equal sizes. This leads to forward/backward boxes of varying widths and therefore computation idle time between boxes. You would notice that those smaller boxes can run slightly slower than wider boxes and the overall critical path (blue line) will not change at all.

one iteration of training with four stage pipeline parallelism

That’s what Perseus automatically does. Based on profiling, it identifies computation boxes that are not on the critical path and figures out the precise amount of slowdown for each box that minimizes energy consumption. When done correctly, computations we slowed down will consume less power & energy, but the overall iteration time of the pipeline does not change.

See our guide to get started with Perseus!

Final Words

For users who run their own on-premise compute, energy consumption and the resulting electricity bill is not something that can be easily overlooked. On a larger scale, energy consumption is not just about electricity bills, but also about data center power delivery. With thousands of GPUs running in clusters, finding stable, affordable, and sustainable electricity sources to power data centers is becoming increasingly challenging. Finding ways to reduce energy disproportionately more than slowdown leads to lower average power consumption, which can help with the power delivery challenge.

With Zeus, we hope to take the first step towards deep learning energy measurement and optimization.

Wondering where to go from here? Here are a couple helpful links:

Read More

Introducing depyf: mastering torch.compile with ease

Introducing depyf: mastering torch.compile with ease

depyf logo

We are thrilled to introduce depyf, a new project to the PyTorch ecosystem designed to help users understand, learn, and adapt to torch.compile!

Motivation

torch.compile is a cornerstone of PyTorch 2.x, offering a straightforward path to accelerate machine learning workflows with just a single line of code for both training and inference. The mere inclusion of @torch.compile can dramatically enhance the performance of your code. However, identifying the optimal insertion point for torch.compile is not easy, not to mention the complexity of adjusting various knobs for maximum efficiency.

The intricacies of the torch.compile stack, encompassing Dynamo, AOTAutograd, Inductor, and more, present a steep learning curve. These components, essential for deep learning performance optimization, can be daunting without a solid foundation in the subject.

Note: For an introductory example of how torch.compile works, please refer to this walk-through explanation.

A common tool: TORCH_COMPILE_DEBUG

To demystify torch.compile, the common approach involves leveraging the TORCH_COMPILE_DEBUG environment variable. While it provides more information, deciphering the output remains a formidable task.

For example, when we have the following code:

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   main()

And run it with TORCH_COMPILE_DEBUG=1 python test.py , we will get a directory named torch_compile_debug/run_2024_02_05_23_02_45_552124-pid_9520 , under which there are these files:

.
├── torchdynamo
│   └── debug.log
└── torchinductor
   ├── aot_model___0_debug.log
   ├── aot_model___10_debug.log
   ├── aot_model___11_debug.log
   ├── model__4_inference_10.1
   │   ├── fx_graph_readable.py
   │   ├── fx_graph_runnable.py
   │   ├── fx_graph_transformed.py
   │   ├── ir_post_fusion.txt
   │   ├── ir_pre_fusion.txt
   │   └── output_code.py
   ├── model__5_inference_11.2
   │   ├── fx_graph_readable.py
   │   ├── fx_graph_runnable.py
   │   ├── fx_graph_transformed.py
   │   ├── ir_post_fusion.txt
   │   ├── ir_pre_fusion.txt
   │   └── output_code.py
   └── model___9.0
       ├── fx_graph_readable.py
       ├── fx_graph_runnable.py
       ├── fx_graph_transformed.py
       ├── ir_post_fusion.txt
       ├── ir_pre_fusion.txt
       └── output_code.py

The generated files and logs often raise more questions than they answer, leaving developers puzzled over the meaning and relationships within the data. Common puzzles for TORCH_COMPILE_DEBUG include:

  • What does model__4_inference_10.1 mean?
  • I have one function but three model__xxx.py in the directory, what is their correspondence?
  • What are those LOAD_GLOBAL stuff in debug.log ?

A better tool: depyf comes to rescue

Let’s see how depyf can help developers to resolve the above challenges. To use depyf , simply execute pip install depyf or follow the project page https://github.com/thuml/depyf to install the latest version, and then surround the main code within with depyf.prepare_debug .

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   import depyf
   with depyf.prepare_debug("depyf_debug_dir"):
       main()

After executing python test.py , depyf will produce a directory named depyf_debug_dir (the argument of the prepare_debug function). Under the directory, there would be these files:

.
├── __compiled_fn_0 AFTER POST GRAD 0.py
├── __compiled_fn_0 Captured Graph 0.py
├── __compiled_fn_0 Forward graph 0.py
├── __compiled_fn_0 kernel 0.py
├── __compiled_fn_3 AFTER POST GRAD 0.py
├── __compiled_fn_3 Captured Graph 0.py
├── __compiled_fn_3 Forward graph 0.py
├── __compiled_fn_3 kernel 0.py
├── __compiled_fn_4 AFTER POST GRAD 0.py
├── __compiled_fn_4 Captured Graph 0.py
├── __compiled_fn_4 Forward graph 0.py
├── __compiled_fn_4 kernel 0.py
├── __transformed_code_0_for_torch_dynamo_resume_in_toy_example_at_8.py
├── __transformed_code_0_for_toy_example.py
├── __transformed_code_1_for_torch_dynamo_resume_in_toy_example_at_8.py
└── full_code_for_toy_example_0.py

And there are two obvious benefits:

  1. The long and difficult-to-understand torchdynamo/debug.log is gone. Its content is cleaned up and shown as human-readable source code, in full_code_for_xxx.py and __transformed_code_{n}_for_xxx.py . It is worth to note, that the most tedious and difficult job of depyf is to decompile the bytecode inside torchdynamo/debug.log into Python source code, freeing developers from intimidating internals of Python.
  2. The correspondence between function names and computation graphs are respected. For example, in __transformed_code_0_for_toy_example.py , we can see a function named __compiled_fn_0 , and we will immediately know its corresponding computation graphs are in __compiled_fn_0_xxx.py , because they share the same __compiled_fn_0 prefix name.

Starting with full_code_for_xxx.py , and following the functions involved, users will have a clear view of what torch.compile does to their code.

One more thing: step-through debuggability

Stepping through code line by line using debuggers is a great way to understand how code works. However, under TORCH_COMPILE_DEBUG , those files are only for users’ information, and cannot be executed with the data users concern.

Note: By “debug”, we mean the process of inspecting and improving a program, rather than correcting buggy code.

A standout feature of depyf is its capability to facilitate step-through debugging for torch.compile: all of the files it generates are linked with runtime code objects inside Python interpreter, and we can set breakpoints in these files. The usage is simple, just add one context manager with depyf.debug() , and it should do the trick:

# test.py
import torch
from torch import _dynamo as torchdynamo
from typing import List

@torch.compile
def toy_example(a, b):
   x = a / (torch.abs(a) + 1)
   if b.sum() < 0:
       b = b * -1
   return x * b

def main():
   for _ in range(100):
       toy_example(torch.randn(10), torch.randn(10))

if __name__ == "__main__":
   import depyf
   with depyf.prepare_debug("depyf_debug_dir"):
       main()
   with depyf.debug():
       main()

Just one caveat: the workflow of debugging torch.compile deviates from standard debugging workflow. With torch.compile, many codes are dynamically generated. Therefore, we need to:

  1. launch the program
  2. when the program exits with depyf.prepare_debug("depyf_debug_dir") , code will be available in depyf_debug_dir.
  3. when the program enters with depyf.debug() , it will automatically set a breakpoint internally, so that the program is paused.
  4. navigate to depyf_debug_dir to set breakpoints.
  5. continue to run the code, and debuggers will hit these breakpoints!

depyf screenshot

Here is a screenshot of what it looks like. All code and tensor variables are live, and we can inspect any variable, and step through the code, as in our daily debugging workflow now! The only difference is that we are debugging torch.compile generated code rather than human-written code.

Conclusion

torch.compile serves as an invaluable tool for accelerating PyTorch code effortlessly. For those looking to delve deeper into torch.compile, whether to leverage its full potential or to integrate custom operations, the learning curve can be very steep though. depyf is designed to lower this barrier, offering a user-friendly experience to understand, learn, and adapt to torch.compile.

Do explore depyf and experience its benefits firsthand! The project is open-source and readily available at https://github.com/thuml/depyf. Installation is straightforward via pip install depyf. We hope depyf can enhance everyone’s development workflow with torch.compile.

Read More

Enhancing Deep Learning Workflows: PyTorch Ecosystem Tools

Welcome to the thriving PyTorch ecosystem, where a wealth of tools and libraries await, purpose-built to elevate your experience in deep learning as a developer or researcher. The Ecosystem Tools pages host many projects from experts spanning academia, industry, application development, and machine learning.

Initially, PyTorch aimed to establish a thriving community, enabling developers to access each other’s tools, engage in meaningful discussions, and explore the wealth of resources available within the community.

Today, the PyTorch ecosystem has grown to feature over 100 projects tailored to your needs, providing robust support, enhanced speed, and effortless integration with PyTorch. If your project aligns with our mission, we invite you to submit it and join this dynamic ecosystem.

New this month, we’ve moved all of our Ecosystem blogs over to our PyTorch.org website to host a space where our community can show off the latest innovations with our users. Read on to hear about the latest projects in the ecosystem!

Explore the Latest Tools and Frameworks in the Ecosystem

As we continue into 2024, we’re thrilled to showcase an impressive array of ecosystem tools that significantly enrich the PyTorch community. These tools cover a wide range of domains, including pose estimation, profiling, and even quantum computing. Let’s explore each one to witness firsthand how they are reshaping the PyTorch landscape, opening up exciting possibilities for developers.

Anomalib

Anomalib is a deep learning library that aims to collect state-of-the-art anomaly detection algorithms for benchmarking on both public and private datasets. Anomalib provides several ready-to-use implementations of anomaly detection algorithms described in the recent literature, as well as a set of tools that facilitate the development and implementation of custom models. The library has a strong focus on image-based anomaly detection, where the goal of the algorithm is to identify anomalous images, or anomalous pixel regions within images in a dataset. Anomalib is constantly updated with the latest algorithms and training/inference extensions.

Diffusers

Diffusers is a library within the PyTorch ecosystem that focuses on model interpretability. It offers a suite of tools and techniques to explain the decisions made by deep learning models. With Diffusers, developers can gain insights into model behavior, understand feature importance, and detect potential biases. By making deep learning models more transparent, Diffusers promotes fairness, accountability, and robustness in AI applications.

Pomegranate

Pomegranate is a versatile machine learning library that integrates seamlessly with PyTorch. It provides a wide range of probabilistic models and tools for probabilistic modeling tasks. Pomegranate empowers users to build complex models such as hidden Markov models (HMMs), Bayesian networks, and Gaussian mixture models (GMMs). By combining the strengths of PyTorch and Pomegranate, developers can leverage the power of deep learning and probabilistic modeling to tackle various machine learning challenges.

PyPose

PyPose is a PyTorch-based library designed for pose estimation tasks. With PyPose, developers can efficiently train and deploy models for human pose estimation, a fundamental computer vision problem. By leveraging PyTorch’s flexibility and performance, PyPose simplifies the process of building accurate pose estimation models. Its intuitive APIs and pre-trained models make it an excellent choice for researchers and developers exploring human pose estimation applications.

PyPOTS

A python toolbox/library for data mining on partially-observed time series with PyTorch, including SOTA models supporting tasks of imputation, classification, clustering, and forecasting on incomplete (irregularly-sampled) multivariate time series with missing values.

OctoML Profiler

OctoML Profiler is a performance profiling tool that aids in optimizing PyTorch models. This tool helps developers identify performance bottlenecks and inefficiencies within their deep learning models. By providing insights into memory usage, compute time, and data movement, the OctoML Profiler enables developers to fine-tune their models for improved efficiency. With this valuable feedback, developers can optimize their models for deployment on various hardware platforms.

Open Compass

OpenCompass is a one-stop platform for large model evaluation, aiming to provide a fair, open, and reproducible benchmark for large model evaluation. Its main features include: Comprehensive support for models and datasets, efficient distributed evaluation, diversified evaluation paradigms, modular design with high extensibility and experiment management and reporting mechanism.

Renate

Renate is a PyTorch-based library for neural architecture search (NAS). It simplifies the process of automatically searching for optimal neural network architectures tailored to specific tasks. Renate leverages techniques like reinforcement learning and evolutionary algorithms to efficiently explore the architecture space. By using Renate, developers can save significant time and resources while discovering highly performant models.

RoMa

RoMa is a standalone library to handle rotation representations with PyTorch (rotation matrices, quaternions, rotation vectors, etc). It aims for robustness, ease-of-use, and efficiency.

Substra

Substra is an open source federated learning (FL) software. It enables the training and validation of machine learning models on distributed datasets. It provides a flexible Python interface and a web application to run federated learning training at scale. Substra’s main usage is in production environments. It has already been deployed and used by hospitals and biotech companies. Substra can also be used on a single machine to perform FL simulations and debug code.

TorchQuantum

TorchQuantum is a powerful library that combines the PyTorch framework with quantum computing concepts. It enables developers to explore quantum machine learning algorithms and build hybrid classical-quantum models. By integrating the principles of quantum computing into PyTorch, TorchQuantum opens up new possibilities for solving complex problems that traditional deep learning approaches may struggle with.

TIAToolbox

The TIAToolbox (Text-Image-Augmentation Toolbox) is a PyTorch library designed to augment text and image data for deep learning tasks. It offers a comprehensive set of tools for data augmentation, including transformations, noise injection, and image/text synthesis. By applying TIAToolbox, developers can enrich their training datasets, improve model generalization, and enhance the robustness of their deep learning models.

torchdistill

torchdistill is a coding-free framework built on PyTorch for reproducible deep learning and knowledge distillation studies. The framework is designed to enable users to design experiments by declarative PyYAML configuration files and supports high-level module abstractions.

TorchOpt

TorchOpt is a PyTorch library focused on optimization algorithms for deep learning. It provides a collection of state-of-the-art optimization techniques, such as stochastic gradient descent (SGD) variants, adaptive learning rate methods, and optimization schedules. TorchOpt empowers developers to fine-tune their models efficiently, converge faster, and achieve better performance in various deep learning tasks.

USB

USB, or Unified Speech-to-Text Benchmark, is a PyTorch-based toolkit for training and evaluating speech recognition models. It provides standardized datasets and evaluation metrics to facilitate fair and accurate comparisons between different speech recognition architectures. By using USB, researchers and developers can benchmark their models against state-of-the-art systems and drive advancements in the field of automatic speech recognition.

Zeus

Zeus is the current state-of-the-art in deep learning energy measurement and optimization. It has monitor components that allow users to measure GPU energy consumption and optimizer components that automatically optimize DNN or GPU knobs based on measurements from the monitor component.

Be Part of Our Ecosystem

Our diverse ecosystem tools are instrumental in PyTorch’s success.. They provide essential support for tasks such as pose estimation, probabilistic modeling, performance profiling, model interpretability, speech recognition, quantum computing, data augmentation, optimization, and neural architecture search.

Leveraging these tools empowers developers and researchers to accelerate their deep learning workflows and unlock new possibilities in the field of AI.

Have a tool that would be a good fit for the PyTorch Ecosystem? If you can answer the below questions, we’d love for you to submit your tool for review.

  1. Does your project complement PyTorch, enhancing user experience, introducing new capabilities, or accelerating training and inference processes?
    • Examples could include visualization tools, a kernel library or a framework that sits on top to enable research in a particular area such as NLP.
  2. Is the project ready for broad developer usage?
    • For example, is the project stable, will it be maintained, and is there adequate supporting infrastructure, documentation, and technical support to allow a developer to successfully use it?

Thank you to all of our contributors and collaborators in our ecosystem! Here’s to a great 2024.

Read More

A Hitchhiker’s Guide to Speculative Decoding

A Hitchhiker’s Guide to Speculative Decoding

Speculative decoding is an optimization technique for inference that makes educated guesses about future tokens while generating the current token, all within a single forward pass. It incorporates a verification mechanism to ensure the correctness of these speculated tokens, thereby guaranteeing that the overall output of speculative decoding is identical to that of vanilla decoding. Optimizing the cost of inference of large language models (LLMs) is arguably one of the most critical factors in reducing the cost of generative AI and increasing its adoption. Towards this goal, various inference optimization techniques are available, including custom kernels, dynamic batching of input requests, and quantization of large models.

In this blog post, we provide a guide to speculative decoding and demonstrate how it can coexist with other optimizations. We are proud to open source the following, which includes the first speculator for Llama3 models:

  1. Speculator models for Meta Llama3 8B, IBM Granite 7B lab, Meta Llama2 13B, and Meta Code Llama2 13B.
  2. The code for inference via IBM’s fork of HF TGI.
  3. The code for training your own speculators and corresponding recipes.

We have deployed these speculators in an internal production-grade environment with thousands of daily users and observed 2x speedup on language models – Llama3 8B, Llama2 13B, and IBM Granite 7B and 3x speedup on IBM’s Granite 20B code models. We provide a detailed explanation of our approach in this technical report and are planning in-depth analysis in an upcoming ArXiv paper.

Speculative decoding: Inference

We run IBM TGIS in our internal production environment that has optimizations such as continuous batching, fused kernels, and quantization kernels. To enable speculative decoding in TGIS, we modified the paged attention kernel from vLLM. In what follows, we will describe the key changes to the inference engine to enable speculative decoding.

Speculative decoding is based on the premise that the model is powerful enough to predict multiple tokens in a single forward pass. However, the current inference servers are optimized to predict only a single token at a time. In our approach, we attach multiple speculative heads (in addition to the usual one) to the LLM to predict N+1-, N+2-, N+3-th … token. For example, 3 heads will predict 3 additional tokens. Details of the speculator architecture are explained in a later part of this blog. There are two challenges to achieve efficiency and correctness during inference – one is to predict without replicating KV-cache and the other is to verify that the predictions match the original model’s outcomes.

In a typical generation loop, after the prompt is processed in a single forward step, a sequence length of 1 (next token predicted) is fed into the forward pass of the model along with the kv-cache. In a naive speculative decoding implementation, each speculative head would have its own kv-cache, but instead we modify the paged attention kernel developed in the vLLM project to enable efficient kv-cache maintenance. This ensures that throughput does not reduce at larger batch sizes. Further, we modify the attention masks to enable verification of the N+1’th token and thus enable speculative decoding without deviating from the original model’s output. The details of this implementation are captured here.

Results

We illustrate the speedup obtained with the Meta’s chat versions of Llama2 13B using a simple prompt.

Visual illustration of the non-speculative generation (left) compared to speculative generation (right)

Figure 2: Visual illustration of the non-speculative generation (left) compared to speculative generation (right)

We deployed the above solution in an internal production environment. The figure below reports two metrics – time to first token (TTFT) and inter-token latency (ITL) with different numbers of concurrent users (which is captured in the numbers on the graph lines). We observe that the speculative decoding version is nearly twice as fast for the Llama2 13B chat model and nearly thrice as fast for the Granite 20B code model compared to the non-speculative version for all batch sizes. We observe similar behavior for the smaller models – IBM’s Granite 7B and Meta Llama3 8B models.

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Llama 13B with number of concurrent users indicated on the graph

Figure 3: Time to first token (TTFT – left) and Inter-token latency (ITL – right) for Llama 13B with number of concurrent users indicated on the graph

Time to first token (TTFT - left) and Inter-token latency (ITL - right) for Granite 20B Code with number of concurrent users indicated on the graph

Figure 4: Time to first token (TTFT – left) and Inter-token latency (ITL – right) for Granite 20B Code with number of concurrent users indicated on the graph

Note on efficiency

We performed numerous experiments to determine the right configuration for speculator training. These are:

  1. Speculator architecture: The current approach allows for the number of heads to be modified, which maps to the number of tokens that we can look ahead. Increasing the number of heads also increases the amount of extra compute needed and complexity of training. In practice, for language models, we find 3-4 heads works well in practice, whereas we found that code models can reap benefits from 6-8 heads.
  2. Compute: Increasing the number of heads results in increased compute in two dimensions, one is that of increased latency for a single forward pass as well as the compute needed for multiple tokens. If the speculator is not accurate with more heads, it will result in wasted compute increasing the latency and reducing the throughput.
  3. Memory: The increased compute is offset by the roundtrips to HBM that need to be done for each forward pass. Note that if we get 3 tokens lookahead correct, we have saved three round trip times on HBM.

We settled on 3-4 heads for the language models and 6-8 heads for the code models and across different model sizes ranging from 7B to 20B, we observed significant latency improvements without throughput loss compared to non-speculative decoding. We begin to observe throughput reduction beyond a batch size of 64, which happens rarely in practice.

Speculative decoding: Training

There are two broad approaches for speculative decoding, one is to leverage a smaller model (e.g., Llama 7B as a speculator for Llama 70B) and the other is to attach speculator heads (and train them). In our experiments, we find the approach of attaching speculator heads to be more effective both in model quality and latency gains.

Speculator architecture

Medusa made speculative decoding popular; their approach is to add a head to the existing model which is then trained to do speculation. We modify the Medusa architecture by making the “heads” hierarchical, where each head stage predicts a single token and then feeds it to the next head stage. These multi-stage heads are depicted in the below figure. We are exploring ways of minimizing the embeddings table by sharing these across the multiple stages and base model.

A simple architecture diagram for a 3-headed multi-stage  speculator. Z is the state from the base model.

Figure 4: A simple architecture diagram for a 3-headed multi-stage speculator. Z is the state from the base model.

Speculator training

We have a two-phase approach to training a speculator for efficiency reasons. In the first phase, we train on small batches with long sequence lengths (4k tokens) and use the standard causal LM approach for training. In phase 2, we use large batches with short sequence lengths (256 tokens) generated from the base model. In this training phase, we tune the heads to match the output of the base model. Through numerous experiments, we find that a 5:2 ratio of steps for phase 1 vs phase 2 works well. We depict the progress of these phases in the below figure. We use PyTorch FSDP and IBM FMS for the training of speculators.

Per-head training loss curves for Llama2-13B speculator training, phase 1 and 2

Figure 5: Per-head training loss curves for Llama2-13B speculator training, phase 1 and 2

Conclusion and Future Work

Through this blog, we are releasing a new approach for speculative decoding and the following assets:

  1. Models for improving the inter-token latencies for a range of models – Llama3 8B, Llama2 13B, Granite 7B, and CodeLlama 13B
  2. Production quality code for inference
  3. Recipes for training speculators

We are working on training speculators for Llama3 70B and Mistral models and invite the community to contribute as well as help improve on our framework. We would also love to work with major open source serving frameworks such as vLLM and TGI to contribute back our speculative decoding approach to benefit the community.

Acknowledgements

There are several teams that helped us get to these latency improvements for inference. We would like to thank the vLLM team for creating the paged attention kernel in a clean and reusable manner. We extend our gratitude to the Team PyTorch at Meta that helped provide feedback on this blog as well as continued efforts on optimal usage of PyTorch. Special thanks to our internal production teams at IBM Research who took this prototype to production and hardened it. A shout out to Stas Bekman for providing insightful comments on the blog resulting in an improved explanation of the tradeoffs between compute, memory, and speculator effectiveness.

The paged attention kernel was integrated into IBM FMS by Josh Rosenkranz and Antoni Viros i Martin. The speculator architecture and training was done by Davis Wertheimer, Pavithra Ranganathan, and Sahil Suneja. The integration of the modeling code with the inference server was done by Thomas Parnell, Nick Hill, and Prashant Gupta.

Read More

Announcing PyTorch Docathon June, 2024

We are thrilled to announce the upcoming PyTorch Docathon in June! The Docathon, akin to a hackathon, is an event dedicated to enhancing the quality of the PyTorch documentation with the invaluable assistance of our community. Documentation is a vital component of any technology. By refining it, we can simplify the process for new users to get started with PyTorch, guide them in effectively utilizing its features, and ultimately expedite the transition from research to production in machine learning. See our previous events here and here.

Why Participate

The Docathon is an inclusive event designed to be accessible to newcomers, requiring only a basic understanding of Python, PyTorch, and Machine Learning, with some tasks not even requiring these skills. It offers a rewarding experience as participants can see the direct impact of their contributions on the project’s usability and accessibility. The Docathon promotes a collaborative environment, allowing participants to work with other contributors and PyTorch maintainers, fostering the exchange of ideas and networking. It also provides a rich learning experience, offering the opportunity to explore PyTorch modules, update docstrings, and test tutorials.

Event Details

June 4: Kick-off
June 4-June 16: Submissions and Feedback
June 17-18: Final Reviews
June 20: Winner Announcements

Further details for the Docathon will be announced at the Kick-off call on June 4.

Please register to join this year’s event.

Read More