You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/quantization.md
+15-8
Original file line number
Diff line number
Diff line change
@@ -121,22 +121,29 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
121
121
## Experimental TorchAO lowbit kernels
122
122
123
123
### Use
124
-
The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
124
+
125
+
#### linear:a8wxdq
126
+
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
125
127
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
126
128
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
127
129
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
128
130
129
-
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
131
+
You should expect high performance on ARM CPU if groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
132
+
133
+
#### embedding:wx
134
+
The quantization scheme embedding:wx quantizes embeddings in a groupwise manner with the specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize. Unlike linear:a8wxdq, embedding:wx always quantizes with scales and zeros.
135
+
136
+
You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
130
137
131
138
### Setup
132
-
To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
139
+
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
133
140
134
141
From the torchchat root directory, run
135
142
```
136
143
sh torchchat/utils/scripts/build_torchao_ops.sh
137
144
```
138
145
139
-
This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat.
146
+
This should take about 10 seconds to complete.
140
147
141
148
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
142
149
@@ -156,17 +163,17 @@ Below we show how to use the new kernels. Except for ExecuTorch, you can specif
156
163
157
164
#### Eager mode
158
165
```
159
-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
166
+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
160
167
```
161
168
162
169
#### torch.compile
163
170
```
164
-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
171
+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command.
# These quantizers require float32 input weights. Note that after quantization,
122
+
# the weights will no longer be float32, but lowbit integers
123
+
ifget_precision() !=torch.float32:
124
+
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
125
+
set_precision(torch.float32)
118
126
119
127
# We set global precision from quantize options if it is specified at cli.py:485
120
128
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
0 commit comments