diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index e3ac420de7..b3b0cfdf67 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -337,64 +337,6 @@ def from_hp_to_intx_static( dtype=input_float.dtype, ) - @classmethod - def from_hp_to_floatx( - cls, - input_float: torch.Tensor, - block_size: Tuple[int, ...], - target_dtype: torch.dtype, - _layout: Layout, - scale_dtype: Optional[torch.dtype] = None, - ): - """Convert a high precision tensor to a float8 quantized tensor.""" - if target_dtype in FP8_TYPES: - return cls.from_hp_to_intx( - input_float=input_float, - mapping_type=MappingType.SYMMETRIC, - block_size=block_size, - target_dtype=target_dtype, - quant_min=math.ceil(torch.finfo(target_dtype).min), - quant_max=math.ceil(torch.finfo(target_dtype).max), - eps=torch.finfo(torch.float32).eps, - scale_dtype=scale_dtype, - zero_point_dtype=None, - preserve_zero=True, - zero_point_domain=None, - _layout=_layout, - use_hqq=False, - ) - else: - raise NotImplementedError( - f"Unsupported dtype {target_dtype} for from_hp_to_floatx" - ) - - @classmethod - def from_hp_to_floatx_static( - cls, - input_float: torch.Tensor, - scale: torch.Tensor, - block_size: Tuple[int, ...], - target_dtype: torch.dtype, - _layout: Layout, - ): - """Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters.""" - if target_dtype in FP8_TYPES: - return cls.from_hp_to_intx_static( - input_float=input_float, - scale=scale, - zero_point=None, - block_size=block_size, - target_dtype=target_dtype, - quant_min=math.ceil(torch.finfo(target_dtype).min), - quant_max=math.ceil(torch.finfo(target_dtype).max), - zero_point_domain=None, - _layout=_layout, - ) - else: - raise NotImplementedError( - f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" - ) - @classmethod def from_hp_to_fpx( cls, diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 5a7e1924b3..7d15f5e9e6 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union +import math import torch from torch.utils._python_dispatch import ( @@ -11,6 +12,7 @@ AffineQuantizedTensor, register_layout, ) +from torchao.dtypes.nf4tensor import implements from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape from torchao.float8.inference import ( Float8MMConfig, @@ -18,8 +20,13 @@ addmm_float8_unwrapped_inference, preprocess_data, ) -from torchao.utils import _is_float8_type, fill_defaults - +from torchao.utils import _is_float8_type, fill_defaults, TorchAOBaseTensor +from torchao.quantization.quant_primitives import ( + FP8_TYPES, + MappingType, + choose_qparams_affine_float8, + quantize_affine_float8, +) aten = torch.ops.aten @@ -34,13 +41,16 @@ class Float8Layout(Layout): mm_config: Optional[Float8MMConfig] = None -@register_layout(Float8Layout) -class Float8AQTTensorImpl(AQTTensorImpl): +class Float8Tensor(TorchAOBaseTensor): """ - TensorImpl for float8 layout affine quantized tensor + Float8 Tensor is a subclass of torch.Tensor that supports float8 data types. + It is used to represent the data in a float8 tensor. - Note: technically we should not create a new layout for float8 we should merge this into - plain layout + Attributes: + float8_data (torch.Tensor): The float8 data tensor. + scale (torch.Tensor): The scale tensor. + transposed (bool): Whether the tensor is transposed or not. + _layout (Layout): The layout of the tensor. """ float8_data: torch.Tensor @@ -52,7 +62,7 @@ def __new__( float8_data: torch.Tensor, scale: torch.Tensor, transposed: bool, - _layout: Layout, + _layout: Layout = Float8Layout(), ): kwargs = {} kwargs["device"] = float8_data.device @@ -69,7 +79,7 @@ def __init__( float8_data: torch.Tensor, scale: torch.Tensor, transposed: bool, - _layout: Layout, + _layout: Layout = Float8Layout(), ): self.float8_data = float8_data self.scale = scale @@ -108,84 +118,20 @@ def __tensor_unflatten__( ) = tensor_attributes return cls(float8_data, scale, transposed, _layout) - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - elif func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - elif func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - args[0].transposed = not args[0].transposed - return return_and_correct_aliasing(func, args, kwargs, args[0]) - elif func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - # TODO: scale replecation should be dependent on block size - if self.scale.ndim == 1: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: aten.slice.Tensor(x, dim, start, end, step) - ), - ) - elif self.scale.ndim == 0: - return return_and_correct_aliasing( - func, - args, - kwargs, - Float8AQTTensorImpl( - aten.slice.Tensor(self.float8_data, dim, start, end, step), - self.scale, - None, - self._layout, - ), - ) - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" - ) - elif dim == 1: - return return_and_correct_aliasing( - func, - args, - kwargs, - Float8AQTTensorImpl( - aten.slice.Tensor( - self.float8_data, dim, start, end, step - ).contiguous(), - self.scale, - None, - self._layout, - ), - ) - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl + def __repr__(self): + float8_data, scale, _ = self.get_plain() + _layout = self.get_layout() + return ( + f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"transposed={self.transposed}, " + f"_layout={_layout})" + ) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return self.float8_data, self.scale, None - def get_layout(self) -> Layout: - return self._layout - @classmethod def from_plain( cls, @@ -203,15 +149,120 @@ def from_plain( ), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" return cls(data, scale, False, _layout) - def __repr__(self): - float8_data, scale, _ = self.get_plain() - _layout = self.get_layout() - return ( - f"{self.__class__.__name__}(\n" - f"float8_data={float8_data},\n" - f"scale={scale},\n" - f"transposed={self.transposed}, " - f"_layout={_layout})" + @classmethod + def from_hp_to_floatx( + cls, + input_float: torch.Tensor, + target_dtype: torch.dtype, + _layout: Layout = Float8Layout(), + ): + """Convert a high precision tensor to a float8 quantized tensor.""" + if target_dtype not in FP8_TYPES: + raise NotImplementedError( + f"Unsupported dtype {target_dtype} for from_hp_to_floatx" + ) + scale = choose_qparams_affine_float8( + input_float, + target_dtype, + ) + float_data = quantize_affine_float8( + input_float, + scale, + target_dtype, + ) + + return cls( + float_data, + scale, + False, + _layout, + ) + + @classmethod + def from_hp_to_floatx_static( + cls, + input_float: torch.Tensor, + scale: torch.Tensor, + target_dtype: torch.dtype, + _layout: Layout, + ): + """Create a float8 AffineQuantizedTensor from a high precision tensor using static parameters.""" + if target_dtype not in FP8_TYPES: + raise NotImplementedError( + f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" + ) + float_data = quantize_affine_float8( + input_float, + scale, + target_dtype, + ) + + return cls( + float_data, + scale, + False, + _layout, + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + args[0].transposed = not args[0].transposed + return return_and_correct_aliasing(func, args, kwargs, args[0]) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + # TODO: scale replecation should be dependent on block size + if self.scale.ndim == 1: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), + ) + elif self.scale.ndim == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + Float8Tensor( + aten.slice.Tensor(self.float8_data, dim, start, end, step), + self.scale, + self.transposed, + self._layout, + ), + ) + else: + raise NotImplementedError( + f"Float8Tensor dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" + ) + elif dim == 1: + return return_and_correct_aliasing( + func, + args, + kwargs, + Float8Tensor( + aten.slice.Tensor( + self.float8_data, dim, start, end, step + ).contiguous(), + self.scale, + self.transposed, + self._layout, + ), + ) + else: + raise NotImplementedError( + f"Float8Tensor dispatch: attempting to run {func}, with dim={dim}, that is not supported" ) @@ -317,3 +368,7 @@ def _linear_fp_act_fp8_weight_impl( bias: Optional[torch.Tensor], ): return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) + + +to_quantized_float8 = Float8Tensor.from_hp_to_floatx +to_quantized_float8_static = Float8Tensor.from_hp_to_float8_static