@@ -495,6 +495,7 @@ def quantize(self, model: nn.Module) -> nn.Module:
495
495
from torchao .experimental .packed_linear_int8_dynamic_activation_intx_weight_layout import (
496
496
PackedLinearInt8DynamicActivationIntxWeightLayout ,
497
497
to_packedlinearint8dynamicactivationintxweight_quantized_intx ,
498
+ Target ,
498
499
)
499
500
from torchao .quantization .linear_activation_quantized_tensor import (
500
501
to_linear_activation_quantized ,
@@ -512,10 +513,9 @@ def int8_dynamic_activation_intx_weight(
512
513
weight_dtype : torch .dtype = torch .int4 ,
513
514
granularity : Union [PerRow , PerGroup ] = PerGroup (128 ),
514
515
has_weight_zeros : bool = False ,
515
- target : str = "native" ,
516
516
weight_mapping_type = MappingType .ASYMMETRIC ,
517
517
act_mapping_type = MappingType .ASYMMETRIC ,
518
- layout = PackedLinearInt8DynamicActivationIntxWeightLayout (), # PlainLayout() also works, but will be slow
518
+ layout = PackedLinearInt8DynamicActivationIntxWeightLayout (target = "native" ), # PlainLayout() also works, but will be slow
519
519
):
520
520
"""
521
521
Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers.
@@ -539,19 +539,16 @@ def int8_dynamic_activation_intx_weight(
539
539
- act_mapping_type must be MappingType.ASYMMETRIC
540
540
"""
541
541
542
- if target == "aten" :
543
- if not isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout ) or \
544
- weight_dtype != torch .int4 or \
545
- has_weight_zeros != True or \
546
- weight_mapping_type != MappingType .SYMMETRIC :
547
- raise NotImplementedError (
548
- f"target 'aten' requires:\n "
549
- f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n "
550
- f"- has_weight_zeros to be True,\n "
551
- f"- weight_dtype to be torch.int4,\n "
552
- f"- weight_mapping_type to be MappingType.SYMMETRIC"
542
+ def is_torchao_op_skippable (layout ):
543
+ return (
544
+ isinstance (layout , PlainLayout ) or
545
+ (
546
+ isinstance (layout , PackedLinearInt8DynamicActivationIntxWeightLayout ) and
547
+ layout .target == Target .ATEN
553
548
)
554
- elif not isinstance (layout , PlainLayout ):
549
+ )
550
+
551
+ if not is_torchao_op_skippable (layout ):
555
552
try :
556
553
torch .ops .torchao ._pack_8bit_act_4bit_weight
557
554
except AttributeError :
@@ -577,7 +574,7 @@ def int8_dynamic_activation_intx_weight(
577
574
)
578
575
bit_width = dtype_to_bit_width [weight_dtype ]
579
576
layout_arg = layout
580
- propagate_bias = isinstance (layout_arg , PackedLinearInt8DynamicActivationIntxWeightLayout ) and layout_arg .target == "aten"
577
+ propagate_bias = isinstance (layout_arg , PackedLinearInt8DynamicActivationIntxWeightLayout ) and layout_arg .target == Target . ATEN
581
578
582
579
def apply (weight , bias : Optional [torch .Tensor ] = None ):
583
580
if isinstance (granularity , PerGroup ):
@@ -612,13 +609,23 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
612
609
bit_width = bit_width ,
613
610
group_size = group_size ,
614
611
has_weight_zeros = has_weight_zeros ,
615
- target = target ,
612
+ target = "aten" if layout . target == Target . ATEN else "native" ,
616
613
)
617
- if target == "aten" :
614
+ if layout .target == Target .ATEN :
615
+ if weight_dtype != torch .int4 or \
616
+ has_weight_zeros != True or \
617
+ weight_mapping_type != MappingType .SYMMETRIC :
618
+ raise NotImplementedError (
619
+ f"target 'aten' requires:\n "
620
+ f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n "
621
+ f"- has_weight_zeros to be True,\n "
622
+ f"- weight_dtype to be torch.int4,\n "
623
+ f"- weight_mapping_type to be MappingType.SYMMETRIC"
624
+ )
618
625
assert TORCH_VERSION_AT_LEAST_2_6 , f"aten target is requires torch version > 2.6.0"
619
626
if torch .backends .kleidiai .is_available ():
620
627
if isinstance (granularity , PerGroup ):
621
- scale_dtype = torch .bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
628
+ scale_dtype = torch .bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
622
629
tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx
623
630
624
631
quantizer_args = [weight ,
0 commit comments