You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -19,19 +19,19 @@ torchao just works with `torch.compile()` and `FSDP2` over most PyTorch models o
19
19
20
20
### Post Training Quantization
21
21
22
-
Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/), sparsity [here](/torchao/_models/sam/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py)
22
+
Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/README.md), sparsity [here](torchao/sparsity/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py)
23
23
24
24
For inference, we have the option of
25
25
1. Quantize only the weights: works best for memory bound models
26
26
2. Quantize the weights and activations: works best for compute bound models
27
-
2. Quantize the activations and weights and sparsify the weight
27
+
3. Quantize the activations and weights and sparsify the weight
28
28
29
29
```python
30
30
from torchao.quantization.quant_api import (
31
31
quantize_,
32
32
int8_dynamic_activation_int8_weight,
33
+
float8_dynamic_activation_float8_weight
33
34
int4_weight_only,
34
-
int8_weight_only
35
35
)
36
36
quantize_(m, int4_weight_only())
37
37
```
@@ -52,11 +52,12 @@ We also provide a developer facing API so you can implement your own quantizatio
52
52
53
53
We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.
54
54
55
-
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)
55
+
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md#kv-cache-quantization---memory-efficient-inference)
56
+
56
57
57
58
### Quantization Aware Training
58
59
59
-
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
60
+
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with [Torchtune](https://github.com/pytorch/torchtune/blob/main/recipes/quantization.md#quantization-aware-training-qat), we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/)
60
61
61
62
```python
62
63
from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
@@ -96,6 +97,8 @@ We've added support for semi-structured 2:4 sparsity with **6% end-to-end speedu
96
97
The code change is a 1 liner with the full example available [here](torchao/sparsity/training/)
97
98
98
99
```python
100
+
from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
1.`torch.compile`: A key design principle for us is composability as in any new dtype or layout we provide needs to work with our compiler. It shouldn't matter if the kernels are written in pure PyTorch, CUDA, C++, or Triton - things should just work! So we write the dtype, layout, or bit packing logic in pure PyTorch and code-generate efficient kernels.
123
-
3.[FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization.
123
+
## Installation
124
124
125
-
The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you.
125
+
`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch, see [getting started](https://pytorch.org/get-started/locally/) for more details.
126
126
127
-
## Custom Kernels
127
+
Install the stable release (recommended):
128
+
```bash
129
+
pip install torchao
130
+
```
128
131
129
-
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow
1.[fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
132
-
2.[2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
133
-
3.[int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
137
+
# Different CUDA versions
138
+
pip install torchao --index-url https://download.pytorch.org/whl/cu121 # CUDA 12.1
139
+
pip install torchao --index-url https://download.pytorch.org/whl/cu118 # CUDA 11.8
140
+
pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only
134
141
135
-
If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on [this issue](https://github.com/pytorch/ao/issues/697)
`torch.compile`: A key design principle for us is composability - any custom dtype or memory layout should work with our compiler. We enable kernel implementations in PyTorch, CUDA, C++, or Triton. This allows researchers and engineers to start with high-level dtype and layout logic in pure PyTorch, then progressively optimize performance by implementing lower-level kernels as needed, while maintaining compatibility with the compile infrastructure.
139
151
140
-
Things we're excited about but need more time to cook in the oven
152
+
[FSDP2](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md): Historically most quantization has been done for inference, there is now a thriving area of research combining distributed algorithms and quantization.
141
153
142
-
1.[MX](torchao/prototype/mx_formats) training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
143
-
2.[Int8 Quantized Training](https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training): We're trying out full int8 training. This is easy to use with `quantize_(model, int8_weight_only_quantized_training())`. This work is prototype as the memory benchmarks are not compelling yet.
144
-
3.[IntX](https://github.com/pytorch/ao/tree/main/torchao/dtypes/uintx): We've managed to support all the ints by doing some clever bitpacking in pure PyTorch and then compiling it. This work is prototype as unfortunately without some more investment in either the compiler or low-bit kernels, int4 is more compelling than any smaller dtype
145
-
4.[Bitnet](https://github.com/pytorch/ao/blob/main/torchao/prototype/dtypes/bitnet.py): Mostly this is very cool to people on the team. This is prototype because how useful these kernels are is highly dependent on better hardware and kernel support.
154
+
The best example we have combining the composability of lower bit dtype with compile and fsdp is [NF4](torchao/dtypes/nf4tensor.py) which we used to implement the [QLoRA](https://www.youtube.com/watch?v=UvRl4ansfCg) algorithm. So if you're doing research at the intersection of this area we'd love to hear from you.
146
155
147
-
## Installation
156
+
## Custom Kernels
148
157
149
-
`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.
158
+
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
150
159
151
-
Stable release from Pypi which will default to CUDA 12.1
160
+
1.[fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
161
+
2.[2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
162
+
3.[int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
152
163
153
-
```Shell
154
-
pip install torchao
155
-
```
164
+
If you believe there's other CUDA kernels we should be taking a closer look at please leave a comment on [this issue](https://github.com/pytorch/ao/issues/697) or feel free to contribute directly to the repo.
156
165
157
-
Stable Release from the PyTorch index
158
-
```Shell
159
-
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
160
-
```
161
166
162
-
Nightly Release
163
-
```Shell
164
-
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
165
-
```
167
+
## Prototype Features
166
168
167
-
For *most* developers you probably want to skip building custom C++/CUDA extensions for faster iteration
169
+
Check out our [prototype directory](torchao/prototype/README.md) where we experiment with cutting-edge model optimization techniques for both training and inference. If you're interested in contributing experimental research or want to explore feel free to open an issue or PR.
168
170
169
-
```Shell
170
-
USE_CPP=0 pip install -e .
171
-
```
172
171
173
172
## OSS Integrations
174
173
175
174
We're also fortunate to be integrated into some of the leading open-source libraries including
176
175
1. Hugging Face transformers with a [builtin inference backend](https://huggingface.co/docs/transformers/main/quantization/torchao) and [low bit optimizers](https://github.com/huggingface/transformers/pull/31865)
177
-
2. Hugging Face diffusers best practices with torch.compile and torchao in a standalone repo [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao)
176
+
2. Hugging Face diffusers best practices with torch.compile and torchao in a standalone repo [diffusers-torchao](https://github.com/huggingface/diffusers/blob/main/docs/source/en/quantization/torchao.md)
178
177
3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference)
179
-
4.[TorchTune](https://github.com/pytorch/torchtune) for our QLoRA and QAT recipes
178
+
4.[TorchTune](https://pytorch.org/torchtune/main/tutorials/qlora_finetune.html?highlight=qlora) for our QLoRA and QAT recipes
180
179
5.[torchchat](https://github.com/pytorch/torchchat) for post training quantization
181
-
6. SGLang for LLM serving: [usage](https://github.com/sgl-project/sglang/blob/4f2ee48ed1c66ee0e189daa4120581de324ee814/docs/backend/backend.md?plain=1#L83) and the major [PR](https://github.com/sgl-project/sglang/pull/1341).
180
+
6. SGLang for LLM serving: [usage](https://docs.sglang.ai/backend/server_arguments.html#server-arguments) and the major [PR](https://github.com/sgl-project/sglang/pull/1341).
182
181
183
182
## Videos
184
183
*[Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009)
0 commit comments