Skip to content

Commit 76c1cd2

Browse files
authored
bump torchao pin (#1318)
* bump torchao pin * update pin * update pin * merge conflict
1 parent 7d5ba09 commit 76c1cd2

File tree

5 files changed

+52
-47
lines changed

5 files changed

+52
-47
lines changed

.github/workflows/pull.yml

+6-27
Original file line numberDiff line numberDiff line change
@@ -1092,32 +1092,11 @@ jobs:
10921092
id: install-torchao-ops
10931093
run: |
10941094
bash torchchat/utils/scripts/build_torchao_ops.sh
1095-
- name: Set git shas
1096-
id: setup-hash
1097-
run: |
1098-
export TORCHCHAT_ROOT=${PWD}
1099-
echo "et-git-hash=$(cat ${TORCHCHAT_ROOT}/install/.pins/et-pin.txt)" >> "$GITHUB_ENV"
1100-
- name: Load or install ET
1101-
id: install-et
1102-
uses: actions/cache@v4
1103-
with:
1104-
path: |
1105-
./et-build
1106-
./torchchat/utils/scripts/install_et.sh
1107-
key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh') }}
1108-
- if: ${{ steps.install-et.outputs.cache-hit != 'true' }}
1109-
continue-on-error: true
1095+
- name: Install ET
11101096
run: |
11111097
echo "Installing ExecuTorch"
1098+
export TORCHCHAT_ROOT=${PWD}
11121099
bash torchchat/utils/scripts/install_et.sh
1113-
- name: Install ExecuTorch python
1114-
run: |
1115-
echo "Install ExecuTorch python"
1116-
export TORCHCHAT_ROOT=$PWD
1117-
export ET_BUILD_DIR="et-build"
1118-
ENABLE_ET_PYBIND="${1:-true}"
1119-
source "torchchat/utils/scripts/install_utils.sh"
1120-
install_executorch_python_libs $ENABLE_ET_PYBIND
11211100
- name: Install runner
11221101
run: |
11231102
echo "Installing runner"
@@ -1132,14 +1111,14 @@ jobs:
11321111
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
11331112
export PRMT="Once upon a time in a land far away"
11341113
echo "Generate eager"
1135-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1114+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11361115
echo "Generate compile"
1137-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile
1116+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
11381117
echo "Export and run ET (C++ runner)"
1139-
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1118+
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11401119
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
11411120
echo "Export and run AOTI (C++ runner)"
1142-
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1121+
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11431122
./cmake-out/aoti_run ./model.so -z ./tokenizer.model -t 0 -i "${PRMT}"
11441123
echo "Generate AOTI"
11451124
python torchchat.py generate stories110M --dso-path ./model.so --prompt "${PRMT}"

docs/quantization.md

+15-8
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,29 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
121121
## Experimental TorchAO lowbit kernels
122122

123123
### Use
124-
The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
124+
125+
#### linear:a8wxdq
126+
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
125127
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
126128
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
127129
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
128130

129-
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
131+
You should expect high performance on ARM CPU if groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
132+
133+
#### embedding:wx
134+
The quantization scheme embedding:wx quantizes embeddings in a groupwise manner with the specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize. Unlike linear:a8wxdq, embedding:wx always quantizes with scales and zeros.
135+
136+
You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
130137

131138
### Setup
132-
To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
139+
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
133140

134141
From the torchchat root directory, run
135142
```
136143
sh torchchat/utils/scripts/build_torchao_ops.sh
137144
```
138145

139-
This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat.
146+
This should take about 10 seconds to complete.
140147

141148
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
142149

@@ -156,17 +163,17 @@ Below we show how to use the new kernels. Except for ExecuTorch, you can specif
156163

157164
#### Eager mode
158165
```
159-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
166+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
160167
```
161168

162169
#### torch.compile
163170
```
164-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
171+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
165172
```
166173

167174
#### AOTI
168175
```
169-
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-dso llama3_1.so
176+
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so
170177
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
171178
```
172179

@@ -178,7 +185,7 @@ OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cac
178185

179186
#### ExecuTorch
180187
```
181-
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-pte llama3_1.pte
188+
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-pte llama3_1.pte
182189
```
183190

184191
Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command.

install/.pins/torchao-pin.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
49b1fb61c8b8eceda755579a2fd92c756d822de2
1+
c8f1174a06dcc0102849c8348ca6573bde8847a9

torchchat/utils/quantize.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,15 @@
4545
find_multiple,
4646
get_device_str,
4747
get_precision,
48+
set_precision,
4849
name_to_dtype,
4950
state_dict_device,
5051
use_et_backend,
5152
)
5253

5354

5455
# Flag for whether the a8wxdq quantizer is available.
55-
a8wxdq_load_error: Optional[Exception] = None
56+
torchao_experimental_load_error: Optional[Exception] = None
5657

5758
#########################################################################
5859
### handle arg validation ###
@@ -115,6 +116,13 @@ def quantize_model(
115116
if not support_tensor_subclass:
116117
unwrap_tensor_subclass(model)
117118
continue
119+
120+
if quantizer in ["linear:a8wxdq", "embedding:wx"]:
121+
# These quantizers require float32 input weights. Note that after quantization,
122+
# the weights will no longer be float32, but lowbit integers
123+
if get_precision() != torch.float32:
124+
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
125+
set_precision(torch.float32)
118126

119127
# We set global precision from quantize options if it is specified at cli.py:485
120128
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
@@ -887,24 +895,35 @@ def quantized_model(self) -> nn.Module:
887895

888896
try:
889897
import importlib.util
890-
import sys
891898
import os
899+
import sys
900+
892901
torchao_build_path = f"{os.getcwd()}/torchao-build"
893902

894903
# Try loading quantizer
895904
torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location(
896905
"torchao_experimental_quant_api",
897906
f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py",
898907
)
899-
torchao_experimental_quant_api = importlib.util.module_from_spec(torchao_experimental_quant_api_spec)
908+
torchao_experimental_quant_api = importlib.util.module_from_spec(
909+
torchao_experimental_quant_api_spec
910+
)
900911
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
901-
torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api)
902-
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
903-
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
912+
torchao_experimental_quant_api_spec.loader.exec_module(
913+
torchao_experimental_quant_api
914+
)
915+
from torchao_experimental_quant_api import (
916+
Int8DynActIntxWeightLinearQuantizer,
917+
IntxWeightEmbeddingQuantizer,
918+
)
919+
920+
quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
921+
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
904922

905923
# Try loading custom op
906924
try:
907925
import glob
926+
908927
libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/libtorchao_ops_aten.*")
909928
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
910929
torch.ops.load_library(libs[0])
@@ -915,8 +934,9 @@ def quantized_model(self) -> nn.Module:
915934
except Exception as e:
916935
class ErrorHandler(QuantHandler):
917936
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
918-
global a8wxdq_load_error
919-
raise Exception(f"Note: Failed to load torchao experimental a8wxdq quantizer with error: {a8wxdq_load_error}")
937+
global torchao_experimental_load_error
938+
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")
920939

921-
a8wxdq_load_error = e
940+
torchao_experimental_load_error = e
922941
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
942+
quantizer_class_dict["embedding:wx"] = ErrorHandler

torchchat/utils/scripts/install_utils.sh

+1-2
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ install_torchao_aten_ops() {
191191
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
192192
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \
193193
-DCMAKE_BUILD_TYPE="Release" \
194-
-DTORCHAO_OP_TARGET="aten" \
195194
-S . \
196195
-B ${CMAKE_OUT_DIR} -G Ninja
197196
cmake --build ${CMAKE_OUT_DIR} --target install --config Release
@@ -207,7 +206,7 @@ install_torchao_executorch_ops() {
207206
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
208207
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT_DIR} \
209208
-DCMAKE_BUILD_TYPE="Release" \
210-
-DTORCHAO_OP_TARGET="executorch" \
209+
-DTORCHAO_BUILD_EXECUTORCH_OPS=ON \
211210
-DEXECUTORCH_INCLUDE_DIRS="${EXECUTORCH_INCLUDE_DIRS}" \
212211
-DEXECUTORCH_LIBRARIES="${EXECUTORCH_LIBRARIES}" \
213212
-S . \

0 commit comments

Comments
 (0)