Skip to content

Commit 952ab42

Browse files
committed
[Fix]: Enable SYMMETRIC_NO_CLIPPING_ERR Mapping type and tests
Signed-off-by: Nikhil Gupta <[email protected]>
1 parent ca8a5f1 commit 952ab42

File tree

2 files changed

+86
-2
lines changed

2 files changed

+86
-2
lines changed

torchao/experimental/quant_api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -614,13 +614,13 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
614614
if layout.target == Target.ATEN:
615615
if weight_dtype != torch.int4 or \
616616
has_weight_zeros != True or \
617-
weight_mapping_type != MappingType.SYMMETRIC:
617+
weight_mapping_type == MappingType.ASYMMETRIC:
618618
raise NotImplementedError(
619619
f"target 'aten' requires:\n"
620620
f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
621621
f"- has_weight_zeros to be True,\n"
622622
f"- weight_dtype to be torch.int4,\n"
623-
f"- weight_mapping_type to be MappingType.SYMMETRIC"
623+
f"- weight_mapping_type to be MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR"
624624
)
625625
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
626626
if torch.backends.kleidiai.is_available():
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import unittest
9+
10+
import torch
11+
12+
from torchao.dtypes import PlainLayout
13+
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
14+
PackedLinearInt8DynamicActivationIntxWeightLayout,
15+
)
16+
from torchao.experimental.quant_api import (
17+
int8_dynamic_activation_intx_weight,
18+
)
19+
from torchao.quantization.granularity import (
20+
PerGroup,
21+
PerRow,
22+
)
23+
from torchao.quantization.quant_api import quantize_
24+
from torchao.utils import unwrap_tensor_subclass
25+
from torchao.quantization.quant_primitives import MappingType
26+
27+
28+
class TestPackedLinearInt8DynamicActivationIntxWeightLayoutAten(unittest.TestCase):
29+
def test_accuracy(self):
30+
"""
31+
Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing
32+
its results to the results of a reference model that uses PlainLayout()
33+
"""
34+
granularities = [PerRow()]
35+
m = 32
36+
n = 128
37+
k = 256
38+
activations = torch.randn(m, k)
39+
weight_mapping_type = MappingType.SYMMETRIC_NO_CLIPPING_ERR
40+
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])
41+
42+
for weight_dtype in [
43+
torch.int4,
44+
]:
45+
for has_weight_zeros in [True]:
46+
for granularity in granularities:
47+
print(
48+
f"Testing weight_dtype={weight_dtype}, has_weight_zeros={
49+
has_weight_zeros}, granularity={granularity}"
50+
)
51+
quantized_model = copy.deepcopy(model)
52+
quantize_(
53+
quantized_model,
54+
int8_dynamic_activation_intx_weight(
55+
weight_dtype=weight_dtype,
56+
granularity=granularity,
57+
has_weight_zeros=has_weight_zeros,
58+
weight_mapping_type=weight_mapping_type,
59+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
60+
target="aten"), # default
61+
),
62+
)
63+
64+
quantized_model_reference = copy.deepcopy(model)
65+
quantize_(
66+
quantized_model_reference,
67+
int8_dynamic_activation_intx_weight(
68+
weight_dtype=weight_dtype,
69+
granularity=granularity,
70+
has_weight_zeros=has_weight_zeros,
71+
layout=PlainLayout(),
72+
),
73+
)
74+
75+
with torch.no_grad():
76+
res = quantized_model(activations)
77+
ref = quantized_model_reference(activations)
78+
79+
mean_err = ((res - ref).abs() / ref).mean()
80+
self.assertTrue(mean_err < 0.04)
81+
82+
83+
if __name__ == "__main__":
84+
unittest.main()

0 commit comments

Comments
 (0)