35
35
PlainLayout ,
36
36
SemiSparseLayout ,
37
37
TensorCoreTiledLayout ,
38
- UintxLayout ,
38
+ to_affine_quantized_float8 ,
39
39
to_affine_quantized_floatx ,
40
40
to_affine_quantized_floatx_static ,
41
41
to_affine_quantized_intx ,
42
42
to_marlinqqq_quantized_intx ,
43
+ UintxLayout ,
43
44
)
44
45
from torchao .float8 .float8_linear import Float8Linear
45
46
from torchao .float8 .inference import Float8MMConfig
51
52
to_weight_tensor_with_linear_activation_quantization_metadata ,
52
53
)
53
54
from torchao .utils import (
54
- TORCH_VERSION_AT_LEAST_2_4 ,
55
- TORCH_VERSION_AT_LEAST_2_5 ,
56
- TORCH_VERSION_AT_LEAST_2_6 ,
57
55
is_MI300 ,
58
56
is_sm_at_least_89 ,
59
57
is_sm_at_least_90 ,
58
+ TORCH_VERSION_AT_LEAST_2_4 ,
59
+ TORCH_VERSION_AT_LEAST_2_5 ,
60
+ TORCH_VERSION_AT_LEAST_2_6 ,
60
61
)
61
62
62
- from .autoquant import AutoQuantizableLinearWeight , autoquant
63
+ from .autoquant import autoquant , AutoQuantizableLinearWeight
63
64
from .GPTQ import (
64
65
Int4WeightOnlyGPTQQuantizer ,
65
66
Int4WeightOnlyQuantizer ,
66
67
Int8DynActInt4WeightGPTQQuantizer ,
67
68
Int8DynActInt4WeightQuantizer ,
68
69
)
69
- from .granularity import (
70
- PerRow ,
71
- PerTensor ,
72
- )
70
+ from .granularity import PerRow , PerTensor
73
71
from .linear_activation_quantized_tensor import (
74
72
LinearActivationQuantizedTensor ,
75
73
to_linear_activation_quantized ,
76
74
)
77
- from .qat import (
78
- intx_quantization_aware_training ,
79
- )
80
- from .quant_primitives import (
81
- MappingType ,
82
- ZeroPointDomain ,
83
- )
75
+ from .qat import intx_quantization_aware_training
76
+ from .quant_primitives import MappingType , ZeroPointDomain
84
77
from .subclass import (
85
78
Int4WeightOnlyQuantizedLinearWeight ,
86
79
Int8DynamicallyQuantizedLinearWeight ,
@@ -915,10 +908,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
915
908
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
916
909
quantization + 2:4 sparsity to linear layers.
917
910
"""
918
- warnings .warn ("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
911
+ warnings .warn (
912
+ """int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
919
913
920
914
from torchao.dtypes import SemiSparseLayout
921
- int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""" )
915
+ int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
916
+ )
922
917
923
918
return int8_dynamic_activation_int8_weight (layout = SemiSparseLayout ())
924
919
@@ -938,11 +933,10 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
938
933
939
934
def apply_float8wo_quant (weight ):
940
935
block_size = (1 , weight .shape [1 ])
941
- return to_affine_quantized_floatx (
936
+ return to_affine_quantized_float8 (
942
937
input_float = weight ,
943
938
block_size = block_size ,
944
939
target_dtype = weight_dtype ,
945
- scale_dtype = None ,
946
940
_layout = Float8Layout (mm_config = None ),
947
941
)
948
942
@@ -1016,11 +1010,10 @@ def _input_activation_quant_func_fp8(
1016
1010
1017
1011
block_size = get_block_size (x .shape , activation_granularity )
1018
1012
if scale is None :
1019
- activation = to_affine_quantized_floatx (
1013
+ activation = to_affine_quantized_float8 (
1020
1014
input_float = x ,
1021
1015
block_size = block_size ,
1022
1016
target_dtype = activation_dtype ,
1023
- scale_dtype = torch .float32 ,
1024
1017
_layout = Float8Layout (mm_config = None ), # Config is stored on weight
1025
1018
)
1026
1019
else :
@@ -1102,11 +1095,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
1102
1095
), "PerRow quantization only works for bfloat16 precision input weight"
1103
1096
1104
1097
block_size = get_block_size (weight .shape , weight_granularity )
1105
- quantized_weight = to_affine_quantized_floatx (
1098
+ quantized_weight = to_affine_quantized_float8 (
1106
1099
input_float = weight ,
1107
1100
block_size = block_size ,
1108
1101
target_dtype = weight_dtype ,
1109
- scale_dtype = torch .float32 ,
1110
1102
_layout = Float8Layout (mm_config = mm_config ),
1111
1103
)
1112
1104
0 commit comments