Skip to content

Commit 41357fa

Browse files
integrate new float8 quantization primitives into AQT
ghstack-source-id: 0a9fcdb ghstack-comment-id: 2608090492 Pull Request resolved: #1598
1 parent 102d4a4 commit 41357fa

File tree

1 file changed

+38
-9
lines changed

1 file changed

+38
-9
lines changed

torchao/dtypes/affine_quantized_tensor.py

+38-9
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,22 @@
44

55
import torch
66

7-
from torchao.dtypes.utils import (
8-
AQTTensorImpl,
9-
Layout,
10-
PlainLayout,
11-
)
7+
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
128
from torchao.quantization.quant_primitives import (
139
FP8_TYPES,
1410
MappingType,
1511
ZeroPointDomain,
1612
choose_qparams_affine,
13+
choose_qparams_affine_float8,
1714
choose_qparams_affine_floatx,
1815
choose_qparams_and_quantize_affine_hqq,
1916
dequantize_affine,
2017
dequantize_affine_floatx,
2118
quantize_affine,
19+
quantize_affine_float8,
2220
quantize_affine_floatx,
2321
)
24-
from torchao.utils import (
25-
TORCH_VERSION_AT_LEAST_2_5,
26-
TorchAOBaseTensor,
27-
)
22+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
2823

2924
logger = logging.getLogger(__name__)
3025
aten = torch.ops.aten
@@ -422,6 +417,39 @@ def from_hp_to_fpx(
422417
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
423418
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)
424419

420+
@classmethod
421+
def from_hp_to_float8(
422+
cls,
423+
input_float: torch.Tensor,
424+
target_dtype: torch.dtype,
425+
block_size: Tuple[int, ...],
426+
_layout: Layout = PlainLayout(),
427+
):
428+
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
429+
430+
# to avoid circular dependency
431+
from torchao.dtypes.floatx import Float8AQTTensorImpl
432+
433+
original_shape = input_float.shape
434+
scale = choose_qparams_affine_float8(
435+
input_float,
436+
target_dtype,
437+
target_dtype,
438+
)
439+
fp8_data = quantize_affine_float8(
440+
input_float,
441+
scale,
442+
target_dtype,
443+
)
444+
fp8_data = _layout.post_process(fp8_data)
445+
tensor_impl = Float8AQTTensorImpl(fp8_data, scale, None, _layout)
446+
return cls(
447+
tensor_impl,
448+
block_size,
449+
original_shape,
450+
dtype=input_float.dtype,
451+
)
452+
425453
@property
426454
def _layout(self) -> Layout:
427455
return self.tensor_impl._layout
@@ -477,6 +505,7 @@ def _apply_fn_to_data(self, fn):
477505
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
478506
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
479507
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
508+
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
480509
# experimental will be merged in to floatx
481510
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
482511

0 commit comments

Comments
 (0)