21
21
)
22
22
from torchao .quantization .quant_primitives import (
23
23
MappingType ,
24
+ ZeroPointDomain ,
24
25
)
25
26
26
27
@@ -74,7 +75,7 @@ def test_block_size_calc_success(self):
74
75
eps = torch .finfo (torch .float32 ).eps ,
75
76
scale_dtype = torch .float ,
76
77
zero_point_dtype = torch .int ,
77
- zero_point_domain = None ,
78
+ zero_point_domain = ZeroPointDomain . NONE ,
78
79
)
79
80
example_inputs = [
80
81
torch .randn (10 , 2048 ),
@@ -93,7 +94,7 @@ def test_block_size_calc_success(self):
93
94
eps = torch .finfo (torch .float32 ).eps ,
94
95
scale_dtype = torch .float ,
95
96
zero_point_dtype = torch .int ,
96
- zero_point_domain = None ,
97
+ zero_point_domain = ZeroPointDomain . NONE ,
97
98
)
98
99
for example_input in example_inputs :
99
100
obs (example_input )
@@ -108,7 +109,7 @@ def test_block_size_row_errors(self):
108
109
eps = torch .finfo (torch .float32 ).eps ,
109
110
scale_dtype = torch .float ,
110
111
zero_point_dtype = torch .int ,
111
- zero_point_domain = None ,
112
+ zero_point_domain = ZeroPointDomain . NONE ,
112
113
)
113
114
example_inputs = [
114
115
torch .randn (10 , 2048 ),
@@ -127,7 +128,7 @@ def test_block_size_row_errors(self):
127
128
eps = torch .finfo (torch .float32 ).eps ,
128
129
scale_dtype = torch .float ,
129
130
zero_point_dtype = torch .int ,
130
- zero_point_domain = None ,
131
+ zero_point_domain = ZeroPointDomain . NONE ,
131
132
)
132
133
example_inputs = [
133
134
torch .randn (10 , 2048 ),
@@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
155
156
eps = torch .finfo (torch .float32 ).eps ,
156
157
scale_dtype = torch .float ,
157
158
zero_point_dtype = torch .int ,
158
- zero_point_domain = None ,
159
+ zero_point_domain = ZeroPointDomain . NONE ,
159
160
)
160
161
if observe_weight :
161
162
weight_observer = AffineQuantizedMinMaxObserver (
@@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
165
166
eps = torch .finfo (torch .float32 ).eps ,
166
167
scale_dtype = torch .float ,
167
168
zero_point_dtype = torch .int ,
168
- zero_point_domain = None ,
169
+ zero_point_domain = ZeroPointDomain . NONE ,
169
170
)
170
171
else :
171
172
weight_observer = None
@@ -199,7 +200,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
199
200
input_scale .item (),
200
201
max_val / max_fp8 ,
201
202
)
202
- self .assertIsNotNone (input_zero_point )
203
+ self .assertIsNone (input_zero_point )
203
204
204
205
if observe_weight :
205
206
weight_observer = linear .weight .weight_observer
@@ -210,7 +211,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
210
211
atol = 5e-5 ,
211
212
rtol = 0.0 ,
212
213
)
213
- self .assertIsNotNone (weight_zero_point )
214
+ self .assertIsNone (weight_zero_point )
214
215
else :
215
216
self .assertIsNone (linear .weight .weight_observer )
216
217
0 commit comments