Skip to content

Commit 88cd9c7

Browse files
authored
Copy and adapt pt2e quantization code to torchao (#2048)
* Copy and adapt pt2e quantization code to torchao Summary: First step of dev-discuss.pytorch.org/t/torch-ao-quantization-migration-plan/2810 core logic of pt2e are duplicated, and also ported test here previous unfinished migration: #1916 Next: move meta internal callsites to depend on torchao Docs will be migrated separately Test Plan: pytest test/quantization/pt2e_flow Reviewers: Subscribers: Tasks: Tags: * fix import * update * import * import * import version check * import * guard pytorch version for op * torchao_quant namespace * rename namespace to torchao * trigger CI * working around 'test_int4.quantize_per_tensor_int4' has no overload name 'default' issue * fix * debug why the test failed * skip test due to out dated nightly
1 parent 663a95d commit 88cd9c7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+24653
-28
lines changed

Diff for: test/integration/test_integration.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2043,8 +2043,8 @@ def forward(self, x):
20432043
self.assertTrue(torch.equal(after_export, ref))
20442044
if api is _int8da_int4w_api:
20452045
targets = [n.target for n in model.graph.nodes]
2046-
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
2047-
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)
2046+
self.assertTrue(torch.ops.torchao.choose_qparams_affine.default in targets)
2047+
self.assertTrue(torch.ops.torchao.quantize_affine.default in targets)
20482048
self.assertFalse(torch.ops.aten.narrow.default in targets)
20492049

20502050

Diff for: test/quantization/pt2e/test_duplicate_dq.py

+324
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Owner(s): ["oncall: quantization"]
8+
# ruff: noqa: F841
9+
import copy
10+
import unittest
11+
from typing import Any
12+
13+
import torch
14+
from torch.testing._internal.common_quantization import QuantizationTestCase
15+
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
16+
17+
from torchao.quantization.pt2e.observer import (
18+
HistogramObserver,
19+
MinMaxObserver,
20+
PlaceholderObserver,
21+
)
22+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
23+
from torchao.quantization.pt2e.quantizer import (
24+
QuantizationAnnotation,
25+
QuantizationSpec,
26+
Quantizer,
27+
SharedQuantizationSpec,
28+
)
29+
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
30+
get_symmetric_quantization_config,
31+
)
32+
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import (
33+
OP_TO_ANNOTATOR,
34+
QuantizationConfig,
35+
)
36+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7
37+
38+
if TORCH_VERSION_AT_LEAST_2_5:
39+
from torch.export import export_for_training
40+
41+
42+
class TestHelperModules:
43+
class Conv2dWithObsSharingOps(torch.nn.Module):
44+
def __init__(self) -> None:
45+
super().__init__()
46+
self.conv = torch.nn.Conv2d(3, 3, 3)
47+
self.hardtanh = torch.nn.Hardtanh()
48+
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
49+
self.linear = torch.nn.Linear(3, 3)
50+
51+
def forward(self, x):
52+
x = self.conv(x)
53+
x = self.adaptive_avg_pool2d(x)
54+
x = self.hardtanh(x)
55+
x = x.view(-1, 3)
56+
x = self.linear(x)
57+
return x
58+
59+
class Conv2dWithSharedDQ(torch.nn.Module):
60+
def __init__(self) -> None:
61+
super().__init__()
62+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
63+
self.conv2 = torch.nn.Conv2d(3, 3, 1)
64+
self.linear = torch.nn.Linear(3, 3)
65+
66+
def forward(self, x):
67+
x = self.conv1(x)
68+
z = x.view(-1, 3)
69+
w = self.linear(z)
70+
71+
y = self.conv2(x)
72+
add_output = x + y
73+
74+
extra_output = x * 2
75+
return w, add_output, extra_output
76+
77+
class ModuleForDifferentQconfig(torch.nn.Module):
78+
def __init__(self) -> None:
79+
super().__init__()
80+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
81+
self.conv2 = torch.nn.Conv2d(3, 3, 1)
82+
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
83+
84+
def forward(self, x):
85+
x = self.conv1(x)
86+
w = self.adaptive_avg_pool2d(x)
87+
88+
y = self.conv2(x)
89+
add_output = x + y
90+
91+
extra_output = x + 2
92+
return w, add_output, extra_output
93+
94+
95+
_DEQUANTIZE_OPS = [
96+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
97+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
98+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
99+
]
100+
101+
102+
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
103+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
104+
class TestDuplicateDQPass(QuantizationTestCase):
105+
def _test_duplicate_dq(
106+
self,
107+
model,
108+
example_inputs,
109+
quantizer,
110+
):
111+
m_eager = model.eval()
112+
113+
# program capture
114+
m = copy.deepcopy(m_eager)
115+
m = export_for_training(m, example_inputs, strict=True).module()
116+
117+
m = prepare_pt2e(m, quantizer)
118+
# Calibrate
119+
m(*example_inputs)
120+
m = convert_pt2e(m)
121+
122+
pt2_quant_output = m(*example_inputs)
123+
for n in m.graph.nodes:
124+
annotation = n.meta.get("quantization_annotation", None)
125+
if annotation is not None:
126+
for arg in n.args:
127+
if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS:
128+
self.assertEqual(len(arg.users.keys()), 1)
129+
130+
def test_no_need_for_duplicate_dq(self):
131+
"""
132+
Model under test
133+
conv2d -> avgpool -> hardtanh -> linear
134+
Check quantization tags on conv2d, avgpool and linear are correctly set
135+
"""
136+
137+
class BackendAQuantizer(Quantizer):
138+
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
139+
backend_string = "BackendA"
140+
quantization_config = get_symmetric_quantization_config(
141+
is_per_channel=True
142+
)
143+
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
144+
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
145+
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)
146+
147+
def validate(self, model: torch.fx.GraphModule) -> None:
148+
pass
149+
150+
example_inputs = (torch.randn(1, 3, 5, 7),)
151+
self._test_duplicate_dq(
152+
TestHelperModules.Conv2dWithObsSharingOps(),
153+
example_inputs,
154+
BackendAQuantizer(),
155+
)
156+
157+
def test_simple_duplicate_dq(self):
158+
"""
159+
Model under test
160+
conv2d -> conv2d -> add
161+
| |
162+
--------->
163+
|
164+
-----> view_copy --> linear
165+
|
166+
-----> mul
167+
There should be three dq nodes because output for the
168+
first conv2d is fed to next conv2d, add, and view_copy + linear.
169+
All three are quantized.
170+
Thus DQ node is not duplicated for those three uses
171+
"""
172+
173+
class BackendAQuantizer(Quantizer):
174+
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
175+
backend_string = "BackendA"
176+
quantization_config = get_symmetric_quantization_config(
177+
is_per_channel=True
178+
)
179+
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
180+
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
181+
OP_TO_ANNOTATOR["add"](gm, quantization_config)
182+
183+
def validate(self, model: torch.fx.GraphModule) -> None:
184+
pass
185+
186+
example_inputs = (torch.randn(1, 3, 5, 7),)
187+
self._test_duplicate_dq(
188+
TestHelperModules.Conv2dWithSharedDQ(),
189+
example_inputs,
190+
BackendAQuantizer(),
191+
)
192+
193+
def test_no_add_quant_duplicate_dq(self):
194+
"""
195+
Model under test
196+
conv2d -> conv2d -> add
197+
| |
198+
--------->
199+
|
200+
-----> view_copy --> linear
201+
|
202+
-----> mul
203+
There should be three dq nodes because output for the
204+
first conv2d is fed to next conv2d, and view_copy + linear.
205+
Both are quantized.
206+
However the skip connection to add and mul are not quantized.
207+
Thus DQ node is not duplicated for those two uses
208+
"""
209+
210+
class BackendAQuantizer(Quantizer):
211+
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
212+
backend_string = "BackendA"
213+
quantization_config = get_symmetric_quantization_config(
214+
is_per_channel=True
215+
)
216+
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
217+
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
218+
219+
def validate(self, model: torch.fx.GraphModule) -> None:
220+
pass
221+
222+
example_inputs = (torch.randn(1, 3, 5, 7),)
223+
self._test_duplicate_dq(
224+
TestHelperModules.Conv2dWithSharedDQ(),
225+
example_inputs,
226+
BackendAQuantizer(),
227+
)
228+
229+
def test_avgpool_use_different_qconfig(self):
230+
"""
231+
Model under test
232+
conv2d -> conv2d -> add
233+
| |
234+
--------->
235+
|
236+
-----> adaptive_avgpool2d (different qconfig)
237+
|
238+
-----> add
239+
output
240+
conv2d -> dq -> conv2d -> add
241+
| |
242+
-------> dq ----->
243+
|
244+
-> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig)
245+
|
246+
-> dq -----> add
247+
"""
248+
249+
def _get_uint8_quantization_config():
250+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
251+
act_quantization_spec = QuantizationSpec(
252+
dtype=torch.uint8,
253+
quant_min=0,
254+
quant_max=255,
255+
qscheme=torch.per_tensor_affine,
256+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
257+
eps=2**-12
258+
),
259+
)
260+
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
261+
MinMaxObserver
262+
)
263+
264+
extra_args: dict[str, Any] = {"eps": 2**-12}
265+
weight_quantization_spec = QuantizationSpec(
266+
dtype=torch.uint8,
267+
quant_min=0,
268+
quant_max=255,
269+
qscheme=torch.per_tensor_affine,
270+
ch_axis=0,
271+
is_dynamic=False,
272+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
273+
**extra_args
274+
),
275+
)
276+
277+
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
278+
PlaceholderObserver
279+
)
280+
bias_quantization_spec = QuantizationSpec(
281+
dtype=torch.float,
282+
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr,
283+
)
284+
quantization_config = QuantizationConfig(
285+
act_quantization_spec,
286+
act_quantization_spec,
287+
weight_quantization_spec,
288+
bias_quantization_spec,
289+
)
290+
return quantization_config
291+
292+
class BackendAQuantizer(Quantizer):
293+
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
294+
backend_string = "BackendA"
295+
quantization_config = get_symmetric_quantization_config(
296+
is_per_channel=True
297+
)
298+
avgpool_qconfig = _get_uint8_quantization_config()
299+
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
300+
OP_TO_ANNOTATOR["add"](gm, quantization_config)
301+
for n in gm.graph.nodes:
302+
if n.op == "call_function" and n.target == torch.ops.aten.mean.dim:
303+
qspec = avgpool_qconfig.input_activation
304+
input_act = n.args[0]
305+
output_qspec = SharedQuantizationSpec((input_act, n))
306+
n.meta["quantization_annotation"] = QuantizationAnnotation(
307+
input_qspec_map={input_act: qspec},
308+
output_qspec=output_qspec,
309+
_annotated=True,
310+
)
311+
312+
def validate(self, model: torch.fx.GraphModule) -> None:
313+
pass
314+
315+
example_inputs = (torch.randn(1, 3, 5, 7),)
316+
self._test_duplicate_dq(
317+
TestHelperModules.ModuleForDifferentQconfig(),
318+
example_inputs,
319+
BackendAQuantizer(),
320+
)
321+
322+
323+
if __name__ == "__main__":
324+
run_tests()

0 commit comments

Comments
 (0)