Skip to content

Copy and adapt pt2e quantization code to torchao #2048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 15, 2025
4 changes: 2 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2038,8 +2038,8 @@ def forward(self, x):
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int4w_api:
targets = [n.target for n in model.graph.nodes]
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)
self.assertTrue(torch.ops.torchao.choose_qparams_affine.default in targets)
self.assertTrue(torch.ops.torchao.quantize_affine.default in targets)
self.assertFalse(torch.ops.aten.narrow.default in targets)


Expand Down
324 changes: 324 additions & 0 deletions test/quantization/pt2e/test_duplicate_dq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

# Owner(s): ["oncall: quantization"]
# ruff: noqa: F841
import copy
import unittest
from typing import Any

import torch
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests

from torchao.quantization.pt2e.observer import (
HistogramObserver,
MinMaxObserver,
PlaceholderObserver,
)
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
Quantizer,
SharedQuantizationSpec,
)
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
QuantizationConfig,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7

if TORCH_VERSION_AT_LEAST_2_5:
from torch.export import export_for_training


class TestHelperModules:
class Conv2dWithObsSharingOps(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.hardtanh = torch.nn.Hardtanh()
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.conv(x)
x = self.adaptive_avg_pool2d(x)
x = self.hardtanh(x)
x = x.view(-1, 3)
x = self.linear(x)
return x

class Conv2dWithSharedDQ(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 1)
self.linear = torch.nn.Linear(3, 3)

def forward(self, x):
x = self.conv1(x)
z = x.view(-1, 3)
w = self.linear(z)

y = self.conv2(x)
add_output = x + y

extra_output = x * 2
return w, add_output, extra_output

class ModuleForDifferentQconfig(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 1)
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))

def forward(self, x):
x = self.conv1(x)
w = self.adaptive_avg_pool2d(x)

y = self.conv2(x)
add_output = x + y

extra_output = x + 2
return w, add_output, extra_output


_DEQUANTIZE_OPS = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
]


@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
class TestDuplicateDQPass(QuantizationTestCase):
def _test_duplicate_dq(
self,
model,
example_inputs,
quantizer,
):
m_eager = model.eval()

# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(m, example_inputs, strict=True).module()

m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)

pt2_quant_output = m(*example_inputs)
for n in m.graph.nodes:
annotation = n.meta.get("quantization_annotation", None)
if annotation is not None:
for arg in n.args:
if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS:
self.assertEqual(len(arg.users.keys()), 1)

def test_no_need_for_duplicate_dq(self):
"""
Model under test
conv2d -> avgpool -> hardtanh -> linear
Check quantization tags on conv2d, avgpool and linear are correctly set
"""

class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.Conv2dWithObsSharingOps(),
example_inputs,
BackendAQuantizer(),
)

def test_simple_duplicate_dq(self):
"""
Model under test
conv2d -> conv2d -> add
| |
--------->
|
-----> view_copy --> linear
|
-----> mul
There should be three dq nodes because output for the
first conv2d is fed to next conv2d, add, and view_copy + linear.
All three are quantized.
Thus DQ node is not duplicated for those three uses
"""

class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.Conv2dWithSharedDQ(),
example_inputs,
BackendAQuantizer(),
)

def test_no_add_quant_duplicate_dq(self):
"""
Model under test
conv2d -> conv2d -> add
| |
--------->
|
-----> view_copy --> linear
|
-----> mul
There should be three dq nodes because output for the
first conv2d is fed to next conv2d, and view_copy + linear.
Both are quantized.
However the skip connection to add and mul are not quantized.
Thus DQ node is not duplicated for those two uses
"""

class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
OP_TO_ANNOTATOR["linear"](gm, quantization_config)
OP_TO_ANNOTATOR["conv"](gm, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.Conv2dWithSharedDQ(),
example_inputs,
BackendAQuantizer(),
)

def test_avgpool_use_different_qconfig(self):
"""
Model under test
conv2d -> conv2d -> add
| |
--------->
|
-----> adaptive_avgpool2d (different qconfig)
|
-----> add
output
conv2d -> dq -> conv2d -> add
| |
-------> dq ----->
|
-> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig)
|
-> dq -----> add
"""

def _get_uint8_quantization_config():
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
act_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
eps=2**-12
),
)
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
MinMaxObserver
)

extra_args: dict[str, Any] = {"eps": 2**-12}
weight_quantization_spec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
ch_axis=0,
is_dynamic=False,
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
**extra_args
),
)

bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821
PlaceholderObserver
)
bias_quantization_spec = QuantizationSpec(
dtype=torch.float,
observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr,
)
quantization_config = QuantizationConfig(
act_quantization_spec,
act_quantization_spec,
weight_quantization_spec,
bias_quantization_spec,
)
return quantization_config

class BackendAQuantizer(Quantizer):
def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
backend_string = "BackendA"
quantization_config = get_symmetric_quantization_config(
is_per_channel=True
)
avgpool_qconfig = _get_uint8_quantization_config()
OP_TO_ANNOTATOR["conv"](gm, quantization_config)
OP_TO_ANNOTATOR["add"](gm, quantization_config)
for n in gm.graph.nodes:
if n.op == "call_function" and n.target == torch.ops.aten.mean.dim:
qspec = avgpool_qconfig.input_activation
input_act = n.args[0]
output_qspec = SharedQuantizationSpec((input_act, n))
n.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={input_act: qspec},
output_qspec=output_qspec,
_annotated=True,
)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

example_inputs = (torch.randn(1, 3, 5, 7),)
self._test_duplicate_dq(
TestHelperModules.ModuleForDifferentQconfig(),
example_inputs,
BackendAQuantizer(),
)


if __name__ == "__main__":
run_tests()
Loading
Loading