Skip to content

Commit bc1530b

Browse files
authored
Q dq layout (#1642)
* add q-dq layout for ET * up * up * up * up * up * up * up
1 parent 4df4d03 commit bc1530b

File tree

4 files changed

+249
-155
lines changed

4 files changed

+249
-155
lines changed

.github/workflows/torchao_experimental_test.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ jobs:
3535
conda activate venv
3636
pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.6.0.dev20250104"
3737
pip install numpy
38+
pip install pytest
3839
USE_CPP=1 pip install .
3940
- name: Run tests
4041
run: |
4142
conda activate venv
42-
python torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py
43+
pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

torchao/experimental/q_dq_layout.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
11+
from torchao.dtypes.affine_quantized_tensor import (
12+
AffineQuantizedTensor,
13+
register_layout,
14+
)
15+
from torchao.dtypes.affine_quantized_tensor_ops import (
16+
register_aqt_quantized_linear_dispatch,
17+
)
18+
19+
logger = logging.getLogger(__name__)
20+
logger.setLevel(logging.WARNING)
21+
22+
import sys
23+
24+
handler = logging.StreamHandler(sys.stdout)
25+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
26+
handler.setFormatter(formatter)
27+
logger.addHandler(handler)
28+
29+
30+
from torchao.dtypes.utils import PlainLayout
31+
32+
33+
class QDQLayout(PlainLayout):
34+
pass
35+
36+
37+
from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl
38+
39+
40+
@register_layout(QDQLayout)
41+
class _Impl(PlainAQTTensorImpl):
42+
pass
43+
44+
45+
def _linear_check(input_tensor, weight_tensor, bias):
46+
layout = weight_tensor.tensor_impl.get_layout()
47+
return isinstance(layout, QDQLayout)
48+
49+
50+
def _linear_impl(input_tensor, weight_tensor, bias):
51+
if isinstance(input_tensor, AffineQuantizedTensor):
52+
input_tensor = input_tensor.dequantize()
53+
if isinstance(weight_tensor, AffineQuantizedTensor):
54+
weight_tensor = weight_tensor.dequantize()
55+
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
56+
57+
58+
register_aqt_quantized_linear_dispatch(
59+
_linear_check,
60+
_linear_impl,
61+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import itertools
9+
import tempfile
10+
import unittest
11+
12+
import torch
13+
from torch.testing import FileCheck
14+
15+
from torchao.dtypes import PlainLayout
16+
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
17+
PackedLinearInt8DynamicActivationIntxWeightLayout,
18+
)
19+
from torchao.experimental.q_dq_layout import QDQLayout
20+
from torchao.experimental.quant_api import (
21+
int8_dynamic_activation_intx_weight,
22+
)
23+
from torchao.quantization.granularity import (
24+
PerGroup,
25+
PerRow,
26+
)
27+
from torchao.quantization.quant_api import quantize_
28+
from torchao.utils import unwrap_tensor_subclass
29+
30+
31+
class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
32+
def test_accuracy(self):
33+
"""
34+
Checks the accuracy of different layouts by comparing the results to PlainLayout()
35+
"""
36+
m = 1
37+
n = 1071
38+
k = 4096
39+
activations = torch.randn(m, k)
40+
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
41+
42+
reference_layout = PlainLayout()
43+
test_layouts = [
44+
PackedLinearInt8DynamicActivationIntxWeightLayout(),
45+
QDQLayout(),
46+
]
47+
test_weight_dtypes = [
48+
torch.int1,
49+
torch.int2,
50+
torch.int3,
51+
torch.int4,
52+
torch.int5,
53+
torch.int6,
54+
torch.int7,
55+
torch.int8,
56+
]
57+
test_has_weight_zeros = [True, False]
58+
test_granularities = [PerGroup(128), PerRow()]
59+
for layout, weight_dtype, has_weight_zeros, granularity in itertools.product(
60+
test_layouts, test_weight_dtypes, test_has_weight_zeros, test_granularities
61+
):
62+
quantized_model = copy.deepcopy(model)
63+
quantize_(
64+
quantized_model,
65+
int8_dynamic_activation_intx_weight(
66+
weight_dtype=weight_dtype,
67+
granularity=granularity,
68+
has_weight_zeros=has_weight_zeros,
69+
layout=layout,
70+
),
71+
)
72+
73+
quantized_model_reference = copy.deepcopy(model)
74+
quantize_(
75+
quantized_model_reference,
76+
int8_dynamic_activation_intx_weight(
77+
weight_dtype=weight_dtype,
78+
granularity=granularity,
79+
has_weight_zeros=has_weight_zeros,
80+
layout=reference_layout,
81+
),
82+
)
83+
84+
with torch.no_grad():
85+
result = quantized_model(activations)
86+
expected_result = quantized_model_reference(activations)
87+
self.assertTrue(torch.allclose(result, expected_result, atol=1e-6))
88+
89+
def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(
90+
self,
91+
):
92+
"""
93+
Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with
94+
torch.export.export, torch.compile, and AOTI.
95+
"""
96+
granularity = PerRow()
97+
m = 3
98+
k0 = 512
99+
k1 = 256
100+
k2 = 128
101+
k3 = 1024
102+
weight_dtype = torch.int4
103+
has_weight_zeros = True
104+
layers = [
105+
torch.nn.Linear(k0, k1, bias=False),
106+
torch.nn.Linear(k1, k2, bias=False),
107+
torch.nn.Linear(k2, k3, bias=False),
108+
]
109+
model = torch.nn.Sequential(*layers)
110+
activations = torch.randn(2, 1, m, k0, dtype=torch.float32)
111+
112+
quantize_(
113+
model,
114+
int8_dynamic_activation_intx_weight(
115+
weight_dtype=weight_dtype,
116+
granularity=granularity,
117+
has_weight_zeros=has_weight_zeros,
118+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
119+
),
120+
)
121+
eager_results = model(activations)
122+
123+
unwrapped_model = copy.deepcopy(model)
124+
unwrap_tensor_subclass(model)
125+
126+
# Export
127+
exported = torch.export.export(model, (activations,), strict=True)
128+
exported_results = exported.module()(activations)
129+
self.assertTrue(torch.allclose(eager_results, exported_results))
130+
131+
# Compile
132+
compiled = torch.compile(unwrapped_model)
133+
with torch.no_grad():
134+
compiled_results = compiled(activations)
135+
self.assertTrue(torch.allclose(eager_results, compiled_results))
136+
137+
# AOTI
138+
with tempfile.TemporaryDirectory() as tmpdirname:
139+
package_path = f"{tmpdirname}/model.pt2"
140+
torch._inductor.aoti_compile_and_package(
141+
exported, package_path=package_path
142+
)
143+
fn = torch._inductor.aoti_load_package(package_path)
144+
aoti_results = fn(activations)
145+
self.assertTrue(torch.allclose(eager_results, aoti_results))
146+
147+
def test_export_QDQLayout(self):
148+
"""
149+
Checks that models quantized with TestQDQLayout() export as expected
150+
"""
151+
granularity = PerGroup(64)
152+
weight_dtype = torch.int4
153+
has_weight_zeros = False
154+
layers = [
155+
torch.nn.Linear(512, 256, bias=False),
156+
]
157+
model = torch.nn.Sequential(*layers)
158+
activations = torch.randn(1, 512, dtype=torch.float32)
159+
160+
quantize_(
161+
model,
162+
int8_dynamic_activation_intx_weight(
163+
weight_dtype=weight_dtype,
164+
granularity=granularity,
165+
has_weight_zeros=has_weight_zeros,
166+
layout=QDQLayout(),
167+
),
168+
)
169+
eager_results = model(activations)
170+
171+
unwrap_tensor_subclass(model)
172+
exported = torch.export.export(model, (activations,), strict=True)
173+
exported_results = exported.module()(activations)
174+
self.assertTrue(torch.allclose(eager_results, exported_results))
175+
176+
expected_lines = [
177+
"torch.ops.quant.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int32, -128, 127, None, torch.float32, torch.int32)",
178+
"torch.ops.quant.quantize_affine.default(input_1, [1, 512], getitem, getitem_1, torch.int32, -128, 127)",
179+
"torch.ops.quant.dequantize_affine.default(quantize_affine, [1, 512], getitem, getitem_1, torch.int32, -128, 127)",
180+
"torch.ops.quant.dequantize_affine.default(p_fn_0_parametrizations_weight_original0, [1, 64], p_fn_0_parametrizations_weight_original1, None, torch.int32, -8, 7, 'NONE')",
181+
"torch.ops.aten.linear.default(dequantize_affine, dequantize_affine_1)",
182+
]
183+
for line in expected_lines:
184+
FileCheck().check_count(line, 1, exactly=True).run(
185+
exported.graph_module.code
186+
)

0 commit comments

Comments
 (0)