Skip to content

Commit 7cc4944

Browse files
integrate new float8 quantization primitives into AQT
ghstack-source-id: c1deeeb ghstack-comment-id: 2608090492 Pull Request resolved: #1598
1 parent ac3dc8d commit 7cc4944

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

torchao/dtypes/affine_quantized_tensor.py

+38-9
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,22 @@
44

55
import torch
66

7-
from torchao.dtypes.utils import (
8-
AQTTensorImpl,
9-
Layout,
10-
PlainLayout,
11-
)
7+
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
128
from torchao.quantization.quant_primitives import (
139
FP8_TYPES,
1410
MappingType,
1511
ZeroPointDomain,
1612
choose_qparams_affine,
13+
choose_qparams_affine_float8,
1714
choose_qparams_affine_floatx,
1815
choose_qparams_and_quantize_affine_hqq,
1916
dequantize_affine,
2017
dequantize_affine_floatx,
2118
quantize_affine,
19+
quantize_affine_float8,
2220
quantize_affine_floatx,
2321
)
24-
from torchao.utils import (
25-
TORCH_VERSION_AT_LEAST_2_5,
26-
TorchAOBaseTensor,
27-
)
22+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
2823

2924
logger = logging.getLogger(__name__)
3025
aten = torch.ops.aten
@@ -422,6 +417,39 @@ def from_hp_to_fpx(
422417
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
423418
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)
424419

420+
@classmethod
421+
def from_hp_to_float8(
422+
cls,
423+
input_float: torch.Tensor,
424+
target_dtype: torch.dtype,
425+
block_size: Tuple[int, ...],
426+
_layout: Layout = PlainLayout(),
427+
):
428+
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
429+
430+
# to avoid circular dependency
431+
from torchao.dtypes.floatx import Float8AQTTensorImpl
432+
433+
original_shape = input_float.shape
434+
scale = choose_qparams_affine_float8(
435+
input_float,
436+
target_dtype,
437+
target_dtype,
438+
)
439+
fp8_data = quantize_affine_float8(
440+
input_float,
441+
scale,
442+
target_dtype,
443+
)
444+
fp8_data = _layout.post_process(fp8_data)
445+
tensor_impl = Float8AQTTensorImpl(fp8_data, scale, None, _layout)
446+
return cls(
447+
tensor_impl,
448+
block_size,
449+
original_shape,
450+
dtype=input_float.dtype,
451+
)
452+
425453
@property
426454
def _layout(self) -> Layout:
427455
return self.tensor_impl._layout
@@ -477,6 +505,7 @@ def _apply_fn_to_data(self, fn):
477505
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
478506
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
479507
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
508+
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
480509
# experimental will be merged in to floatx
481510
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
482511

torchao/quantization/quant_primitives.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8-
from enum import auto, Enum
8+
from enum import Enum, auto
99
from typing import Callable, Dict, List, Optional, Tuple, Union
1010

1111
import torch
1212

1313
from torchao.float8.float8_utils import (
1414
ScalingGranularity,
15+
)
16+
from torchao.float8.float8_utils import (
1517
tensor_to_scale as tensor_to_float8_scale,
1618
)
1719
from torchao.prototype.custom_fp_utils import (
@@ -20,11 +22,11 @@
2022
_n_ones,
2123
)
2224
from torchao.utils import (
23-
_is_float8_type,
24-
_register_custom_op,
2525
TORCH_VERSION_AT_LEAST_2_3,
2626
TORCH_VERSION_AT_LEAST_2_5,
2727
TORCH_VERSION_AT_LEAST_2_6,
28+
_is_float8_type,
29+
_register_custom_op,
2830
)
2931

3032
__all__ = [

0 commit comments

Comments
 (0)