|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import copy |
| 8 | +import itertools |
| 9 | +import tempfile |
| 10 | +import unittest |
| 11 | + |
| 12 | +import torch |
| 13 | +from torch.testing import FileCheck |
| 14 | + |
| 15 | +from torchao.dtypes import PlainLayout |
| 16 | +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( |
| 17 | + PackedLinearInt8DynamicActivationIntxWeightLayout, |
| 18 | +) |
| 19 | +from torchao.experimental.q_dq_layout import QDQLayout |
| 20 | +from torchao.experimental.quant_api import ( |
| 21 | + int8_dynamic_activation_intx_weight, |
| 22 | +) |
| 23 | +from torchao.quantization.granularity import ( |
| 24 | + PerGroup, |
| 25 | + PerRow, |
| 26 | +) |
| 27 | +from torchao.quantization.quant_api import quantize_ |
| 28 | +from torchao.utils import unwrap_tensor_subclass |
| 29 | + |
| 30 | + |
| 31 | +class TestInt8DynamicActivationIntxWeight(unittest.TestCase): |
| 32 | + def test_accuracy(self): |
| 33 | + """ |
| 34 | + Checks the accuracy of different layouts by comparing the results to PlainLayout() |
| 35 | + """ |
| 36 | + m = 1 |
| 37 | + n = 1071 |
| 38 | + k = 4096 |
| 39 | + activations = torch.randn(m, k) |
| 40 | + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) |
| 41 | + |
| 42 | + reference_layout = PlainLayout() |
| 43 | + test_layouts = [ |
| 44 | + PackedLinearInt8DynamicActivationIntxWeightLayout(), |
| 45 | + QDQLayout(), |
| 46 | + ] |
| 47 | + test_weight_dtypes = [ |
| 48 | + torch.int1, |
| 49 | + torch.int2, |
| 50 | + torch.int3, |
| 51 | + torch.int4, |
| 52 | + torch.int5, |
| 53 | + torch.int6, |
| 54 | + torch.int7, |
| 55 | + torch.int8, |
| 56 | + ] |
| 57 | + test_has_weight_zeros = [True, False] |
| 58 | + test_granularities = [PerGroup(128), PerRow()] |
| 59 | + for layout, weight_dtype, has_weight_zeros, granularity in itertools.product( |
| 60 | + test_layouts, test_weight_dtypes, test_has_weight_zeros, test_granularities |
| 61 | + ): |
| 62 | + quantized_model = copy.deepcopy(model) |
| 63 | + quantize_( |
| 64 | + quantized_model, |
| 65 | + int8_dynamic_activation_intx_weight( |
| 66 | + weight_dtype=weight_dtype, |
| 67 | + granularity=granularity, |
| 68 | + has_weight_zeros=has_weight_zeros, |
| 69 | + layout=layout, |
| 70 | + ), |
| 71 | + ) |
| 72 | + |
| 73 | + quantized_model_reference = copy.deepcopy(model) |
| 74 | + quantize_( |
| 75 | + quantized_model_reference, |
| 76 | + int8_dynamic_activation_intx_weight( |
| 77 | + weight_dtype=weight_dtype, |
| 78 | + granularity=granularity, |
| 79 | + has_weight_zeros=has_weight_zeros, |
| 80 | + layout=reference_layout, |
| 81 | + ), |
| 82 | + ) |
| 83 | + |
| 84 | + with torch.no_grad(): |
| 85 | + result = quantized_model(activations) |
| 86 | + expected_result = quantized_model_reference(activations) |
| 87 | + self.assertTrue(torch.allclose(result, expected_result, atol=1e-6)) |
| 88 | + |
| 89 | + def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout( |
| 90 | + self, |
| 91 | + ): |
| 92 | + """ |
| 93 | + Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with |
| 94 | + torch.export.export, torch.compile, and AOTI. |
| 95 | + """ |
| 96 | + granularity = PerRow() |
| 97 | + m = 3 |
| 98 | + k0 = 512 |
| 99 | + k1 = 256 |
| 100 | + k2 = 128 |
| 101 | + k3 = 1024 |
| 102 | + weight_dtype = torch.int4 |
| 103 | + has_weight_zeros = True |
| 104 | + layers = [ |
| 105 | + torch.nn.Linear(k0, k1, bias=False), |
| 106 | + torch.nn.Linear(k1, k2, bias=False), |
| 107 | + torch.nn.Linear(k2, k3, bias=False), |
| 108 | + ] |
| 109 | + model = torch.nn.Sequential(*layers) |
| 110 | + activations = torch.randn(2, 1, m, k0, dtype=torch.float32) |
| 111 | + |
| 112 | + quantize_( |
| 113 | + model, |
| 114 | + int8_dynamic_activation_intx_weight( |
| 115 | + weight_dtype=weight_dtype, |
| 116 | + granularity=granularity, |
| 117 | + has_weight_zeros=has_weight_zeros, |
| 118 | + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), |
| 119 | + ), |
| 120 | + ) |
| 121 | + eager_results = model(activations) |
| 122 | + |
| 123 | + unwrapped_model = copy.deepcopy(model) |
| 124 | + unwrap_tensor_subclass(model) |
| 125 | + |
| 126 | + # Export |
| 127 | + exported = torch.export.export(model, (activations,), strict=True) |
| 128 | + exported_results = exported.module()(activations) |
| 129 | + self.assertTrue(torch.allclose(eager_results, exported_results)) |
| 130 | + |
| 131 | + # Compile |
| 132 | + compiled = torch.compile(unwrapped_model) |
| 133 | + with torch.no_grad(): |
| 134 | + compiled_results = compiled(activations) |
| 135 | + self.assertTrue(torch.allclose(eager_results, compiled_results)) |
| 136 | + |
| 137 | + # AOTI |
| 138 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 139 | + package_path = f"{tmpdirname}/model.pt2" |
| 140 | + torch._inductor.aoti_compile_and_package( |
| 141 | + exported, package_path=package_path |
| 142 | + ) |
| 143 | + fn = torch._inductor.aoti_load_package(package_path) |
| 144 | + aoti_results = fn(activations) |
| 145 | + self.assertTrue(torch.allclose(eager_results, aoti_results)) |
| 146 | + |
| 147 | + def test_export_QDQLayout(self): |
| 148 | + """ |
| 149 | + Checks that models quantized with TestQDQLayout() export as expected |
| 150 | + """ |
| 151 | + granularity = PerGroup(64) |
| 152 | + weight_dtype = torch.int4 |
| 153 | + has_weight_zeros = False |
| 154 | + layers = [ |
| 155 | + torch.nn.Linear(512, 256, bias=False), |
| 156 | + ] |
| 157 | + model = torch.nn.Sequential(*layers) |
| 158 | + activations = torch.randn(1, 512, dtype=torch.float32) |
| 159 | + |
| 160 | + quantize_( |
| 161 | + model, |
| 162 | + int8_dynamic_activation_intx_weight( |
| 163 | + weight_dtype=weight_dtype, |
| 164 | + granularity=granularity, |
| 165 | + has_weight_zeros=has_weight_zeros, |
| 166 | + layout=QDQLayout(), |
| 167 | + ), |
| 168 | + ) |
| 169 | + eager_results = model(activations) |
| 170 | + |
| 171 | + unwrap_tensor_subclass(model) |
| 172 | + exported = torch.export.export(model, (activations,), strict=True) |
| 173 | + exported_results = exported.module()(activations) |
| 174 | + self.assertTrue(torch.allclose(eager_results, exported_results)) |
| 175 | + |
| 176 | + expected_lines = [ |
| 177 | + "torch.ops.quant.choose_qparams_affine.default(input_1, 'ASYMMETRIC', [1, 512], torch.int32, -128, 127, None, torch.float32, torch.int32)", |
| 178 | + "torch.ops.quant.quantize_affine.default(input_1, [1, 512], getitem, getitem_1, torch.int32, -128, 127)", |
| 179 | + "torch.ops.quant.dequantize_affine.default(quantize_affine, [1, 512], getitem, getitem_1, torch.int32, -128, 127)", |
| 180 | + "torch.ops.quant.dequantize_affine.default(p_fn_0_parametrizations_weight_original0, [1, 64], p_fn_0_parametrizations_weight_original1, None, torch.int32, -8, 7, 'NONE')", |
| 181 | + "torch.ops.aten.linear.default(dequantize_affine, dequantize_affine_1)", |
| 182 | + ] |
| 183 | + for line in expected_lines: |
| 184 | + FileCheck().check_count(line, 1, exactly=True).run( |
| 185 | + exported.graph_module.code |
| 186 | + ) |
0 commit comments