|
20 | 20 | write_json_result_ossci,
|
21 | 21 | )
|
22 | 22 | from torchao.quantization.quant_primitives import MappingType
|
23 |
| -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, get_model_size_in_bytes |
| 23 | +from torchao.utils import ( |
| 24 | + TORCH_VERSION_AT_LEAST_2_5, |
| 25 | + TORCH_VERSION_AT_LEAST_2_6, |
| 26 | + get_model_size_in_bytes, |
| 27 | +) |
24 | 28 |
|
25 | 29 | torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
|
26 | 30 | torch.backends.cuda.enable_cudnn_sdp(True)
|
@@ -553,26 +557,37 @@ def ffn_or_attn_only(mod, fqn):
|
553 | 557 | group_size = int(_quant_args[2])
|
554 | 558 | quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
|
555 | 559 | elif "int8_dynamic_activation_intx_weight" in quantization:
|
556 |
| - from torchao.experimental.quant_api import ( |
557 |
| - int8_dynamic_activation_intx_weight, |
558 |
| - ) |
559 |
| - from torchao.quantization.granularity import PerGroup |
560 |
| - |
| 560 | + assert ( |
| 561 | + TORCH_VERSION_AT_LEAST_2_6 |
| 562 | + ), "int8_dynamic_activation_intx_weight requires torch2.6+" |
561 | 563 | assert (
|
562 | 564 | precision == torch.float32
|
563 | 565 | ), "int8_dynamic_activation_intx_weight requires using precision=torch.float32"
|
564 | 566 |
|
| 567 | + from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout |
| 568 | + from torchao.quantization.granularity import PerAxis, PerGroup |
| 569 | + from torchao.quantization.quant_api import ( |
| 570 | + Int8DynamicActivationIntxWeightConfig, |
| 571 | + ZeroPointDomain, |
| 572 | + ) |
| 573 | + |
565 | 574 | # Quantize model
|
566 | 575 | _quant_args = quantization.split("-")
|
567 | 576 | weight_dtype = getattr(torch, f"int{_quant_args[1]}")
|
568 |
| - granularity = PerGroup(int(_quant_args[2])) |
| 577 | + group_size = int(_quant_args[2]) |
| 578 | + granularity = PerGroup(group_size) if group_size > 0 else PerAxis(0) |
569 | 579 | has_weight_zeros = bool(_quant_args[3])
|
570 | 580 | quantize_(
|
571 | 581 | model,
|
572 |
| - int8_dynamic_activation_intx_weight( |
| 582 | + Int8DynamicActivationIntxWeightConfig( |
573 | 583 | weight_dtype=weight_dtype,
|
574 |
| - granularity=granularity, |
575 |
| - has_weight_zeros=has_weight_zeros, |
| 584 | + weight_granularity=granularity, |
| 585 | + weight_zero_point_domain=ZeroPointDomain.INT |
| 586 | + if has_weight_zeros |
| 587 | + else ZeroPointDomain.NONE, |
| 588 | + weight_mapping_type=MappingType.ASYMMETRIC, |
| 589 | + weight_scale_dtype=torch.bfloat16, |
| 590 | + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), |
576 | 591 | ),
|
577 | 592 | )
|
578 | 593 | elif "float8wo" in quantization:
|
|
0 commit comments