21
21
)
22
22
from torchao .dtypes .utils import AQTTensorImpl , Layout
23
23
from torchao .quantization .quant_primitives import (
24
- ZeroPointDomain ,
25
24
MappingType ,
25
+ ZeroPointDomain ,
26
26
choose_qparams_affine ,
27
27
quantize_affine ,
28
28
)
29
-
30
29
from torchao .utils import (
31
30
TORCH_VERSION_AT_LEAST_2_6 ,
32
31
)
41
40
handler .setFormatter (formatter )
42
41
logger .addHandler (handler )
43
42
43
+
44
44
class Target (Enum ):
45
45
"""Enum that indicates the backend target"""
46
46
47
47
NATIVE = auto ()
48
48
ATEN = auto ()
49
49
50
+
50
51
def target_from_str (target : str ) -> Target :
51
52
if target .lower () == "native" :
52
53
return Target .NATIVE
@@ -55,6 +56,7 @@ def target_from_str(target: str) -> Target:
55
56
else :
56
57
raise ValueError (f"Invalid target: { target } " )
57
58
59
+
58
60
class PackedLinearInt8DynamicActivationIntxWeightLayout (Layout ):
59
61
bit_width : Optional [int ]
60
62
group_size : Optional [int ]
@@ -157,7 +159,10 @@ def from_plain(
157
159
):
158
160
assert isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
159
161
assert layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
160
- assert layout .target in {Target .NATIVE , Target .ATEN }, f"Unexpected target: { layout .target } "
162
+ assert layout .target in {
163
+ Target .NATIVE ,
164
+ Target .ATEN ,
165
+ }, f"Unexpected target: { layout .target } "
161
166
162
167
# TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
163
168
# when AOTI supports int
@@ -167,10 +172,14 @@ def from_plain(
167
172
k_tensor = torch .empty (0 , k , dtype = torch .int8 )
168
173
169
174
if layout .target == Target .ATEN :
170
- assert TORCH_VERSION_AT_LEAST_2_6 , f"aten target is requires torch version > 2.6.0"
175
+ assert (
176
+ TORCH_VERSION_AT_LEAST_2_6
177
+ ), "aten target is requires torch version > 2.6.0"
171
178
int_data = int_data .add (8 )
172
- int_data = (int_data [::,1 ::2 ] << 4 | int_data [::,::2 ] ).to (torch .uint8 )
173
- packed_weight = torch .ops .aten ._dyn_quant_pack_4bit_weight (int_data , scale , bias , layout .group_size , k , n )
179
+ int_data = (int_data [::, 1 ::2 ] << 4 | int_data [::, ::2 ]).to (torch .uint8 )
180
+ packed_weight = torch .ops .aten ._dyn_quant_pack_4bit_weight (
181
+ int_data , scale , bias , layout .group_size , k , n
182
+ )
174
183
return cls (packed_weight , layout , group_size_tensor , n_tensor , k_tensor )
175
184
176
185
if layout .has_weight_zeros :
@@ -248,12 +257,11 @@ def __tensor_unflatten__(
248
257
def _linear_check (input_tensor , weight_tensor , bias ):
249
258
layout = weight_tensor .tensor_impl .get_layout ()
250
259
return isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout ) and (
251
- bias is None or layout .target == Target .ATEN # Aten target allows bias
260
+ bias is None or layout .target == Target .ATEN # Aten target allows bias
252
261
)
253
262
254
263
255
264
def _linear_impl (input_tensor , weight_tensor , bias ):
256
-
257
265
def _impl_2d_native (input_tensor , weight_tensor ):
258
266
assert input_tensor .dim () == 2
259
267
assert weight_tensor .dim () == 2
@@ -299,14 +307,13 @@ def _impl_2d_aten(input_tensor, weight_tensor):
299
307
group_size = weight_tensor .tensor_impl .get_layout ().group_size
300
308
packed_weight = weight_tensor .tensor_impl .packed_weight
301
309
return torch .ops .aten ._dyn_quant_matmul_4bit (
302
- input_tensor , packed_weight , group_size , k_ , n )
310
+ input_tensor , packed_weight , group_size , k_ , n
311
+ )
303
312
304
313
target = weight_tensor .tensor_impl .get_layout ().target
305
314
306
315
if target == Target .ATEN :
307
- assert (
308
- TORCH_VERSION_AT_LEAST_2_6 == 1
309
- ), "Target.ATEN requires torch >= 2.6.0"
316
+ assert TORCH_VERSION_AT_LEAST_2_6 == 1 , "Target.ATEN requires torch >= 2.6.0"
310
317
_impl_2d = _impl_2d_aten
311
318
elif target == Target .NATIVE :
312
319
_impl_2d = _impl_2d_native
@@ -327,6 +334,7 @@ def _impl_2d_aten(input_tensor, weight_tensor):
327
334
res = res .reshape (* lead_shape , m , n )
328
335
return res
329
336
337
+
330
338
register_aqt_quantized_linear_dispatch (
331
339
_linear_check ,
332
340
_linear_impl ,
@@ -354,12 +362,17 @@ def from_hp_to_intx(
354
362
zero_point_domain : Optional [ZeroPointDomain ] = ZeroPointDomain .INT ,
355
363
_layout : Layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
356
364
use_hqq : bool = False ,
357
- bias : Optional [torch .Tensor ] = None
365
+ bias : Optional [torch .Tensor ] = None ,
358
366
):
359
- assert use_hqq == False , f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
367
+ assert (
368
+ use_hqq == False
369
+ ), "PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
360
370
assert isinstance (
361
- _layout , PackedLinearInt8DynamicActivationIntxWeightLayout ), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided { _layout } "
362
- assert _layout .target == Target .ATEN , f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
371
+ _layout , PackedLinearInt8DynamicActivationIntxWeightLayout
372
+ ), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided { _layout } "
373
+ assert (
374
+ _layout .target == Target .ATEN
375
+ ), "PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
363
376
original_shape = input_float .shape
364
377
input_float = _layout .pre_process (input_float )
365
378
@@ -405,4 +418,7 @@ def from_hp_to_intx(
405
418
dtype = input_float .dtype ,
406
419
)
407
420
408
- to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor .from_hp_to_intx
421
+
422
+ to_packedlinearint8dynamicactivationintxweight_quantized_intx = (
423
+ PackedLinearInt8DynamicActivationIntxWeightAtenTensor .from_hp_to_intx
424
+ )
0 commit comments