|
| 1 | +--- |
| 2 | +layout: blog_detail |
| 3 | +title: "High-Performance Low-Bit Operators for PyTorch" |
| 4 | +author: Scott Roy, Digant Desai, Kimish Patel |
| 5 | +--- |
| 6 | + |
| 7 | +We are excited to announce the addition of embedding operators with low-bit weights (1-8 bit) and linear operators with 8-bit dynamically quantized activations and low-bit weights (1-8 bit) for Arm CPUs in TorchAO, PyTorch’s native low-precision library. These operators work seamlessly across all PyTorch surfaces, including eager, torch.compile, AOTI, and ExecuTorch, and are [available to use in torchchat](https://github.com/pytorch/torchchat/blob/main/docs/quantization.md#experimental-torchao-lowbit-kernels). |
| 8 | + |
| 9 | +In developing these linear operators, our focus was on **code sharing between PyTorch and ExecuTorch**, and establishing a clear boundary between the higher-level operator and the lower-level kernel. This design **allows third-party vendors to easily swap in their own kernels**. We also set out to **create a place and infrastructure to experiment** with new CPU quantization ideas and test those across the PyTorch ecosystem. |
| 10 | + |
| 11 | + |
| 12 | +## Universal low-bit kernels |
| 13 | + |
| 14 | +There is no hardware support for low-bit arithmetic. In what we call universal kernels, we explicitly separated the logic that unpacks low-bit values to int8 values, and the int8 GEMV kernel logic in a modular fashion. We started with an 8-bit kernel, for example, this [1x8 8-bit GEMV kernel](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h#L64) that uses the Arm neondot instruction. Within the 8-bit kernel, we invoke an [inlined unpacking routine](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot-impl.h#L169) to convert low-bit values into int8 values. This unpacking routine is force-inlined and templated on some low-bit value. Our experiments showed no performance difference between using a separate force-inlined unpacking routine and directly embedding the unpacking code inline. |
| 15 | + |
| 16 | +The advantage of this modular design is improved development speed and code maintainability. After writing an 8-bit kernel, we quickly achieved full low-bit coverage by writing [simple bitpacking routines](https://github.com/pytorch/ao/tree/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/bitpacking). In fact, developers who worked on the bit packing routines did not need to be experts on GEMV/GEMM kernel writing. We also reused the same bitpacking routines from the linear kernels [within the embedding kernels](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/embedding/embedding.h#L161). In future we could reuse the same bitpacking routines for universal GEMM kernels or kernels based on fma or i8mm instructions. |
| 17 | + |
| 18 | + |
| 19 | +## Shared code between PyTorch and ExecuTorch |
| 20 | + |
| 21 | +To achieve shared code between PyTorch and ExecuTorch, we wrote kernels [using raw pointers instead of PyTorch tensors](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/kernels/cpu/aarch64/linear/linear.h). Moreover, we implemented the [linear operator in a header ](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h#L259)that is included in separate [PyTorch](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp) and [ExecuTorch](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w4s.cpp) operator registration code. By using only features common to both ATen and ExecuTorch tensors, we ensured compatibility between the two frameworks. For multi-threaded compute, we introduced [torchao::parallel_1d](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/parallel.h#L13), which compiles to either [at::parallel_for](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/parallel-aten-impl.h) or [ExecuTorch’s threadpool](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/parallel-executorch-impl.h) based on compile-time flags. |
| 22 | + |
| 23 | + |
| 24 | +## Swappable kernels |
| 25 | + |
| 26 | +Our design for the higher-level multi-threaded linear operator is agnostic to the lower-level single-threaded kernels, allowing third-party vendors to swap in their own implementations. The interface between the operator and kernel is defined by a [ukernel config](https://github.com/pytorch/ao/blob/299aacd0ab0e0cce376f56e18e5bb585d517b2e1/torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h#L14), which specifies kernel function pointers for preparing activation data, preparing weight data, and running the kernel. The operator, responsible for tiling and scheduling, interacts with kernels solely through this config. |
| 27 | + |
| 28 | + |
| 29 | +## Performance |
| 30 | + |
| 31 | +In the table below, we show Llama3.1 8B token generation performance using 6 CPU threads on an M1 Macbook Pro with 32GB of RAM. |
| 32 | + |
| 33 | + |
| 34 | +<table class="table table-bordered"> |
| 35 | + <tr> |
| 36 | + <td><strong>Bitwidth x</strong> |
| 37 | + </td> |
| 38 | + <td><strong>torch.compile (Decode tokens/sec)</strong> |
| 39 | + </td> |
| 40 | + <td><strong>ExecuTorch (Decode tokens/sec)</strong> |
| 41 | + </td> |
| 42 | + <td><strong>ExecuTorch PTE size (GiB)</strong> |
| 43 | + </td> |
| 44 | + </tr> |
| 45 | + <tr> |
| 46 | + <td>1 |
| 47 | + </td> |
| 48 | + <td>24.18 |
| 49 | + </td> |
| 50 | + <td>17.86 |
| 51 | + </td> |
| 52 | + <td>1.46 |
| 53 | + </td> |
| 54 | + </tr> |
| 55 | + <tr> |
| 56 | + <td>2 |
| 57 | + </td> |
| 58 | + <td>27.02 |
| 59 | + </td> |
| 60 | + <td>19.65 |
| 61 | + </td> |
| 62 | + <td>2.46 |
| 63 | + </td> |
| 64 | + </tr> |
| 65 | + <tr> |
| 66 | + <td>3 |
| 67 | + </td> |
| 68 | + <td>21.01 |
| 69 | + </td> |
| 70 | + <td>22.25 |
| 71 | + </td> |
| 72 | + <td>3.46 |
| 73 | + </td> |
| 74 | + </tr> |
| 75 | + <tr> |
| 76 | + <td>4 |
| 77 | + </td> |
| 78 | + <td>19.51 |
| 79 | + </td> |
| 80 | + <td>19.47 |
| 81 | + </td> |
| 82 | + <td>4.47 |
| 83 | + </td> |
| 84 | + </tr> |
| 85 | + <tr> |
| 86 | + <td>5 |
| 87 | + </td> |
| 88 | + <td>14.78 |
| 89 | + </td> |
| 90 | + <td>16.34 |
| 91 | + </td> |
| 92 | + <td>5.47 |
| 93 | + </td> |
| 94 | + </tr> |
| 95 | + <tr> |
| 96 | + <td>6 |
| 97 | + </td> |
| 98 | + <td>12.80 |
| 99 | + </td> |
| 100 | + <td>13.61 |
| 101 | + </td> |
| 102 | + <td>6.47 |
| 103 | + </td> |
| 104 | + </tr> |
| 105 | + <tr> |
| 106 | + <td>7 |
| 107 | + </td> |
| 108 | + <td>8.16 |
| 109 | + </td> |
| 110 | + <td>11.73 |
| 111 | + </td> |
| 112 | + <td>7.48 |
| 113 | + </td> |
| 114 | + </tr> |
| 115 | +</table> |
| 116 | + |
| 117 | + |
| 118 | +Results were run on an M1 Macbook Pro (with 8 perf cores, and 2 efficiency cores) with 32GB of RAM and 6 threads [using torchchat](https://github.com/pytorch/torchchat). In each test, the max-seq-length of 128 tokens were generated. For each bit width x, the embedding layer was groupwise quantized to x-bits with group size 32. In the linear layers, activations were dynamically quantized per token to 8 bits and weights were groupwise quantized to x-bits with group size 256. Our focus here is performance and we do not report accuracy or perplexity numbers. Depending on the model, lower bit widths may require quantization-aware training, quantizing a model with a mixture of bit widths, or adjusting the group sizes for acceptable accuracy. |
| 119 | + |
| 120 | + |
| 121 | +{:style="width:100%"} |
| 122 | + |
| 123 | + |
| 124 | +## Try them out and contribute! |
| 125 | + |
| 126 | +If you want to see the new low-bit kernels in action, give them a try by [setting up torchchat](https://github.com/pytorch/torchchat/tree/main) and [quantizing and running an LLM locally using the kernels](https://github.com/pytorch/torchchat/blob/main/docs/quantization.md#experimental-torchao-lowbit-kernels). |
| 127 | + |
| 128 | +If you want to help contribute, consider adding support for one of the following areas: |
| 129 | + |
| 130 | +* [Add universal low-bit GEMM kernels](https://github.com/pytorch/ao/issues/1394) for Arm CPU, reusing the same bitpacking routines from the universal GEMV kernels. |
| 131 | +* [Improve runtime selection](https://github.com/pytorch/ao/issues/1376) of ukernel configs based on ISA, packing format, and activation shape. |
| 132 | +* Add low-bit kernels for other CPU ISAs like x86. |
| 133 | +* Integrate third-party libraries like [KleidiAI](https://gitlab.arm.com/kleidi/kleidiai) with the operator framework. |
0 commit comments