Skip to content

Commit 9606219

Browse files
replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs
ghstack-source-id: cba5e1c ghstack-comment-id: 2608105249 Pull Request resolved: #1599
1 parent fae690c commit 9606219

File tree

6 files changed

+30
-63
lines changed

6 files changed

+30
-63
lines changed

docs/source/api_ref_dtypes.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ torchao.dtypes
1313
to_nf4
1414
to_affine_quantized_intx
1515
to_affine_quantized_intx_static
16-
to_affine_quantized_floatx
16+
to_affine_quantized_float8
1717
to_affine_quantized_floatx_static
1818
to_affine_quantized_fpx
1919
NF4Tensor

torchao/dtypes/__init__.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from . import affine_quantized_tensor_ops
22
from .affine_quantized_tensor import (
33
AffineQuantizedTensor,
4-
to_affine_quantized_floatx,
4+
to_affine_quantized_float8,
55
to_affine_quantized_floatx_static,
66
# experimental, will be merged into floatx in the future
77
to_affine_quantized_fpx,
88
to_affine_quantized_intx,
99
to_affine_quantized_intx_static,
1010
)
11-
from .floatx import (
12-
Float8Layout,
13-
)
11+
from .floatx import Float8Layout
1412
from .nf4tensor import NF4Tensor, to_nf4
1513
from .uintx import (
1614
BlockSparseLayout,
@@ -24,10 +22,7 @@
2422
UintxLayout,
2523
to_marlinqqq_quantized_intx,
2624
)
27-
from .utils import (
28-
Layout,
29-
PlainLayout,
30-
)
25+
from .utils import Layout, PlainLayout
3126

3227
__all__ = [
3328
"NF4Tensor",
@@ -36,8 +31,8 @@
3631
"to_affine_quantized_intx",
3732
"to_affine_quantized_intx_static",
3833
"to_affine_quantized_fpx",
39-
"to_affine_quantized_floatx",
4034
"to_affine_quantized_floatx_static",
35+
"to_affine_quantized_float8",
4136
"to_marlinqqq_quantized_intx",
4237
"Layout",
4338
"PlainLayout",

torchao/dtypes/affine_quantized_tensor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"AffineQuantizedTensor",
2929
"register_layout",
3030
"to_affine_quantized_intx",
31-
"to_affine_quantized_floatx",
31+
"to_affine_quantized_float8",
3232
"to_affine_quantized_intx_static",
3333
"to_affine_quantized_floatx_static",
3434
"to_affine_quantized_fpx",
@@ -430,7 +430,6 @@ def from_hp_to_float8(
430430
scale = choose_qparams_affine_float8(
431431
input_float,
432432
target_dtype,
433-
target_dtype,
434433
)
435434
fp8_data = quantize_affine_float8(
436435
input_float,

torchao/prototype/quantization/autoquant_v2.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,8 @@
2727
from torchao.quantization.autoquant import (
2828
AutoQuantizableLinearWeight as AutoQuantizableLinearWeightV1,
2929
)
30-
from torchao.quantization.granularity import (
31-
PerRow,
32-
PerTensor,
33-
)
34-
from torchao.quantization.quant_primitives import (
35-
MappingType,
36-
ZeroPointDomain,
37-
)
30+
from torchao.quantization.granularity import PerRow, PerTensor
31+
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
3832
from torchao.quantization.subclass import ( # noqa
3933
Int8DynamicallyQuantizedLinearWeight,
4034
Int8WeightOnlyQuantizedLinearWeight,
@@ -991,7 +985,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(
991985
@classmethod
992986
def from_float(cls, weight):
993987
# avoid circular dep
994-
from torchao.dtypes import to_affine_quantized_floatx
988+
from torchao.dtypes import to_affine_quantized_float8
995989
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
996990

997991
# weight settings
@@ -1015,12 +1009,11 @@ def get_per_token_block_size(x):
10151009
activation_dtype=input_target_dtype,
10161010
)
10171011
block_size = get_weight_block_size(weight)
1018-
weight = to_affine_quantized_floatx(
1012+
weight = to_affine_quantized_float8(
10191013
input_float=weight,
10201014
block_size=block_size,
10211015
target_dtype=target_dtype,
10221016
_layout=_layout,
1023-
scale_dtype=torch.float32,
10241017
)
10251018
weight = super(
10261019
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls
@@ -1040,7 +1033,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(
10401033
@classmethod
10411034
def from_float(cls, weight):
10421035
# avoid circular dep
1043-
from torchao.dtypes import to_affine_quantized_floatx
1036+
from torchao.dtypes import to_affine_quantized_float8
10441037
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
10451038

10461039
# weight settings
@@ -1058,12 +1051,11 @@ def get_weight_block_size(x):
10581051
activation_dtype=input_target_dtype,
10591052
)
10601053
block_size = get_weight_block_size(weight)
1061-
weight = to_affine_quantized_floatx(
1054+
weight = to_affine_quantized_float8(
10621055
input_float=weight,
10631056
block_size=block_size,
10641057
target_dtype=target_dtype,
10651058
_layout=_layout,
1066-
scale_dtype=torch.float32,
10671059
)
10681060
weight = super(
10691061
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls

torchao/quantization/autoquant.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
LinearActivationQuantizedTensor,
1919
to_linear_activation_quantized,
2020
)
21-
from torchao.quantization.quant_primitives import (
22-
MappingType,
23-
ZeroPointDomain,
24-
)
21+
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
2522
from torchao.quantization.utils import (
2623
compute_error,
2724
quantize_activation_per_token_absmax,
@@ -34,10 +31,7 @@
3431
is_sm_at_least_90,
3532
)
3633

37-
from .granularity import (
38-
PerRow,
39-
PerTensor,
40-
)
34+
from .granularity import PerRow, PerTensor
4135
from .subclass import ( # noqa
4236
Int8DynamicallyQuantizedLinearWeight,
4337
Int8WeightOnlyQuantizedLinearWeight,
@@ -969,7 +963,7 @@ class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(AQMixin, BFloat16Ten
969963
@classmethod
970964
def from_float(cls, weight):
971965
# avoid circular dep
972-
from torchao.dtypes import to_affine_quantized_floatx
966+
from torchao.dtypes import to_affine_quantized_float8
973967
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
974968

975969
# weight settings
@@ -995,12 +989,11 @@ def get_per_token_block_size(x):
995989
}
996990
block_size = get_weight_block_size(weight)
997991

998-
weight = to_affine_quantized_floatx(
992+
weight = to_affine_quantized_float8(
999993
input_float=weight,
1000994
block_size=block_size,
1001995
target_dtype=target_dtype,
1002996
_layout=_layout,
1003-
scale_dtype=torch.float32,
1004997
)
1005998
weight = to_linear_activation_quantized(
1006999
weight, input_quant_func, quant_kwargs=input_quant_kwargs
@@ -1025,7 +1018,7 @@ class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(
10251018
@classmethod
10261019
def from_float(cls, weight):
10271020
# avoid circular dep
1028-
from torchao.dtypes import to_affine_quantized_floatx
1021+
from torchao.dtypes import to_affine_quantized_float8
10291022
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
10301023

10311024
# weight settings
@@ -1043,12 +1036,11 @@ def get_weight_block_size(x):
10431036
"activation_dtype": input_target_dtype,
10441037
}
10451038
block_size = get_weight_block_size(weight)
1046-
weight = to_affine_quantized_floatx(
1039+
weight = to_affine_quantized_float8(
10471040
input_float=weight,
10481041
block_size=block_size,
10491042
target_dtype=target_dtype,
10501043
_layout=_layout,
1051-
scale_dtype=torch.float32,
10521044
)
10531045
weight = super(
10541046
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls

torchao/quantization/quant_api.py

+12-23
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
SemiSparseLayout,
3737
TensorCoreTiledLayout,
3838
UintxLayout,
39-
to_affine_quantized_floatx,
39+
to_affine_quantized_float8,
4040
to_affine_quantized_floatx_static,
4141
to_affine_quantized_intx,
4242
to_marlinqqq_quantized_intx,
@@ -66,21 +66,13 @@
6666
Int8DynActInt4WeightGPTQQuantizer,
6767
Int8DynActInt4WeightQuantizer,
6868
)
69-
from .granularity import (
70-
PerRow,
71-
PerTensor,
72-
)
69+
from .granularity import PerRow, PerTensor
7370
from .linear_activation_quantized_tensor import (
7471
LinearActivationQuantizedTensor,
7572
to_linear_activation_quantized,
7673
)
77-
from .qat import (
78-
intx_quantization_aware_training,
79-
)
80-
from .quant_primitives import (
81-
MappingType,
82-
ZeroPointDomain,
83-
)
74+
from .qat import intx_quantization_aware_training
75+
from .quant_primitives import MappingType, ZeroPointDomain
8476
from .subclass import (
8577
Int4WeightOnlyQuantizedLinearWeight,
8678
Int8DynamicallyQuantizedLinearWeight,
@@ -915,10 +907,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
915907
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
916908
quantization + 2:4 sparsity to linear layers.
917909
"""
918-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
910+
warnings.warn(
911+
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
919912
920913
from torchao.dtypes import SemiSparseLayout
921-
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
914+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
915+
)
922916

923917
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
924918

@@ -934,15 +928,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
934928
The actual matmul will be computed in original precision of the weight tensor.
935929
936930
"""
937-
from torchao.dtypes import to_affine_quantized_floatx
938931

939932
def apply_float8wo_quant(weight):
940933
block_size = (1, weight.shape[1])
941-
return to_affine_quantized_floatx(
934+
return to_affine_quantized_float8(
942935
input_float=weight,
943936
block_size=block_size,
944937
target_dtype=weight_dtype,
945-
scale_dtype=None,
946938
_layout=Float8Layout(mm_config=None),
947939
)
948940

@@ -1016,11 +1008,10 @@ def _input_activation_quant_func_fp8(
10161008

10171009
block_size = get_block_size(x.shape, activation_granularity)
10181010
if scale is None:
1019-
activation = to_affine_quantized_floatx(
1011+
activation = to_affine_quantized_float8(
10201012
input_float=x,
10211013
block_size=block_size,
10221014
target_dtype=activation_dtype,
1023-
scale_dtype=torch.float32,
10241015
_layout=Float8Layout(mm_config=None), # Config is stored on weight
10251016
)
10261017
else:
@@ -1102,11 +1093,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
11021093
), "PerRow quantization only works for bfloat16 precision input weight"
11031094

11041095
block_size = get_block_size(weight.shape, weight_granularity)
1105-
quantized_weight = to_affine_quantized_floatx(
1096+
quantized_weight = to_affine_quantized_float8(
11061097
input_float=weight,
11071098
block_size=block_size,
11081099
target_dtype=weight_dtype,
1109-
scale_dtype=torch.float32,
11101100
_layout=Float8Layout(mm_config=mm_config),
11111101
)
11121102

@@ -1157,11 +1147,10 @@ def apply_float8_static_activation_quant(weight: torch.Tensor):
11571147
if not _fp8_mm_compat(weight):
11581148
return weight
11591149
block_size = get_block_size(weight.shape, weight_granularity)
1160-
quantized_weight = to_affine_quantized_floatx(
1150+
quantized_weight = to_affine_quantized_float8(
11611151
input_float=weight,
11621152
block_size=block_size,
11631153
target_dtype=weight_dtype,
1164-
scale_dtype=torch.float32,
11651154
_layout=Float8Layout(mm_config=mm_config),
11661155
)
11671156

0 commit comments

Comments
 (0)