|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# Owner(s): ["oncall: quantization"] |
| 8 | +# ruff: noqa: F841 |
| 9 | +import copy |
| 10 | +import unittest |
| 11 | +from typing import Any |
| 12 | + |
| 13 | +import torch |
| 14 | +from torch.testing._internal.common_quantization import QuantizationTestCase |
| 15 | +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests |
| 16 | + |
| 17 | +from torchao.quantization.pt2e.observer import ( |
| 18 | + HistogramObserver, |
| 19 | + MinMaxObserver, |
| 20 | + PlaceholderObserver, |
| 21 | +) |
| 22 | +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 23 | +from torchao.quantization.pt2e.quantizer import ( |
| 24 | + QuantizationAnnotation, |
| 25 | + QuantizationSpec, |
| 26 | + Quantizer, |
| 27 | + SharedQuantizationSpec, |
| 28 | +) |
| 29 | +from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import ( |
| 30 | + get_symmetric_quantization_config, |
| 31 | +) |
| 32 | +from torchao.quantization.pt2e.quantizer.xnnpack_quantizer_utils import ( |
| 33 | + OP_TO_ANNOTATOR, |
| 34 | + QuantizationConfig, |
| 35 | +) |
| 36 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_7 |
| 37 | + |
| 38 | +if TORCH_VERSION_AT_LEAST_2_5: |
| 39 | + from torch.export import export_for_training |
| 40 | + |
| 41 | + |
| 42 | +class TestHelperModules: |
| 43 | + class Conv2dWithObsSharingOps(torch.nn.Module): |
| 44 | + def __init__(self) -> None: |
| 45 | + super().__init__() |
| 46 | + self.conv = torch.nn.Conv2d(3, 3, 3) |
| 47 | + self.hardtanh = torch.nn.Hardtanh() |
| 48 | + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| 49 | + self.linear = torch.nn.Linear(3, 3) |
| 50 | + |
| 51 | + def forward(self, x): |
| 52 | + x = self.conv(x) |
| 53 | + x = self.adaptive_avg_pool2d(x) |
| 54 | + x = self.hardtanh(x) |
| 55 | + x = x.view(-1, 3) |
| 56 | + x = self.linear(x) |
| 57 | + return x |
| 58 | + |
| 59 | + class Conv2dWithSharedDQ(torch.nn.Module): |
| 60 | + def __init__(self) -> None: |
| 61 | + super().__init__() |
| 62 | + self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| 63 | + self.conv2 = torch.nn.Conv2d(3, 3, 1) |
| 64 | + self.linear = torch.nn.Linear(3, 3) |
| 65 | + |
| 66 | + def forward(self, x): |
| 67 | + x = self.conv1(x) |
| 68 | + z = x.view(-1, 3) |
| 69 | + w = self.linear(z) |
| 70 | + |
| 71 | + y = self.conv2(x) |
| 72 | + add_output = x + y |
| 73 | + |
| 74 | + extra_output = x * 2 |
| 75 | + return w, add_output, extra_output |
| 76 | + |
| 77 | + class ModuleForDifferentQconfig(torch.nn.Module): |
| 78 | + def __init__(self) -> None: |
| 79 | + super().__init__() |
| 80 | + self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| 81 | + self.conv2 = torch.nn.Conv2d(3, 3, 1) |
| 82 | + self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| 83 | + |
| 84 | + def forward(self, x): |
| 85 | + x = self.conv1(x) |
| 86 | + w = self.adaptive_avg_pool2d(x) |
| 87 | + |
| 88 | + y = self.conv2(x) |
| 89 | + add_output = x + y |
| 90 | + |
| 91 | + extra_output = x + 2 |
| 92 | + return w, add_output, extra_output |
| 93 | + |
| 94 | + |
| 95 | +_DEQUANTIZE_OPS = [ |
| 96 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| 97 | + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, |
| 98 | + torch.ops.quantized_decomposed.dequantize_per_channel.default, |
| 99 | +] |
| 100 | + |
| 101 | + |
| 102 | +@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") |
| 103 | +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+") |
| 104 | +class TestDuplicateDQPass(QuantizationTestCase): |
| 105 | + def _test_duplicate_dq( |
| 106 | + self, |
| 107 | + model, |
| 108 | + example_inputs, |
| 109 | + quantizer, |
| 110 | + ): |
| 111 | + m_eager = model.eval() |
| 112 | + |
| 113 | + # program capture |
| 114 | + m = copy.deepcopy(m_eager) |
| 115 | + m = export_for_training(m, example_inputs, strict=True).module() |
| 116 | + |
| 117 | + m = prepare_pt2e(m, quantizer) |
| 118 | + # Calibrate |
| 119 | + m(*example_inputs) |
| 120 | + m = convert_pt2e(m) |
| 121 | + |
| 122 | + pt2_quant_output = m(*example_inputs) |
| 123 | + for n in m.graph.nodes: |
| 124 | + annotation = n.meta.get("quantization_annotation", None) |
| 125 | + if annotation is not None: |
| 126 | + for arg in n.args: |
| 127 | + if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS: |
| 128 | + self.assertEqual(len(arg.users.keys()), 1) |
| 129 | + |
| 130 | + def test_no_need_for_duplicate_dq(self): |
| 131 | + """ |
| 132 | + Model under test |
| 133 | + conv2d -> avgpool -> hardtanh -> linear |
| 134 | + Check quantization tags on conv2d, avgpool and linear are correctly set |
| 135 | + """ |
| 136 | + |
| 137 | + class BackendAQuantizer(Quantizer): |
| 138 | + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 139 | + backend_string = "BackendA" |
| 140 | + quantization_config = get_symmetric_quantization_config( |
| 141 | + is_per_channel=True |
| 142 | + ) |
| 143 | + OP_TO_ANNOTATOR["linear"](gm, quantization_config) |
| 144 | + OP_TO_ANNOTATOR["conv"](gm, quantization_config) |
| 145 | + OP_TO_ANNOTATOR["adaptive_avg_pool2d"](gm, quantization_config) |
| 146 | + |
| 147 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 148 | + pass |
| 149 | + |
| 150 | + example_inputs = (torch.randn(1, 3, 5, 7),) |
| 151 | + self._test_duplicate_dq( |
| 152 | + TestHelperModules.Conv2dWithObsSharingOps(), |
| 153 | + example_inputs, |
| 154 | + BackendAQuantizer(), |
| 155 | + ) |
| 156 | + |
| 157 | + def test_simple_duplicate_dq(self): |
| 158 | + """ |
| 159 | + Model under test |
| 160 | + conv2d -> conv2d -> add |
| 161 | + | | |
| 162 | + ---------> |
| 163 | + | |
| 164 | + -----> view_copy --> linear |
| 165 | + | |
| 166 | + -----> mul |
| 167 | + There should be three dq nodes because output for the |
| 168 | + first conv2d is fed to next conv2d, add, and view_copy + linear. |
| 169 | + All three are quantized. |
| 170 | + Thus DQ node is not duplicated for those three uses |
| 171 | + """ |
| 172 | + |
| 173 | + class BackendAQuantizer(Quantizer): |
| 174 | + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 175 | + backend_string = "BackendA" |
| 176 | + quantization_config = get_symmetric_quantization_config( |
| 177 | + is_per_channel=True |
| 178 | + ) |
| 179 | + OP_TO_ANNOTATOR["linear"](gm, quantization_config) |
| 180 | + OP_TO_ANNOTATOR["conv"](gm, quantization_config) |
| 181 | + OP_TO_ANNOTATOR["add"](gm, quantization_config) |
| 182 | + |
| 183 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 184 | + pass |
| 185 | + |
| 186 | + example_inputs = (torch.randn(1, 3, 5, 7),) |
| 187 | + self._test_duplicate_dq( |
| 188 | + TestHelperModules.Conv2dWithSharedDQ(), |
| 189 | + example_inputs, |
| 190 | + BackendAQuantizer(), |
| 191 | + ) |
| 192 | + |
| 193 | + def test_no_add_quant_duplicate_dq(self): |
| 194 | + """ |
| 195 | + Model under test |
| 196 | + conv2d -> conv2d -> add |
| 197 | + | | |
| 198 | + ---------> |
| 199 | + | |
| 200 | + -----> view_copy --> linear |
| 201 | + | |
| 202 | + -----> mul |
| 203 | + There should be three dq nodes because output for the |
| 204 | + first conv2d is fed to next conv2d, and view_copy + linear. |
| 205 | + Both are quantized. |
| 206 | + However the skip connection to add and mul are not quantized. |
| 207 | + Thus DQ node is not duplicated for those two uses |
| 208 | + """ |
| 209 | + |
| 210 | + class BackendAQuantizer(Quantizer): |
| 211 | + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 212 | + backend_string = "BackendA" |
| 213 | + quantization_config = get_symmetric_quantization_config( |
| 214 | + is_per_channel=True |
| 215 | + ) |
| 216 | + OP_TO_ANNOTATOR["linear"](gm, quantization_config) |
| 217 | + OP_TO_ANNOTATOR["conv"](gm, quantization_config) |
| 218 | + |
| 219 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 220 | + pass |
| 221 | + |
| 222 | + example_inputs = (torch.randn(1, 3, 5, 7),) |
| 223 | + self._test_duplicate_dq( |
| 224 | + TestHelperModules.Conv2dWithSharedDQ(), |
| 225 | + example_inputs, |
| 226 | + BackendAQuantizer(), |
| 227 | + ) |
| 228 | + |
| 229 | + def test_avgpool_use_different_qconfig(self): |
| 230 | + """ |
| 231 | + Model under test |
| 232 | + conv2d -> conv2d -> add |
| 233 | + | | |
| 234 | + ---------> |
| 235 | + | |
| 236 | + -----> adaptive_avgpool2d (different qconfig) |
| 237 | + | |
| 238 | + -----> add |
| 239 | + output |
| 240 | + conv2d -> dq -> conv2d -> add |
| 241 | + | | |
| 242 | + -------> dq -----> |
| 243 | + | |
| 244 | + -> dq -> q -> dq -----> adaptive_avgpool2d (different qconfig) |
| 245 | + | |
| 246 | + -> dq -----> add |
| 247 | + """ |
| 248 | + |
| 249 | + def _get_uint8_quantization_config(): |
| 250 | + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] |
| 251 | + act_quantization_spec = QuantizationSpec( |
| 252 | + dtype=torch.uint8, |
| 253 | + quant_min=0, |
| 254 | + quant_max=255, |
| 255 | + qscheme=torch.per_tensor_affine, |
| 256 | + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( |
| 257 | + eps=2**-12 |
| 258 | + ), |
| 259 | + ) |
| 260 | + weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821 |
| 261 | + MinMaxObserver |
| 262 | + ) |
| 263 | + |
| 264 | + extra_args: dict[str, Any] = {"eps": 2**-12} |
| 265 | + weight_quantization_spec = QuantizationSpec( |
| 266 | + dtype=torch.uint8, |
| 267 | + quant_min=0, |
| 268 | + quant_max=255, |
| 269 | + qscheme=torch.per_tensor_affine, |
| 270 | + ch_axis=0, |
| 271 | + is_dynamic=False, |
| 272 | + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args( |
| 273 | + **extra_args |
| 274 | + ), |
| 275 | + ) |
| 276 | + |
| 277 | + bias_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( # noqa: F821 |
| 278 | + PlaceholderObserver |
| 279 | + ) |
| 280 | + bias_quantization_spec = QuantizationSpec( |
| 281 | + dtype=torch.float, |
| 282 | + observer_or_fake_quant_ctr=bias_observer_or_fake_quant_ctr, |
| 283 | + ) |
| 284 | + quantization_config = QuantizationConfig( |
| 285 | + act_quantization_spec, |
| 286 | + act_quantization_spec, |
| 287 | + weight_quantization_spec, |
| 288 | + bias_quantization_spec, |
| 289 | + ) |
| 290 | + return quantization_config |
| 291 | + |
| 292 | + class BackendAQuantizer(Quantizer): |
| 293 | + def annotate(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 294 | + backend_string = "BackendA" |
| 295 | + quantization_config = get_symmetric_quantization_config( |
| 296 | + is_per_channel=True |
| 297 | + ) |
| 298 | + avgpool_qconfig = _get_uint8_quantization_config() |
| 299 | + OP_TO_ANNOTATOR["conv"](gm, quantization_config) |
| 300 | + OP_TO_ANNOTATOR["add"](gm, quantization_config) |
| 301 | + for n in gm.graph.nodes: |
| 302 | + if n.op == "call_function" and n.target == torch.ops.aten.mean.dim: |
| 303 | + qspec = avgpool_qconfig.input_activation |
| 304 | + input_act = n.args[0] |
| 305 | + output_qspec = SharedQuantizationSpec((input_act, n)) |
| 306 | + n.meta["quantization_annotation"] = QuantizationAnnotation( |
| 307 | + input_qspec_map={input_act: qspec}, |
| 308 | + output_qspec=output_qspec, |
| 309 | + _annotated=True, |
| 310 | + ) |
| 311 | + |
| 312 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 313 | + pass |
| 314 | + |
| 315 | + example_inputs = (torch.randn(1, 3, 5, 7),) |
| 316 | + self._test_duplicate_dq( |
| 317 | + TestHelperModules.ModuleForDifferentQconfig(), |
| 318 | + example_inputs, |
| 319 | + BackendAQuantizer(), |
| 320 | + ) |
| 321 | + |
| 322 | + |
| 323 | +if __name__ == "__main__": |
| 324 | + run_tests() |
0 commit comments