|
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 |
|
7 |
| -from torchao.dtypes.utils import ( |
8 |
| - AQTTensorImpl, |
9 |
| - Layout, |
10 |
| - PlainLayout, |
11 |
| -) |
| 7 | +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout |
12 | 8 | from torchao.quantization.quant_primitives import (
|
13 | 9 | FP8_TYPES,
|
14 | 10 | MappingType,
|
15 | 11 | ZeroPointDomain,
|
16 | 12 | choose_qparams_affine,
|
| 13 | + choose_qparams_affine_float8, |
17 | 14 | choose_qparams_affine_floatx,
|
18 | 15 | choose_qparams_and_quantize_affine_hqq,
|
19 | 16 | dequantize_affine,
|
20 | 17 | dequantize_affine_floatx,
|
21 | 18 | quantize_affine,
|
| 19 | + quantize_affine_float8, |
22 | 20 | quantize_affine_floatx,
|
23 | 21 | )
|
24 |
| -from torchao.utils import ( |
25 |
| - TORCH_VERSION_AT_LEAST_2_5, |
26 |
| - TorchAOBaseTensor, |
27 |
| -) |
| 22 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor |
28 | 23 |
|
29 | 24 | logger = logging.getLogger(__name__)
|
30 | 25 | aten = torch.ops.aten
|
@@ -422,6 +417,36 @@ def from_hp_to_fpx(
|
422 | 417 | tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
|
423 | 418 | return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)
|
424 | 419 |
|
| 420 | + @classmethod |
| 421 | + def from_hp_to_float8( |
| 422 | + cls, |
| 423 | + input_float: torch.Tensor, |
| 424 | + target_dtype: torch.dtype, |
| 425 | + block_size: Tuple[int, ...], |
| 426 | + _layout: Layout = PlainLayout(), |
| 427 | + ): |
| 428 | + assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8" |
| 429 | + original_shape = input_float.shape |
| 430 | + scale = choose_qparams_affine_float8( |
| 431 | + input_float, |
| 432 | + target_dtype, |
| 433 | + target_dtype, |
| 434 | + ) |
| 435 | + fp8_data = quantize_affine_float8( |
| 436 | + input_float, |
| 437 | + scale, |
| 438 | + target_dtype, |
| 439 | + ) |
| 440 | + fp8_data = _layout.post_process(fp8_data) |
| 441 | + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) |
| 442 | + tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout) |
| 443 | + return cls( |
| 444 | + tensor_impl, |
| 445 | + block_size, |
| 446 | + original_shape, |
| 447 | + dtype=input_float.dtype, |
| 448 | + ) |
| 449 | + |
425 | 450 | @property
|
426 | 451 | def _layout(self) -> Layout:
|
427 | 452 | return self.tensor_impl._layout
|
@@ -477,6 +502,7 @@ def _apply_fn_to_data(self, fn):
|
477 | 502 | to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
|
478 | 503 | to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
|
479 | 504 | to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
|
| 505 | +to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8 |
480 | 506 | # experimental will be merged in to floatx
|
481 | 507 | to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
|
482 | 508 |
|
|
0 commit comments