36
36
SemiSparseLayout ,
37
37
TensorCoreTiledLayout ,
38
38
UintxLayout ,
39
- to_affine_quantized_floatx ,
39
+ to_affine_quantized_float8 ,
40
40
to_affine_quantized_floatx_static ,
41
41
to_affine_quantized_intx ,
42
42
to_marlinqqq_quantized_intx ,
66
66
Int8DynActInt4WeightGPTQQuantizer ,
67
67
Int8DynActInt4WeightQuantizer ,
68
68
)
69
- from .granularity import (
70
- PerRow ,
71
- PerTensor ,
72
- )
69
+ from .granularity import PerRow , PerTensor
73
70
from .linear_activation_quantized_tensor import (
74
71
LinearActivationQuantizedTensor ,
75
72
to_linear_activation_quantized ,
76
73
)
77
- from .qat import (
78
- intx_quantization_aware_training ,
79
- )
80
- from .quant_primitives import (
81
- MappingType ,
82
- ZeroPointDomain ,
83
- )
74
+ from .qat import intx_quantization_aware_training
75
+ from .quant_primitives import MappingType , ZeroPointDomain
84
76
from .subclass import (
85
77
Int4WeightOnlyQuantizedLinearWeight ,
86
78
Int8DynamicallyQuantizedLinearWeight ,
@@ -915,10 +907,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
915
907
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
916
908
quantization + 2:4 sparsity to linear layers.
917
909
"""
918
- warnings .warn ("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
910
+ warnings .warn (
911
+ """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
919
912
920
913
from torchao.dtypes import SemiSparseLayout
921
- int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" )
914
+ int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
915
+ )
922
916
923
917
return int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ())
924
918
@@ -934,15 +928,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
934
928
The actual matmul will be computed in original precision of the weight tensor.
935
929
936
930
"""
937
- from torchao .dtypes import to_affine_quantized_floatx
938
931
939
932
def apply_float8wo_quant (weight ):
940
933
block_size = (1 , weight .shape [1 ])
941
- return to_affine_quantized_floatx (
934
+ return to_affine_quantized_float8 (
942
935
input_float = weight ,
943
936
block_size = block_size ,
944
937
target_dtype = weight_dtype ,
945
- scale_dtype = None ,
946
938
_layout = Float8Layout (mm_config = None ),
947
939
)
948
940
@@ -1016,11 +1008,10 @@ def _input_activation_quant_func_fp8(
1016
1008
1017
1009
block_size = get_block_size (x .shape , activation_granularity )
1018
1010
if scale is None :
1019
- activation = to_affine_quantized_floatx (
1011
+ activation = to_affine_quantized_float8 (
1020
1012
input_float = x ,
1021
1013
block_size = block_size ,
1022
1014
target_dtype = activation_dtype ,
1023
- scale_dtype = torch .float32 ,
1024
1015
_layout = Float8Layout (mm_config = None ), # Config is stored on weight
1025
1016
)
1026
1017
else :
@@ -1102,11 +1093,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
1102
1093
), "PerRow quantization only works for bfloat16 precision input weight"
1103
1094
1104
1095
block_size = get_block_size (weight .shape , weight_granularity )
1105
- quantized_weight = to_affine_quantized_floatx (
1096
+ quantized_weight = to_affine_quantized_float8 (
1106
1097
input_float = weight ,
1107
1098
block_size = block_size ,
1108
1099
target_dtype = weight_dtype ,
1109
- scale_dtype = torch .float32 ,
1110
1100
_layout = Float8Layout (mm_config = mm_config ),
1111
1101
)
1112
1102
@@ -1157,11 +1147,10 @@ def apply_float8_static_activation_quant(weight: torch.Tensor):
1157
1147
if not _fp8_mm_compat (weight ):
1158
1148
return weight
1159
1149
block_size = get_block_size (weight .shape , weight_granularity )
1160
- quantized_weight = to_affine_quantized_floatx (
1150
+ quantized_weight = to_affine_quantized_float8 (
1161
1151
input_float = weight ,
1162
1152
block_size = block_size ,
1163
1153
target_dtype = weight_dtype ,
1164
- scale_dtype = torch .float32 ,
1165
1154
_layout = Float8Layout (mm_config = mm_config ),
1166
1155
)
1167
1156
0 commit comments