-
Notifications
You must be signed in to change notification settings - Fork 294
Copy and adapt pt2e quantization code to torchao #2048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
ac8f45d
Copy and adapt pt2e quantization code to torchao
jerryzh168 7364836
fix import
jerryzh168 af29555
update
jerryzh168 c8cbda8
import
jerryzh168 0922c09
import
jerryzh168 172008a
import version check
jerryzh168 628ad2e
import
jerryzh168 f88c204
guard pytorch version for op
jerryzh168 029b2cb
torchao_quant namespace
jerryzh168 4c16dde
rename namespace to torchao
jerryzh168 0d0474f
trigger CI
jerryzh168 4154766
working around 'test_int4.quantize_per_tensor_int4' has no overload n…
jerryzh168 b4f196a
fix
jerryzh168 8075634
debug why the test failed
jerryzh168 0050dce
skip test due to out dated nightly
jerryzh168 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,324 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# Owner(s): ["oncall: quantization"] | ||
# ruff: noqa: F841 | ||
import copy | ||
import unittest | ||
from typing import Any | ||
|
||
import torch | ||
from torch.testing._internal.common_quantization import QuantizationTestCase | ||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests | ||
|
||
from torchao.quantization.pt2e.observer import ( | ||
HistogramObserver, | ||
MinMaxObserver, | ||
PlaceholderObserver, | ||
) | ||
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e | ||
from torchao.quantization.pt2e.quantizer import ( | ||
QuantizationAnnotation, | ||
QuantizationSpec, | ||
Quantizer, | ||
SharedQuantizationSpec, | ||
) | ||
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import ( | ||
get_symmetric_quantization_config, | ||
) | ||
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import ( | ||
OP_TO_ANNOTATOR, | ||
QuantizationConfig, | ||
) | ||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
from torch.export import export_for_training | ||
|
||
|
||
class TestHelperModules: | ||
class Conv2dWithObsSharingOps(torch.nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.conv = torch.nn.Conv2d(3, 3, 3) | ||
self.hardtanh = torch.nn.Hardtanh() | ||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) | ||
self.linear = torch.nn.Linear(3, 3) | ||
|
||
def forward(self, x): | ||
x = self.conv(x) | ||
x = self.adaptive_avg_pool2d(x) | ||
x = self.hardtanh(x) | ||
x = x.view(-1, 3) | ||
x = self.linear(x) | ||
return x | ||
|
||
class Conv2dWithSharedDQ(torch.nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.conv1 = torch.nn.Conv2d(3, 3, 3) | ||
self.conv2 = torch.nn.Conv2d(3, 3, 1) | ||
self.linear = torch.nn.Linear(3, 3) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
z = x.view(-1, 3) | ||
w = self.linear(z) | ||
|
||
y = self.conv2(x) | ||
add_output = x + y | ||
|
||
extra_output = x * 2 | ||
return w, add_output, extra_output | ||
|
||
class ModuleForDifferentQconfig(torch.nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.conv1 = torch.nn.Conv2d(3, 3, 3) | ||
self.conv2 = torch.nn.Conv2d(3, 3, 1) | ||
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
w = self.adaptive_avg_pool2d(x) | ||
|
||
y = self.conv2(x) | ||
add_output = x + y | ||
|
||
extra_output = x + 2 | ||
return w, add_output, extra_output | ||
|
||
|
||
_DEQUANTIZE_OPS = [ | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.default, | ||
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, | ||
torch.ops.quantized_decomposed.dequantize_per_channel.default, | ||
] | ||
|
||
|
||
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") | ||
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") | ||
class TestDuplicateDQPass(QuantizationTestCase): | ||
def _test_duplicate_dq( | ||
self, | ||
model, | ||
example_inputs, | ||
quantizer, | ||
): | ||
m_eager = model.eval() | ||
|
||
# program capture | ||
m = copy.deepcopy(m_eager) | ||
m = export_for_training(m, example_inputs, strict=True).module() | ||
|
||
m = prepare_pt2e(m, quantizer) | ||
# Calibrate | ||
m(*example_inputs) | ||
m = convert_pt2e(m) | ||
|
||
pt2_quant_output = m(*example_inputs) | ||
for n in m.graph.nodes: | ||
annotation = n.meta.get("quantization_annotation", None) | ||
if annotation is not None: | ||
for arg in n.args: | ||
if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS: | ||
self.assertEqual(len(arg.users.keys()), 1) | ||
|
||
def test_no_need_for_duplicate_dq(self): | ||
""" | ||
Model under test | ||
conv2d -> avgpool -> hardtanh -> linear | ||
Check quantization tags on conv2d, avgpool and linear are correctly set | ||
""" | ||
|
||
class BackendAQuantizer(Quantizer): | ||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
backend_string = "BackendA" | ||
quantization_config = get_symmetric_quantization_config( | ||
is_per_channel=True | ||
) | ||
OP_TO_ANNOTATOR["linear"](gm, quantization_config) | ||
OP_TO_ANNOTATOR["conv"](gm, quantization_config) | ||
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
example_inputs = (torch.randn(1, 3, 5, 7),) | ||
self._test_duplicate_dq( | ||
TestHelperModules.Conv2dWithObsSharingOps(), | ||
example_inputs, | ||
BackendAQuantizer(), | ||
) | ||
|
||
def test_simple_duplicate_dq(self): | ||
""" | ||
Model under test | ||
conv2d -> conv2d -> add | ||
| | | ||
---------> | ||
| | ||
-----> view_copy --> linear | ||
| | ||
-----> mul | ||
There should be three dq nodes because output for the | ||
first conv2d is fed to next conv2d, add, and view_copy + linear. | ||
All three are quantized. | ||
Thus DQ node is not duplicated for those three uses | ||
""" | ||
|
||
class BackendAQuantizer(Quantizer): | ||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
backend_string = "BackendA" | ||
quantization_config = get_symmetric_quantization_config( | ||
is_per_channel=True | ||
) | ||
OP_TO_ANNOTATOR["linear"](gm, quantization_config) | ||
OP_TO_ANNOTATOR["conv"](gm, quantization_config) | ||
OP_TO_ANNOTATOR["add"](gm, quantization_config) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
example_inputs = (torch.randn(1, 3, 5, 7),) | ||
self._test_duplicate_dq( | ||
TestHelperModules.Conv2dWithSharedDQ(), | ||
example_inputs, | ||
BackendAQuantizer(), | ||
) | ||
|
||
def test_no_add_quant_duplicate_dq(self): | ||
""" | ||
Model under test | ||
conv2d -> conv2d -> add | ||
| | | ||
---------> | ||
| | ||
-----> view_copy --> linear | ||
| | ||
-----> mul | ||
There should be three dq nodes because output for the | ||
first conv2d is fed to next conv2d, and view_copy + linear. | ||
Both are quantized. | ||
However the skip connection to add and mul are not quantized. | ||
Thus DQ node is not duplicated for those two uses | ||
""" | ||
|
||
class BackendAQuantizer(Quantizer): | ||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
backend_string = "BackendA" | ||
quantization_config = get_symmetric_quantization_config( | ||
is_per_channel=True | ||
) | ||
OP_TO_ANNOTATOR["linear"](gm, quantization_config) | ||
OP_TO_ANNOTATOR["conv"](gm, quantization_config) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
example_inputs = (torch.randn(1, 3, 5, 7),) | ||
self._test_duplicate_dq( | ||
TestHelperModules.Conv2dWithSharedDQ(), | ||
example_inputs, | ||
BackendAQuantizer(), | ||
) | ||
|
||
def test_avgpool_use_different_qconfig(self): | ||
""" | ||
Model under test | ||
conv2d -> conv2d -> add | ||
| | | ||
---------> | ||
| | ||
-----> adaptive_avgpool2d (different qconfig) | ||
| | ||
-----> add | ||
output | ||
conv2d -> dq -> conv2d -> add | ||
| | | ||
-------> dq -----> | ||
| | ||
-> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig) | ||
| | ||
-> dq -----> add | ||
""" | ||
|
||
def _get_uint8_quantization_config(): | ||
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] | ||
act_quantization_spec = QuantizationSpec( | ||
dtype=torch.uint8, | ||
quant_min=0, | ||
quant_max=255, | ||
qscheme=torch.per_tensor_affine, | ||
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( | ||
eps=2**-12 | ||
), | ||
) | ||
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821 | ||
MinMaxObserver | ||
) | ||
|
||
extra_args: dict[str, Any] = {"eps": 2**-12} | ||
weight_quantization_spec = QuantizationSpec( | ||
dtype=torch.uint8, | ||
quant_min=0, | ||
quant_max=255, | ||
qscheme=torch.per_tensor_affine, | ||
ch_axis=0, | ||
is_dynamic=False, | ||
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( | ||
**extra_args | ||
), | ||
) | ||
|
||
bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821 | ||
PlaceholderObserver | ||
) | ||
bias_quantization_spec = QuantizationSpec( | ||
dtype=torch.float, | ||
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr, | ||
) | ||
quantization_config = QuantizationConfig( | ||
act_quantization_spec, | ||
act_quantization_spec, | ||
weight_quantization_spec, | ||
bias_quantization_spec, | ||
) | ||
return quantization_config | ||
|
||
class BackendAQuantizer(Quantizer): | ||
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | ||
backend_string = "BackendA" | ||
quantization_config = get_symmetric_quantization_config( | ||
is_per_channel=True | ||
) | ||
avgpool_qconfig = _get_uint8_quantization_config() | ||
OP_TO_ANNOTATOR["conv"](gm, quantization_config) | ||
OP_TO_ANNOTATOR["add"](gm, quantization_config) | ||
for n in gm.graph.nodes: | ||
if n.op == "call_function" and n.target == torch.ops.aten.mean.dim: | ||
qspec = avgpool_qconfig.input_activation | ||
input_act = n.args[0] | ||
output_qspec = SharedQuantizationSpec((input_act, n)) | ||
n.meta["quantization_annotation"] = QuantizationAnnotation( | ||
input_qspec_map={input_act: qspec}, | ||
output_qspec=output_qspec, | ||
_annotated=True, | ||
) | ||
|
||
def validate(self, model: torch.fx.GraphModule) -> None: | ||
pass | ||
|
||
example_inputs = (torch.randn(1, 3, 5, 7),) | ||
self._test_duplicate_dq( | ||
TestHelperModules.ModuleForDifferentQconfig(), | ||
example_inputs, | ||
BackendAQuantizer(), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.