Skip to content

Commit ca8a5f1

Browse files
committed
[Refactor]: Move target attribute to Layout Class & fix target checks
Signed-off-by: Nikhil Gupta <[email protected]>
1 parent 8fdf6a9 commit ca8a5f1

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

torchao/experimental/quant_api.py

+25-18
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ def quantize(self, model: nn.Module) -> nn.Module:
495495
from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import (
496496
PackedLinearInt8DynamicActivationIntxWeightLayout,
497497
to_packedlinearint8dynamicactivationintxweight_quantized_intx,
498+
Target,
498499
)
499500
from torchao.quantization.linear_activation_quantized_tensor import (
500501
to_linear_activation_quantized,
@@ -512,10 +513,9 @@ def int8_dynamic_activation_intx_weight(
512513
weight_dtype: torch.dtype = torch.int4,
513514
granularity: Union[PerRow, PerGroup] = PerGroup(128),
514515
has_weight_zeros: bool = False,
515-
target: str = "native",
516516
weight_mapping_type=MappingType.ASYMMETRIC,
517517
act_mapping_type=MappingType.ASYMMETRIC,
518-
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow
518+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="native"), # PlainLayout() also works, but will be slow
519519
):
520520
"""
521521
Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers.
@@ -539,19 +539,16 @@ def int8_dynamic_activation_intx_weight(
539539
- act_mapping_type must be MappingType.ASYMMETRIC
540540
"""
541541

542-
if target == "aten":
543-
if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) or \
544-
weight_dtype != torch.int4 or \
545-
has_weight_zeros != True or \
546-
weight_mapping_type != MappingType.SYMMETRIC:
547-
raise NotImplementedError(
548-
f"target 'aten' requires:\n"
549-
f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
550-
f"- has_weight_zeros to be True,\n"
551-
f"- weight_dtype to be torch.int4,\n"
552-
f"- weight_mapping_type to be MappingType.SYMMETRIC"
542+
def is_torchao_op_skippable(layout):
543+
return (
544+
isinstance(layout, PlainLayout) or
545+
(
546+
isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and
547+
layout.target == Target.ATEN
553548
)
554-
elif not isinstance(layout, PlainLayout):
549+
)
550+
551+
if not is_torchao_op_skippable(layout):
555552
try:
556553
torch.ops.torchao._pack_8bit_act_4bit_weight
557554
except AttributeError:
@@ -577,7 +574,7 @@ def int8_dynamic_activation_intx_weight(
577574
)
578575
bit_width = dtype_to_bit_width[weight_dtype]
579576
layout_arg = layout
580-
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target=="aten"
577+
propagate_bias = isinstance(layout_arg, PackedLinearInt8DynamicActivationIntxWeightLayout) and layout_arg.target == Target.ATEN
581578

582579
def apply(weight, bias: Optional[torch.Tensor] = None):
583580
if isinstance(granularity, PerGroup):
@@ -612,13 +609,23 @@ def apply(weight, bias: Optional[torch.Tensor] = None):
612609
bit_width=bit_width,
613610
group_size=group_size,
614611
has_weight_zeros=has_weight_zeros,
615-
target=target,
612+
target="aten" if layout.target == Target.ATEN else "native",
616613
)
617-
if target == "aten":
614+
if layout.target == Target.ATEN:
615+
if weight_dtype != torch.int4 or \
616+
has_weight_zeros != True or \
617+
weight_mapping_type != MappingType.SYMMETRIC:
618+
raise NotImplementedError(
619+
f"target 'aten' requires:\n"
620+
f"- layout to be PackedLinearInt8DynamicActivationIntxWeightLayout,\n"
621+
f"- has_weight_zeros to be True,\n"
622+
f"- weight_dtype to be torch.int4,\n"
623+
f"- weight_mapping_type to be MappingType.SYMMETRIC"
624+
)
618625
assert TORCH_VERSION_AT_LEAST_2_6, f"aten target is requires torch version > 2.6.0"
619626
if torch.backends.kleidiai.is_available():
620627
if isinstance(granularity, PerGroup):
621-
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
628+
scale_dtype = torch.bfloat16 # KleidiAI kernel requires bfloat16 scale_dtype
622629
tensor_quantizer = to_packedlinearint8dynamicactivationintxweight_quantized_intx
623630

624631
quantizer_args = [weight,

0 commit comments

Comments
 (0)