Skip to content

Commit 7b0d2ce

Browse files
authored
Consolidate ZeroPointDomain.NONE & None zero point domains (#1556)
* Fix ZeroPointDomain.NONE support & make it default for da8w8 weights * Fix bug & apply review recommendations * Throw exceptions when None zero_point_domain is used * Use ZeroPointDomain.NONE for weight in int8_dynamic_activation_int8_weight * Rebase with the latest main branch * Fix typo
1 parent abd41e5 commit 7b0d2ce

10 files changed

+171
-64
lines changed

test/integration/test_integration.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
import unittest
13+
from functools import partial
1314

1415
import torch
1516
import torch.nn as nn
@@ -48,6 +49,7 @@
4849
quantize_,
4950
)
5051
from torchao.quantization.quant_primitives import (
52+
MappingType,
5153
dequantize_affine,
5254
)
5355
from torchao.quantization.smoothquant import (
@@ -102,6 +104,8 @@
102104

103105
COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
104106

107+
ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]
108+
105109
COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
106110

107111

@@ -121,9 +125,18 @@ def _int8wo_groupwise_api(mod):
121125
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)
122126

123127

124-
def _int8da_int8w_api(mod):
128+
def _int8da_int8w_api(
129+
mod,
130+
act_mapping_type=MappingType.SYMMETRIC,
131+
):
125132
if TORCH_VERSION_AT_LEAST_2_4:
126-
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
133+
quantize_(
134+
mod,
135+
int8_dynamic_activation_int8_weight(
136+
act_mapping_type=act_mapping_type,
137+
),
138+
set_inductor_config=False,
139+
)
127140
if not TORCH_VERSION_AT_LEAST_2_5:
128141
unwrap_tensor_subclass(mod)
129142
else:
@@ -962,25 +975,43 @@ def _test_lin_weight_subclass_api_impl(
962975
mod[0].weight.tensor_impl.get_plain()
963976

964977
test = mod(x)
978+
965979
self.assertGreater(
966980
SQNR(ref_f, test),
967981
min_sqnr,
968-
f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
982+
f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}",
969983
)
970984

971985
mod_qc = torch.compile(mod, mode="max-autotune")
972986
test_comp = mod_qc(x)
973987
self.assertGreater(
974988
SQNR(ref_f, test_comp),
975989
min_sqnr,
976-
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
990+
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}",
977991
)
978992

979-
@parameterized.expand(COMMON_DEVICE_DTYPE)
980-
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
981-
self._test_lin_weight_subclass_api_impl(
982-
_int8da_int8w_api, device, 35, test_dtype=dtype
993+
@parameterized.expand(
994+
list(
995+
itertools.product(
996+
COMMON_DEVICES,
997+
COMMON_DTYPES,
998+
ACT_MAPPING_TYPES,
999+
)
1000+
)
1001+
)
1002+
def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping):
1003+
if (
1004+
not TORCH_VERSION_AT_LEAST_2_5
1005+
and dtype in (torch.float16, torch.bfloat16)
1006+
and act_mapping is MappingType.ASYMMETRIC
1007+
and device == "cpu"
1008+
):
1009+
self.skipTest("Inductor-CPU codegen issue fixed in torch 2.5")
1010+
api = partial(
1011+
_int8da_int8w_api,
1012+
act_mapping_type=act_mapping,
9831013
)
1014+
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)
9841015

9851016
@parameterized.expand(COMMON_DEVICE_DTYPE)
9861017
@unittest.skipIf(is_fbcode(), "broken in fbcode")

test/quantization/test_observer.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from torchao.quantization.quant_primitives import (
2323
MappingType,
24+
ZeroPointDomain,
2425
)
2526

2627

@@ -74,7 +75,7 @@ def test_block_size_calc_success(self):
7475
eps=torch.finfo(torch.float32).eps,
7576
scale_dtype=torch.float,
7677
zero_point_dtype=torch.int,
77-
zero_point_domain=None,
78+
zero_point_domain=ZeroPointDomain.NONE,
7879
)
7980
example_inputs = [
8081
torch.randn(10, 2048),
@@ -93,7 +94,7 @@ def test_block_size_calc_success(self):
9394
eps=torch.finfo(torch.float32).eps,
9495
scale_dtype=torch.float,
9596
zero_point_dtype=torch.int,
96-
zero_point_domain=None,
97+
zero_point_domain=ZeroPointDomain.NONE,
9798
)
9899
for example_input in example_inputs:
99100
obs(example_input)
@@ -108,7 +109,7 @@ def test_block_size_row_errors(self):
108109
eps=torch.finfo(torch.float32).eps,
109110
scale_dtype=torch.float,
110111
zero_point_dtype=torch.int,
111-
zero_point_domain=None,
112+
zero_point_domain=ZeroPointDomain.NONE,
112113
)
113114
example_inputs = [
114115
torch.randn(10, 2048),
@@ -127,7 +128,7 @@ def test_block_size_row_errors(self):
127128
eps=torch.finfo(torch.float32).eps,
128129
scale_dtype=torch.float,
129130
zero_point_dtype=torch.int,
130-
zero_point_domain=None,
131+
zero_point_domain=ZeroPointDomain.NONE,
131132
)
132133
example_inputs = [
133134
torch.randn(10, 2048),
@@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
155156
eps=torch.finfo(torch.float32).eps,
156157
scale_dtype=torch.float,
157158
zero_point_dtype=torch.int,
158-
zero_point_domain=None,
159+
zero_point_domain=ZeroPointDomain.NONE,
159160
)
160161
if observe_weight:
161162
weight_observer = AffineQuantizedMinMaxObserver(
@@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
165166
eps=torch.finfo(torch.float32).eps,
166167
scale_dtype=torch.float,
167168
zero_point_dtype=torch.int,
168-
zero_point_domain=None,
169+
zero_point_domain=ZeroPointDomain.NONE,
169170
)
170171
else:
171172
weight_observer = None
@@ -199,7 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
199200
input_scale.item(),
200201
max_val / max_fp8,
201202
)
202-
self.assertIsNotNone(input_zero_point)
203+
self.assertIsNone(input_zero_point)
203204

204205
if observe_weight:
205206
weight_observer = linear.weight.weight_observer
@@ -210,7 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
210211
atol=5e-5,
211212
rtol=0.0,
212213
)
213-
self.assertIsNotNone(weight_zero_point)
214+
self.assertIsNone(weight_zero_point)
214215
else:
215216
self.assertIsNone(linear.weight.weight_observer)
216217

test/quantization/test_quant_primitives.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,55 @@ def test_fake_quantize_affine_cachemask(self):
843843
torch.testing.assert_close(dequantized, fake_quantized)
844844
torch.testing.assert_close(expected_mask, mask)
845845

846+
def test_none_zero_point_domain(self):
847+
"""A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should"""
848+
input = torch.randn(10, 256)
849+
mapping_type = MappingType.SYMMETRIC
850+
dtype = torch.int8
851+
block_size = (1, 128)
852+
quant_min = None
853+
quant_max = None
854+
eps = 1e-6
855+
scale_dtype = torch.float32
856+
zero_point_dtype = torch.int64
857+
try:
858+
_, zero_point = choose_qparams_affine(
859+
input,
860+
mapping_type,
861+
block_size,
862+
dtype,
863+
quant_min,
864+
quant_max,
865+
eps,
866+
scale_dtype=scale_dtype,
867+
zero_point_dtype=zero_point_dtype,
868+
preserve_zero=True,
869+
zero_point_domain=None,
870+
)
871+
except ValueError:
872+
# This exception was expected
873+
# Now test for ZeroPointDomain.NONE
874+
_, zero_point = choose_qparams_affine(
875+
input,
876+
mapping_type,
877+
block_size,
878+
dtype,
879+
quant_min,
880+
quant_max,
881+
eps,
882+
scale_dtype=scale_dtype,
883+
zero_point_dtype=zero_point_dtype,
884+
preserve_zero=True,
885+
zero_point_domain=ZeroPointDomain.NONE,
886+
)
887+
self.assertTrue(zero_point is None)
888+
else:
889+
# An exception should have been thrown for zero_point_domain None
890+
self.assertTrue(
891+
False,
892+
msg="A runtime exception should have been thrown for zero_point_domain None",
893+
)
894+
846895
@parameterized.expand(
847896
[
848897
(
@@ -890,7 +939,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
890939
quant_min=torch.finfo(float8_dtype).min,
891940
quant_max=torch.finfo(float8_dtype).max,
892941
zero_point=None,
893-
zero_point_domain=None,
942+
zero_point_domain=ZeroPointDomain.NONE,
894943
)
895944
expected_dequantized = dequantize_affine(
896945
expected_quantized,
@@ -901,7 +950,7 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
901950
quant_min=torch.finfo(float8_dtype).min,
902951
quant_max=torch.finfo(float8_dtype).max,
903952
zero_point=None,
904-
zero_point_domain=None,
953+
zero_point_domain=ZeroPointDomain.NONE,
905954
)
906955

907956
self.assertTrue(torch.equal(expected_scale, scale))

torchao/dtypes/affine_quantized_tensor.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def __new__(
8181
dtype=None,
8282
strides=None,
8383
):
84+
if zero_point_domain is None:
85+
raise ValueError("please use ZeroPointDomain.NONE instead of None")
8486
kwargs = {}
8587
kwargs["device"] = tensor_impl.device
8688
kwargs["layout"] = (
@@ -199,7 +201,7 @@ def from_hp_to_intx(
199201
scale_dtype: Optional[torch.dtype] = None,
200202
zero_point_dtype: Optional[torch.dtype] = None,
201203
preserve_zero: bool = True,
202-
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
204+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
203205
_layout: Layout = PlainLayout(),
204206
use_hqq: bool = False,
205207
):
@@ -258,8 +260,7 @@ def from_hp_to_intx(
258260
zero_point_domain,
259261
)
260262
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
261-
# TODO should probably consolidate ZeroPointDomain.NONE and None
262-
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
263+
if zero_point_domain == ZeroPointDomain.NONE:
263264
zero_point = None
264265
data = quantize_affine(
265266
input_float,
@@ -296,14 +297,15 @@ def from_hp_to_intx_static(
296297
target_dtype: torch.dtype,
297298
quant_min: Optional[int] = None,
298299
quant_max: Optional[int] = None,
299-
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
300+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
300301
_layout: Layout = PlainLayout(),
301302
):
302303
"""Create an integer AffineQuantizedTensor from a high precision tensor using static parameters."""
304+
if zero_point_domain is None:
305+
raise ValueError("please use ZeroPointDomain.NONE instead of None")
306+
elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None:
307+
raise ValueError("zero_point should be None when zero_point_domain is NONE")
303308
if target_dtype not in FP8_TYPES:
304-
assert (
305-
zero_point_domain is not None
306-
), "zero_point_domain must be specified for non-fp8 types"
307309
assert (
308310
zero_point is not None
309311
), "zero_point must be specified for non-fp8 types"
@@ -359,7 +361,7 @@ def from_hp_to_floatx(
359361
scale_dtype=scale_dtype,
360362
zero_point_dtype=None,
361363
preserve_zero=True,
362-
zero_point_domain=None,
364+
zero_point_domain=ZeroPointDomain.NONE,
363365
_layout=_layout,
364366
use_hqq=False,
365367
)
@@ -387,7 +389,7 @@ def from_hp_to_floatx_static(
387389
target_dtype=target_dtype,
388390
quant_min=math.ceil(torch.finfo(target_dtype).min),
389391
quant_max=math.ceil(torch.finfo(target_dtype).max),
390-
zero_point_domain=None,
392+
zero_point_domain=ZeroPointDomain.NONE,
391393
_layout=_layout,
392394
)
393395
else:

torchao/dtypes/uintx/marlin_qqq_tensor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ def from_hp_to_intx(
5454
block_size: Tuple[int, ...],
5555
quant_min: Optional[int] = None,
5656
quant_max: Optional[int] = None,
57-
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
57+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
5858
_layout: Optional[Layout] = None,
5959
):
6060
"""Converts a floating point tensor to a Marlin QQQ quantized tensor."""
61+
if zero_point_domain is None:
62+
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
6163
original_shape = input_float.shape
6264
input_float = _layout.pre_process(input_float)
6365
nbits = int(math.log2(quant_max - quant_min + 1))

torchao/quantization/observer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,12 @@ def __init__(
104104
scale_dtype: Optional[torch.dtype] = None,
105105
zero_point_dtype: Optional[torch.dtype] = None,
106106
preserve_zero: bool = True,
107-
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
107+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
108108
):
109109
super().__init__()
110110
assert granularity is not None, "granularity is None"
111-
111+
if zero_point_domain is None:
112+
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
112113
self.mapping_type = mapping_type
113114
self.target_dtype = target_dtype
114115
self.granularity = granularity

torchao/quantization/qat/affine_fake_quantized_tensor.py

+5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def forward(
4141
preserve_zero: bool = True,
4242
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
4343
) -> "AffineFakeQuantizedTensor":
44+
if zero_point_domain is None:
45+
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
46+
4447
def apply_fake_quant_fn(t: torch.Tensor):
4548
assert isinstance(t, AffineFakeQuantizedTensor)
4649
qmin, qmax = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
@@ -158,6 +161,8 @@ def from_float(
158161
preserve_zero: bool = True,
159162
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
160163
):
164+
if zero_point_domain is None:
165+
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
161166
return _ToAffineFakeQuantized.apply(
162167
original_input,
163168
mapping_type,

torchao/quantization/qat/api.py

+2
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def __init__(
9696
group_size: Optional[int] = None,
9797
is_symmetric: Optional[bool] = None,
9898
):
99+
if zero_point_domain is None:
100+
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
99101
self.dtype = dtype
100102
self.granularity = self._get_granularity(granularity, group_size)
101103
self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)

0 commit comments

Comments
 (0)