36
36
SemiSparseLayout ,
37
37
TensorCoreTiledLayout ,
38
38
UintxLayout ,
39
+ to_affine_quantized_float8 ,
39
40
to_affine_quantized_floatx ,
40
41
to_affine_quantized_floatx_static ,
41
42
to_affine_quantized_intx ,
66
67
Int8DynActInt4WeightGPTQQuantizer ,
67
68
Int8DynActInt4WeightQuantizer ,
68
69
)
69
- from .granularity import (
70
- PerRow ,
71
- PerTensor ,
72
- )
70
+ from .granularity import PerRow , PerTensor
73
71
from .linear_activation_quantized_tensor import (
74
72
LinearActivationQuantizedTensor ,
75
73
to_linear_activation_quantized ,
76
74
)
77
- from .qat import (
78
- intx_quantization_aware_training ,
79
- )
80
- from .quant_primitives import (
81
- MappingType ,
82
- ZeroPointDomain ,
83
- )
75
+ from .qat import intx_quantization_aware_training
76
+ from .quant_primitives import MappingType , ZeroPointDomain
84
77
from .subclass import (
85
78
Int4WeightOnlyQuantizedLinearWeight ,
86
79
Int8DynamicallyQuantizedLinearWeight ,
@@ -915,10 +908,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
915
908
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
916
909
quantization + 2:4 sparsity to linear layers.
917
910
"""
918
- warnings .warn ("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
911
+ warnings .warn (
912
+ """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
919
913
920
914
from torchao.dtypes import SemiSparseLayout
921
- int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" )
915
+ int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
916
+ )
922
917
923
918
return int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ())
924
919
@@ -934,15 +929,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
934
929
The actual matmul will be computed in original precision of the weight tensor.
935
930
936
931
"""
937
- from torchao .dtypes import to_affine_quantized_floatx
938
932
939
933
def apply_float8wo_quant (weight ):
940
934
block_size = (1 , weight .shape [1 ])
941
- return to_affine_quantized_floatx (
935
+ return to_affine_quantized_float8 (
942
936
input_float = weight ,
943
937
block_size = block_size ,
944
938
target_dtype = weight_dtype ,
945
- scale_dtype = None ,
946
939
_layout = Float8Layout (mm_config = None ),
947
940
)
948
941
@@ -1016,11 +1009,10 @@ def _input_activation_quant_func_fp8(
1016
1009
1017
1010
block_size = get_block_size (x .shape , activation_granularity )
1018
1011
if scale is None :
1019
- activation = to_affine_quantized_floatx (
1012
+ activation = to_affine_quantized_float8 (
1020
1013
input_float = x ,
1021
1014
block_size = block_size ,
1022
1015
target_dtype = activation_dtype ,
1023
- scale_dtype = torch .float32 ,
1024
1016
_layout = Float8Layout (mm_config = None ), # Config is stored on weight
1025
1017
)
1026
1018
else :
@@ -1102,11 +1094,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
1102
1094
), "PerRow quantization only works for bfloat16 precision input weight"
1103
1095
1104
1096
block_size = get_block_size (weight .shape , weight_granularity )
1105
- quantized_weight = to_affine_quantized_floatx (
1097
+ quantized_weight = to_affine_quantized_float8 (
1106
1098
input_float = weight ,
1107
1099
block_size = block_size ,
1108
1100
target_dtype = weight_dtype ,
1109
- scale_dtype = torch .float32 ,
1110
1101
_layout = Float8Layout (mm_config = mm_config ),
1111
1102
)
1112
1103
0 commit comments