From ab8d5b56ddbdd1af7d6bf17bd5c74b33394e2662 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 22 Jan 2025 11:27:16 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torchao/dtypes/affine_quantized_tensor.py | 44 ++++++++++++++++++----- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e7aca34c5f..506e8f0174 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -4,27 +4,22 @@ import torch -from torchao.dtypes.utils import ( - AQTTensorImpl, - Layout, - PlainLayout, -) +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout from torchao.quantization.quant_primitives import ( FP8_TYPES, MappingType, ZeroPointDomain, choose_qparams_affine, + choose_qparams_affine_float8, choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, dequantize_affine, dequantize_affine_floatx, quantize_affine, + quantize_affine_float8, quantize_affine_floatx, ) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_5, - TorchAOBaseTensor, -) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor logger = logging.getLogger(__name__) aten = torch.ops.aten @@ -422,6 +417,36 @@ def from_hp_to_fpx( tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) + @classmethod + def from_hp_to_float8( + cls, + input_float: torch.Tensor, + target_dtype: torch.dtype, + block_size: Tuple[int, ...], + _layout: Layout = PlainLayout(), + ): + assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8" + original_shape = input_float.shape + scale = choose_qparams_affine_float8( + input_float, + target_dtype, + target_dtype, + ) + fp8_data = quantize_affine_float8( + input_float, + scale, + target_dtype, + ) + fp8_data = _layout.post_process(fp8_data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + dtype=input_float.dtype, + ) + @property def _layout(self) -> Layout: return self.tensor_impl._layout @@ -477,6 +502,7 @@ def _apply_fn_to_data(self, fn): to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static +to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8 # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx From 76c7bde837e36587cf4acbc938d9ce78dc68d43f Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Thu, 23 Jan 2025 13:02:52 -0800 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchao/dtypes/affine_quantized_tensor.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 506e8f0174..06b5668fb7 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -6,18 +6,18 @@ from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout from torchao.quantization.quant_primitives import ( - FP8_TYPES, - MappingType, - ZeroPointDomain, choose_qparams_affine, choose_qparams_affine_float8, choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, dequantize_affine, dequantize_affine_floatx, + FP8_TYPES, + MappingType, quantize_affine, quantize_affine_float8, quantize_affine_floatx, + ZeroPointDomain, ) from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor @@ -426,6 +426,10 @@ def from_hp_to_float8( _layout: Layout = PlainLayout(), ): assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8" + + # to avoid circular dependency + from torchao.dtypes.floatx import Float8AQTTensorImpl + original_shape = input_float.shape scale = choose_qparams_affine_float8( input_float, @@ -438,8 +442,7 @@ def from_hp_to_float8( target_dtype, ) fp8_data = _layout.post_process(fp8_data) - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout) + tensor_impl = Float8AQTTensorImpl(fp8_data, scale, None, _layout) return cls( tensor_impl, block_size,