You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -17,21 +17,29 @@ torchao just works with `torch.compile()` and `FSDP2` over most PyTorch models o
17
17
18
18
## Inference
19
19
20
+
Our optimizations deliver significant speedups and memory savings:
21
+
22
+
-**INT4 Weight-Only Quantization**: 2x higher throughput (201 vs 107 tokens/sec) with 65% less memory (4.9GB vs 13.9GB) on LLaMA-2-7B
23
+
-**Float8 Dynamic Quantization**: Demonstrates 53.88% speedup on Flux.1-Dev* and 27.33% speedup on CogVideoX-5b on H100 GPU while preserving image quality
24
+
-**INT4 + 2:4 Sparsity**: 2.4x throughput increase (226 vs 95 tokens/sec) with 80% memory reduction (5.3GB vs 16.4GB peak memory) on LLaMA-3-8B
25
+
26
+
For detailed benchmarks across models and techniques, see our [quantization documentation](torchao/quantization/README.md).
27
+
20
28
### Post Training Quantization
21
29
22
-
Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/), sparsity [here](/torchao/_models/sam/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py)
30
+
Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/README.md), sparsity [here](torchao/sparsity/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py)
23
31
24
32
For inference, we have the option of
25
33
1. Quantize only the weights: works best for memory bound models
26
34
2. Quantize the weights and activations: works best for compute bound models
27
-
2. Quantize the activations and weights and sparsify the weight
35
+
3. Quantize the activations and weights and sparsify the weight
28
36
29
37
```python
30
-
from torchao.quantization.quant_apiimport (
38
+
from torchao.quantization import (
31
39
quantize_,
32
40
int8_dynamic_activation_int8_weight,
41
+
float8_dynamic_activation_float8_weight,
33
42
int4_weight_only,
34
-
int8_weight_only
35
43
)
36
44
quantize_(m, int4_weight_only())
37
45
```
@@ -52,13 +60,14 @@ We also provide a developer facing API so you can implement your own quantizatio
52
60
53
61
We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.
54
62
55
-
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)
63
+
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md#kv-cache-quantization---memory-efficient-inference)
64
+
56
65
57
66
## Training
58
67
59
68
### Quantization Aware Training
60
69
61
-
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/). For more details, please see the [QAT README](./torchao/quantization/qat/README.md).
70
+
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with [Torchtune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
62
71
63
72
```python
64
73
from torchao.quantization import (
@@ -107,6 +116,8 @@ We've added support for semi-structured 2:4 sparsity with **6% end-to-end speedu
107
116
The code change is a 1 liner with the full example available [here](torchao/sparsity/training/)
108
117
109
118
```python
119
+
from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
1.`torch.compile`: A key design principle for us is composability as in any new dtype or layout we provide needs to work with our compiler. It shouldn't matter if the kernels are written in pure PyTorch, CUDA, C++, or Triton - things should just work! So we write the dtype, layout, or bit packing logic in pure PyTorch and code-generate efficient kernels.
134
-
3.[FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization.
142
+
## Installation
135
143
136
-
The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you.
144
+
`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch, see [getting started](https://pytorch.org/get-started/locally/) for more details.
137
145
138
-
## Custom Kernels
146
+
Install the stable release (recommended):
147
+
```bash
148
+
pip install torchao
149
+
```
139
150
140
-
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow
1.[fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
143
-
2.[2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
144
-
3.[int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
156
+
# Different CUDA versions
157
+
pip install torchao --index-url https://download.pytorch.org/whl/cu118 # CUDA 11.8
158
+
pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only
145
159
146
-
If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on [this issue](https://github.com/pytorch/ao/issues/697)
`torch.compile`: A key design principle for us is composability - any custom dtype or memory layout should work with our compiler. We enable kernel implementations in PyTorch, CUDA, C++, or Triton. This allows researchers and engineers to start with high-level dtype and layout logic in pure PyTorch, then progressively optimize performance by implementing lower-level kernels as needed, while maintaining compatibility with the compile infrastructure.
150
169
151
-
Things we're excited about but need more time to cook in the oven
170
+
[FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization.
152
171
153
-
1.[MX](torchao/prototype/mx_formats) training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
154
-
2.[Int8 Quantized Training](https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training): We're trying out full int8 training. This is easy to use with `quantize_(model, int8_weight_only_quantized_training())`. This work is prototype as the memory benchmarks are not compelling yet.
155
-
3.[IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype
156
-
4.[Bitnet](https://github.com/pytorch/ao/blob/main/torchao/prototype/dtypes/bitnet.py): Mostly this is very cool to people on the team. This is prototype because how useful these kernels are is highly dependent on better hardware and kernel support.
172
+
The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you.
157
173
158
-
## Installation
174
+
## Custom Kernels
159
175
160
-
`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.
176
+
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
161
177
162
-
Stable release from Pypi which will default to CUDA 12.1
178
+
1.[fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
179
+
2.[2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
180
+
3.[int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
163
181
164
-
```Shell
165
-
pip install torchao
166
-
```
182
+
If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on [this issue](https://github.com/pytorch/ao/issues/697) or feel free to contribute directly to the repo.
167
183
168
-
Stable Release from the PyTorch index
169
-
```Shell
170
-
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
171
-
```
172
184
173
-
Nightly Release
174
-
```Shell
175
-
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
176
-
```
185
+
## Prototype Features
177
186
178
-
For *most* developers you probably want to skip building custom C++/CUDA extensions for faster iteration
187
+
Check out our [prototype directory](torchao/prototype/README.md) where we experiment with cutting-edge model optimization techniques for both training and inference. If you're interested in contributing experimental research or want to explore feel free to open an issue or PR.
179
188
180
-
```Shell
181
-
USE_CPP=0 pip install -e .
182
-
```
183
189
184
190
## OSS Integrations
185
191
186
192
We're also fortunate to be integrated into some of the leading open-source libraries including
187
193
1. Hugging Face transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865)
188
-
2. Hugging Face diffusers best practices with torch.compile and torchao in a standalone repo [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao)
194
+
2. Hugging Face diffusers best practices with torch.compile and torchao in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md)
189
195
3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference)
190
-
4.[TorchTune](https://github.com/pytorch/torchtune) for our QLoRA and QAT recipes
196
+
4.[TorchTune](https://pytorch.org/torchtune/main/tutorials/qlora_finetune.html?highlight=qlora) for our QLoRA and QAT recipes
191
197
5.[torchchat](https://github.com/pytorch/torchchat) for post training quantization
192
-
6. SGLang for LLM serving: [usage](https://github.com/sgl-project/sglang/blob/4f2ee48ed1c66ee0e189daa4120581de324ee814/docs/backend/backend.md?plain=1#L83) and the major [PR](https://github.com/sgl-project/sglang/pull/1341).
198
+
6. SGLang for LLM serving: [usage](https://docs.sglang.ai/backend/server_arguments.html#server-arguments) and the major [PR](https://github.com/sgl-project/sglang/pull/1341).
193
199
194
200
## Videos
195
201
*[Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009)
0 commit comments