2121)
2222from torchao .dtypes .utils import AQTTensorImpl , Layout
2323from torchao .quantization .quant_primitives import (
24- ZeroPointDomain ,
2524 MappingType ,
25+ ZeroPointDomain ,
2626 choose_qparams_affine ,
2727 quantize_affine ,
2828)
29-
3029from torchao .utils import (
3130 TORCH_VERSION_AT_LEAST_2_6 ,
3231)
4140handler .setFormatter (formatter )
4241logger .addHandler (handler )
4342
43+
4444class Target (Enum ):
4545 """Enum that indicates the backend target"""
4646
4747 NATIVE = auto ()
4848 ATEN = auto ()
4949
50+
5051def target_from_str (target : str ) -> Target :
5152 if target .lower () == "native" :
5253 return Target .NATIVE
@@ -55,6 +56,7 @@ def target_from_str(target: str) -> Target:
5556 else :
5657 raise ValueError (f"Invalid target: { target } " )
5758
59+
5860class PackedLinearInt8DynamicActivationIntxWeightLayout (Layout ):
5961 bit_width : Optional [int ]
6062 group_size : Optional [int ]
@@ -157,7 +159,10 @@ def from_plain(
157159 ):
158160 assert isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout )
159161 assert layout .has_params_set (), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
160- assert layout .target in {Target .NATIVE , Target .ATEN }, f"Unexpected target: { layout .target } "
162+ assert layout .target in {
163+ Target .NATIVE ,
164+ Target .ATEN ,
165+ }, f"Unexpected target: { layout .target } "
161166
162167 # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor
163168 # when AOTI supports int
@@ -167,10 +172,14 @@ def from_plain(
167172 k_tensor = torch .empty (0 , k , dtype = torch .int8 )
168173
169174 if layout .target == Target .ATEN :
170- assert TORCH_VERSION_AT_LEAST_2_6 , f"aten target is requires torch version > 2.6.0"
175+ assert (
176+ TORCH_VERSION_AT_LEAST_2_6
177+ ), "aten target is requires torch version > 2.6.0"
171178 int_data = int_data .add (8 )
172- int_data = (int_data [::,1 ::2 ] << 4 | int_data [::,::2 ] ).to (torch .uint8 )
173- packed_weight = torch .ops .aten ._dyn_quant_pack_4bit_weight (int_data , scale , bias , layout .group_size , k , n )
179+ int_data = (int_data [::, 1 ::2 ] << 4 | int_data [::, ::2 ]).to (torch .uint8 )
180+ packed_weight = torch .ops .aten ._dyn_quant_pack_4bit_weight (
181+ int_data , scale , bias , layout .group_size , k , n
182+ )
174183 return cls (packed_weight , layout , group_size_tensor , n_tensor , k_tensor )
175184
176185 if layout .has_weight_zeros :
@@ -248,12 +257,11 @@ def __tensor_unflatten__(
248257def _linear_check (input_tensor , weight_tensor , bias ):
249258 layout = weight_tensor .tensor_impl .get_layout ()
250259 return isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout ) and (
251- bias is None or layout .target == Target .ATEN # Aten target allows bias
260+ bias is None or layout .target == Target .ATEN # Aten target allows bias
252261 )
253262
254263
255264def _linear_impl (input_tensor , weight_tensor , bias ):
256-
257265 def _impl_2d_native (input_tensor , weight_tensor ):
258266 assert input_tensor .dim () == 2
259267 assert weight_tensor .dim () == 2
@@ -299,14 +307,13 @@ def _impl_2d_aten(input_tensor, weight_tensor):
299307 group_size = weight_tensor .tensor_impl .get_layout ().group_size
300308 packed_weight = weight_tensor .tensor_impl .packed_weight
301309 return torch .ops .aten ._dyn_quant_matmul_4bit (
302- input_tensor , packed_weight , group_size , k_ , n )
310+ input_tensor , packed_weight , group_size , k_ , n
311+ )
303312
304313 target = weight_tensor .tensor_impl .get_layout ().target
305314
306315 if target == Target .ATEN :
307- assert (
308- TORCH_VERSION_AT_LEAST_2_6 == 1
309- ), "Target.ATEN requires torch >= 2.6.0"
316+ assert TORCH_VERSION_AT_LEAST_2_6 == 1 , "Target.ATEN requires torch >= 2.6.0"
310317 _impl_2d = _impl_2d_aten
311318 elif target == Target .NATIVE :
312319 _impl_2d = _impl_2d_native
@@ -327,6 +334,7 @@ def _impl_2d_aten(input_tensor, weight_tensor):
327334 res = res .reshape (* lead_shape , m , n )
328335 return res
329336
337+
330338register_aqt_quantized_linear_dispatch (
331339 _linear_check ,
332340 _linear_impl ,
@@ -354,12 +362,17 @@ def from_hp_to_intx(
354362 zero_point_domain : Optional [ZeroPointDomain ] = ZeroPointDomain .INT ,
355363 _layout : Layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
356364 use_hqq : bool = False ,
357- bias : Optional [torch .Tensor ] = None
365+ bias : Optional [torch .Tensor ] = None ,
358366 ):
359- assert use_hqq == False , f"PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
367+ assert (
368+ use_hqq == False
369+ ), "PackedLinearInt8DynamicActivationIntxWeightTensor can not support HQQ optimization"
360370 assert isinstance (
361- _layout , PackedLinearInt8DynamicActivationIntxWeightLayout ), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided { _layout } "
362- assert _layout .target == Target .ATEN , f"PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
371+ _layout , PackedLinearInt8DynamicActivationIntxWeightLayout
372+ ), f"PackedLinearInt8DynamicActivationIntxWeightTensor can only support PackedLinearInt8DynamicActivationIntxWeightLayout(). Provided { _layout } "
373+ assert (
374+ _layout .target == Target .ATEN
375+ ), "PackedLinearInt8DynamicActivationIntxWeightTensor requires target 'aten'."
363376 original_shape = input_float .shape
364377 input_float = _layout .pre_process (input_float )
365378
@@ -405,4 +418,7 @@ def from_hp_to_intx(
405418 dtype = input_float .dtype ,
406419 )
407420
408- to_packedlinearint8dynamicactivationintxweight_quantized_intx = PackedLinearInt8DynamicActivationIntxWeightAtenTensor .from_hp_to_intx
421+
422+ to_packedlinearint8dynamicactivationintxweight_quantized_intx = (
423+ PackedLinearInt8DynamicActivationIntxWeightAtenTensor .from_hp_to_intx
424+ )
0 commit comments