Skip to content

Commit d035740

Browse files
replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs
ghstack-source-id: 43890bf2fd3b4d9cc251b4ea614de6ff8d93735b ghstack-comment-id: 2608105249 Pull Request resolved: #1599
1 parent fb011f0 commit d035740

File tree

3 files changed

+19
-31
lines changed

3 files changed

+19
-31
lines changed

torchao/dtypes/__init__.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
from . import affine_quantized_tensor_ops
22
from .affine_quantized_tensor import (
33
AffineQuantizedTensor,
4+
to_affine_quantized_float8,
45
to_affine_quantized_floatx,
56
to_affine_quantized_floatx_static,
67
# experimental, will be merged into floatx in the future
78
to_affine_quantized_fpx,
89
to_affine_quantized_intx,
910
to_affine_quantized_intx_static,
1011
)
11-
from .floatx import (
12-
Float8Layout,
13-
)
12+
from .floatx import Float8Layout
1413
from .nf4tensor import NF4Tensor, to_nf4
1514
from .uintx import (
1615
BlockSparseLayout,
@@ -24,10 +23,7 @@
2423
UintxLayout,
2524
to_marlinqqq_quantized_intx,
2625
)
27-
from .utils import (
28-
Layout,
29-
PlainLayout,
30-
)
26+
from .utils import Layout, PlainLayout
3127

3228
__all__ = [
3329
"NF4Tensor",
@@ -38,6 +34,7 @@
3834
"to_affine_quantized_fpx",
3935
"to_affine_quantized_floatx",
4036
"to_affine_quantized_floatx_static",
37+
"to_affine_quantized_float8",
4138
"to_marlinqqq_quantized_intx",
4239
"Layout",
4340
"PlainLayout",

torchao/dtypes/affine_quantized_tensor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,18 @@
66

77
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
88
from torchao.quantization.quant_primitives import (
9-
FP8_TYPES,
10-
MappingType,
11-
ZeroPointDomain,
129
choose_qparams_affine,
1310
choose_qparams_affine_float8,
1411
choose_qparams_affine_floatx,
1512
choose_qparams_and_quantize_affine_hqq,
1613
dequantize_affine,
1714
dequantize_affine_floatx,
15+
FP8_TYPES,
16+
MappingType,
1817
quantize_affine,
1918
quantize_affine_float8,
2019
quantize_affine_floatx,
20+
ZeroPointDomain,
2121
)
2222
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
2323

@@ -29,6 +29,7 @@
2929
"register_layout",
3030
"to_affine_quantized_intx",
3131
"to_affine_quantized_floatx",
32+
"to_affine_quantized_float8",
3233
"to_affine_quantized_intx_static",
3334
"to_affine_quantized_floatx_static",
3435
"to_affine_quantized_fpx",
@@ -430,7 +431,6 @@ def from_hp_to_float8(
430431
scale = choose_qparams_affine_float8(
431432
input_float,
432433
target_dtype,
433-
target_dtype,
434434
)
435435
fp8_data = quantize_affine_float8(
436436
input_float,

torchao/quantization/quant_api.py

+11-20
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
SemiSparseLayout,
3737
TensorCoreTiledLayout,
3838
UintxLayout,
39+
to_affine_quantized_float8,
3940
to_affine_quantized_floatx,
4041
to_affine_quantized_floatx_static,
4142
to_affine_quantized_intx,
@@ -66,21 +67,13 @@
6667
Int8DynActInt4WeightGPTQQuantizer,
6768
Int8DynActInt4WeightQuantizer,
6869
)
69-
from .granularity import (
70-
PerRow,
71-
PerTensor,
72-
)
70+
from .granularity import PerRow, PerTensor
7371
from .linear_activation_quantized_tensor import (
7472
LinearActivationQuantizedTensor,
7573
to_linear_activation_quantized,
7674
)
77-
from .qat import (
78-
intx_quantization_aware_training,
79-
)
80-
from .quant_primitives import (
81-
MappingType,
82-
ZeroPointDomain,
83-
)
75+
from .qat import intx_quantization_aware_training
76+
from .quant_primitives import MappingType, ZeroPointDomain
8477
from .subclass import (
8578
Int4WeightOnlyQuantizedLinearWeight,
8679
Int8DynamicallyQuantizedLinearWeight,
@@ -915,10 +908,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
915908
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
916909
quantization + 2:4 sparsity to linear layers.
917910
"""
918-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
911+
warnings.warn(
912+
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
919913
920914
from torchao.dtypes import SemiSparseLayout
921-
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
915+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
916+
)
922917

923918
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
924919

@@ -934,15 +929,13 @@ def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
934929
The actual matmul will be computed in original precision of the weight tensor.
935930
936931
"""
937-
from torchao.dtypes import to_affine_quantized_floatx
938932

939933
def apply_float8wo_quant(weight):
940934
block_size = (1, weight.shape[1])
941-
return to_affine_quantized_floatx(
935+
return to_affine_quantized_float8(
942936
input_float=weight,
943937
block_size=block_size,
944938
target_dtype=weight_dtype,
945-
scale_dtype=None,
946939
_layout=Float8Layout(mm_config=None),
947940
)
948941

@@ -1016,11 +1009,10 @@ def _input_activation_quant_func_fp8(
10161009

10171010
block_size = get_block_size(x.shape, activation_granularity)
10181011
if scale is None:
1019-
activation = to_affine_quantized_floatx(
1012+
activation = to_affine_quantized_float8(
10201013
input_float=x,
10211014
block_size=block_size,
10221015
target_dtype=activation_dtype,
1023-
scale_dtype=torch.float32,
10241016
_layout=Float8Layout(mm_config=None), # Config is stored on weight
10251017
)
10261018
else:
@@ -1102,11 +1094,10 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
11021094
), "PerRow quantization only works for bfloat16 precision input weight"
11031095

11041096
block_size = get_block_size(weight.shape, weight_granularity)
1105-
quantized_weight = to_affine_quantized_floatx(
1097+
quantized_weight = to_affine_quantized_float8(
11061098
input_float=weight,
11071099
block_size=block_size,
11081100
target_dtype=weight_dtype,
1109-
scale_dtype=torch.float32,
11101101
_layout=Float8Layout(mm_config=mm_config),
11111102
)
11121103

0 commit comments

Comments
 (0)