Skip to content

Commit fb011f0

Browse files
integrate new float8 quantization primitives into AQT
ghstack-source-id: 9aacd39f35698b23f70c6270aa67e95aff24e29a ghstack-comment-id: 2608090492 Pull Request resolved: #1598
1 parent 598469e commit fb011f0

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

torchao/dtypes/affine_quantized_tensor.py

+35-9
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,22 @@
44

55
import torch
66

7-
from torchao.dtypes.utils import (
8-
AQTTensorImpl,
9-
Layout,
10-
PlainLayout,
11-
)
7+
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
128
from torchao.quantization.quant_primitives import (
139
FP8_TYPES,
1410
MappingType,
1511
ZeroPointDomain,
1612
choose_qparams_affine,
13+
choose_qparams_affine_float8,
1714
choose_qparams_affine_floatx,
1815
choose_qparams_and_quantize_affine_hqq,
1916
dequantize_affine,
2017
dequantize_affine_floatx,
2118
quantize_affine,
19+
quantize_affine_float8,
2220
quantize_affine_floatx,
2321
)
24-
from torchao.utils import (
25-
TORCH_VERSION_AT_LEAST_2_5,
26-
TorchAOBaseTensor,
27-
)
22+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
2823

2924
logger = logging.getLogger(__name__)
3025
aten = torch.ops.aten
@@ -422,6 +417,36 @@ def from_hp_to_fpx(
422417
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
423418
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)
424419

420+
@classmethod
421+
def from_hp_to_float8(
422+
cls,
423+
input_float: torch.Tensor,
424+
target_dtype: torch.dtype,
425+
block_size: Tuple[int, ...],
426+
_layout: Layout = PlainLayout(),
427+
):
428+
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
429+
original_shape = input_float.shape
430+
scale = choose_qparams_affine_float8(
431+
input_float,
432+
target_dtype,
433+
target_dtype,
434+
)
435+
fp8_data = quantize_affine_float8(
436+
input_float,
437+
scale,
438+
target_dtype,
439+
)
440+
fp8_data = _layout.post_process(fp8_data)
441+
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
442+
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
443+
return cls(
444+
tensor_impl,
445+
block_size,
446+
original_shape,
447+
dtype=input_float.dtype,
448+
)
449+
425450
@property
426451
def _layout(self) -> Layout:
427452
return self.tensor_impl._layout
@@ -477,6 +502,7 @@ def _apply_fn_to_data(self, fn):
477502
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
478503
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
479504
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
505+
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
480506
# experimental will be merged in to floatx
481507
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
482508

0 commit comments

Comments
 (0)