Skip to content

Commit 1b35144

Browse files
replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs
ghstack-source-id: 059b697 ghstack-comment-id: 2608105249 Pull Request resolved: #1599
1 parent 26a0a50 commit 1b35144

File tree

1 file changed

+11
-20
lines changed

1 file changed

+11
-20
lines changed

torchao/quantization/quant_api.py

+11-20
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
SemiSparseLayout,
3737
TensorCoreTiledLayout,
3838
UintxLayout,
39+
to_affine_quantized_float8,
3940
to_affine_quantized_floatx,
4041
to_affine_quantized_floatx_static,
4142
to_affine_quantized_intx,
@@ -66,21 +67,13 @@
6667
Int8DynActInt4WeightGPTQQuantizer,
6768
Int8DynActInt4WeightQuantizer,
6869
)
69-
from .granularity import (
70-
PerRow,
71-
PerTensor,
72-
)
70+
from .granularity import PerRow, PerTensor
7371
from .linear_activation_quantized_tensor import (
7472
LinearActivationQuantizedTensor,
7573
to_linear_activation_quantized,
7674
)
77-
from .qat import (
78-
intx_quantization_aware_training,
79-
)
80-
from .quant_primitives import (
81-
MappingType,
82-
ZeroPointDomain,
83-
)
75+
from .qat import intx_quantization_aware_training
76+
from .quant_primitives import MappingType, ZeroPointDomain
8477
from .subclass import (
8578
Int4WeightOnlyQuantizedLinearWeight,
8679
Int8DynamicallyQuantizedLinearWeight,
@@ -915,10 +908,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
915908
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
916909
quantization + 2:4 sparsity to linear layers.
917910
"""
918-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
911+
warnings.warn(
912+
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
919913
920914
from torchao.dtypes import SemiSparseLayout
921-
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
915+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
916+
)
922917

923918
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
924919

@@ -934,15 +929,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
934929
The actual matmul will be computed in original precision of the weight tensor.
935930
936931
"""
937-
from torchao.dtypes import to_affine_quantized_floatx
938932

939933
def apply_float8wo_quant(weight):
940934
block_size = (1, weight.shape[1])
941-
return to_affine_quantized_floatx(
935+
return to_affine_quantized_float8(
942936
input_float=weight,
943937
block_size=block_size,
944938
target_dtype=weight_dtype,
945-
scale_dtype=None,
946939
_layout=Float8Layout(mm_config=None),
947940
)
948941

@@ -1016,11 +1009,10 @@ def _input_activation_quant_func_fp8(
10161009

10171010
block_size = get_block_size(x.shape, activation_granularity)
10181011
if scale is None:
1019-
activation = to_affine_quantized_floatx(
1012+
activation = to_affine_quantized_float8(
10201013
input_float=x,
10211014
block_size=block_size,
10221015
target_dtype=activation_dtype,
1023-
scale_dtype=torch.float32,
10241016
_layout=Float8Layout(mm_config=None), # Config is stored on weight
10251017
)
10261018
else:
@@ -1102,11 +1094,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
11021094
), "PerRow quantization only works for bfloat16 precision input weight"
11031095

11041096
block_size = get_block_size(weight.shape, weight_granularity)
1105-
quantized_weight = to_affine_quantized_floatx(
1097+
quantized_weight = to_affine_quantized_float8(
11061098
input_float=weight,
11071099
block_size=block_size,
11081100
target_dtype=weight_dtype,
1109-
scale_dtype=torch.float32,
11101101
_layout=Float8Layout(mm_config=mm_config),
11111102
)
11121103

0 commit comments

Comments
 (0)