Skip to content

Commit 04d1186

Browse files
authored
Move Int8DynamicActivationIntxWeightConfig out of experimental (#1968)
* up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up
1 parent 9516764 commit 04d1186

12 files changed

+599
-463
lines changed

torchao/_models/llama/generate.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
write_json_result_ossci,
2121
)
2222
from torchao.quantization.quant_primitives import MappingType
23-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, get_model_size_in_bytes
23+
from torchao.utils import (
24+
TORCH_VERSION_AT_LEAST_2_5,
25+
TORCH_VERSION_AT_LEAST_2_6,
26+
get_model_size_in_bytes,
27+
)
2428

2529
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
2630
torch.backends.cuda.enable_cudnn_sdp(True)
@@ -553,26 +557,37 @@ def ffn_or_attn_only(mod, fqn):
553557
group_size = int(_quant_args[2])
554558
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
555559
elif "int8_dynamic_activation_intx_weight" in quantization:
556-
from torchao.experimental.quant_api import (
557-
int8_dynamic_activation_intx_weight,
558-
)
559-
from torchao.quantization.granularity import PerGroup
560-
560+
assert (
561+
TORCH_VERSION_AT_LEAST_2_6
562+
), "int8_dynamic_activation_intx_weight requires torch2.6+"
561563
assert (
562564
precision == torch.float32
563565
), "int8_dynamic_activation_intx_weight requires using precision=torch.float32"
564566

567+
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
568+
from torchao.quantization.granularity import PerAxis, PerGroup
569+
from torchao.quantization.quant_api import (
570+
Int8DynamicActivationIntxWeightConfig,
571+
ZeroPointDomain,
572+
)
573+
565574
# Quantize model
566575
_quant_args = quantization.split("-")
567576
weight_dtype = getattr(torch, f"int{_quant_args[1]}")
568-
granularity = PerGroup(int(_quant_args[2]))
577+
group_size = int(_quant_args[2])
578+
granularity = PerGroup(group_size) if group_size > 0 else PerAxis(0)
569579
has_weight_zeros = bool(_quant_args[3])
570580
quantize_(
571581
model,
572-
int8_dynamic_activation_intx_weight(
582+
Int8DynamicActivationIntxWeightConfig(
573583
weight_dtype=weight_dtype,
574-
granularity=granularity,
575-
has_weight_zeros=has_weight_zeros,
584+
weight_granularity=granularity,
585+
weight_zero_point_domain=ZeroPointDomain.INT
586+
if has_weight_zeros
587+
else ZeroPointDomain.NONE,
588+
weight_mapping_type=MappingType.ASYMMETRIC,
589+
weight_scale_dtype=torch.bfloat16,
590+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
576591
),
577592
)
578593
elif "float8wo" in quantization:

torchao/dtypes/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
SemiSparseLayout,
2626
TensorCoreTiledLayout,
2727
UintxLayout,
28-
to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight,
2928
to_marlinqqq_quantized_intx,
3029
)
3130
from .utils import (

torchao/dtypes/uintx/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818
from .packed_linear_int8_dynamic_activation_intx_weight_layout import (
1919
PackedLinearInt8DynamicActivationIntxWeightLayout,
20-
to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight,
2120
)
2221
from .q_dq_layout import (
2322
QDQLayout,
@@ -43,7 +42,6 @@
4342
"MarlinQQQTensor",
4443
"to_marlinqqq_quantized_intx",
4544
"CutlassInt4PackedLayout",
46-
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
4745
"PackedLinearInt8DynamicActivationIntxWeightLayout",
4846
"QDQLayout",
4947
]

0 commit comments

Comments
 (0)