Skip to content

Add Float8QuantizedTensor (AQT subclass) and replace to_affine_quantized_floatx with to_affine_quantized_float8 in quantization APIs #1599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
2 changes: 1 addition & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ torchao.dtypes
to_nf4
to_affine_quantized_intx
to_affine_quantized_intx_static
to_affine_quantized_floatx
to_affine_quantized_float8
to_affine_quantized_floatx_static
to_affine_quantized_fpx
NF4Tensor
Expand Down
13 changes: 4 additions & 9 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from . import affine_quantized_tensor_ops
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
# experimental, will be merged into floatx in the future
to_affine_quantized_fpx,
to_affine_quantized_intx,
to_affine_quantized_intx_static,
)
from .floatx import (
Float8Layout,
)
from .floatx import Float8Layout, Float8QuantizedTensor, to_affine_quantized_float8
from .nf4tensor import NF4Tensor, to_nf4
from .uintx import (
BlockSparseLayout,
Expand All @@ -24,20 +21,18 @@
UintxLayout,
to_marlinqqq_quantized_intx,
)
from .utils import (
Layout,
PlainLayout,
)
from .utils import Layout, PlainLayout

__all__ = [
"NF4Tensor",
"to_nf4",
"AffineQuantizedTensor",
"Float8QuantizedTensor",
"to_affine_quantized_intx",
"to_affine_quantized_intx_static",
"to_affine_quantized_fpx",
"to_affine_quantized_floatx",
"to_affine_quantized_floatx_static",
"to_affine_quantized_float8",
"to_marlinqqq_quantized_intx",
"Layout",
"PlainLayout",
Expand Down
92 changes: 22 additions & 70 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
dequantize_affine,
dequantize_affine_floatx,
quantize_affine,
quantize_affine_float8,
quantize_affine_floatx,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor
Expand All @@ -28,7 +25,6 @@
"AffineQuantizedTensor",
"register_layout",
"to_affine_quantized_intx",
"to_affine_quantized_floatx",
"to_affine_quantized_intx_static",
"to_affine_quantized_floatx_static",
"to_affine_quantized_fpx",
Expand Down Expand Up @@ -121,40 +117,28 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
if output_dtype is None:
output_dtype = self.dtype

from torchao.dtypes.floatx import FloatxTensorCoreLayout

if isinstance(self._layout, FloatxTensorCoreLayout):
int_data, scale = self.tensor_impl.get_plain()
return dequantize_affine_floatx(
int_data,
scale,
self._layout.ebits,
self._layout.mbits,
output_dtype=output_dtype,
)
else:
data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout
data, scale, zero_point = self.tensor_impl.get_plain()
dq = dequantize_affine(
data,
self.block_size,
scale,
zero_point,
data.dtype,
self.quant_min,
self.quant_max,
self.zero_point_domain,
output_dtype=output_dtype,
)
from torchao.dtypes.uintx import TensorCoreTiledLayout

if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq
if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq

def __tensor_flatten__(self):
return ["tensor_impl"], [
Expand Down Expand Up @@ -272,7 +256,7 @@ def from_hp_to_intx(
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
return cls(
tensor_impl,
Expand Down Expand Up @@ -417,36 +401,6 @@ def from_hp_to_fpx(
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype)

@classmethod
def from_hp_to_float8(
cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
block_size: Tuple[int, ...],
_layout: Layout = PlainLayout(),
):
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
original_shape = input_float.shape
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)
fp8_data = _layout.post_process(fp8_data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
dtype=input_float.dtype,
)

@property
def _layout(self) -> Layout:
return self.tensor_impl._layout
Expand Down Expand Up @@ -500,9 +454,7 @@ def _apply_fn_to_data(self, fn):

to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
to_affine_quantized_float8 = AffineQuantizedTensor.from_hp_to_float8
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx

Expand Down
8 changes: 2 additions & 6 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import torch
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
)
from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor
from torchao.dtypes.floatx.float8_layout import (
_linear_fp8_act_fp8_weight_check,
_linear_fp8_act_fp8_weight_impl,
Expand Down Expand Up @@ -52,9 +50,7 @@
_linear_bf16_act_uint4_weight_impl,
)
from torchao.quantization.quant_primitives import dequantize_affine
from torchao.utils import (
fill_defaults,
)
from torchao.utils import fill_defaults

logger = logging.getLogger(__name__)

Expand Down
18 changes: 16 additions & 2 deletions torchao/dtypes/floatx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from .float8_layout import Float8Layout
from .float8_layout import (
Float8Layout,
Float8QuantizedTensor,
_linear_fp8_act_fp8_weight_check,
_linear_fp8_act_fp8_weight_impl,
_linear_fp_act_fp8_weight_check,
_linear_fp_act_fp8_weight_impl,
to_affine_quantized_float8,
)
from .floatx_tensor_core_layout import (
FloatxTensorCoreLayout,
from_scaled_tc_floatx,
Expand All @@ -7,7 +15,13 @@

__all__ = [
"FloatxTensorCoreLayout",
"Float8Layout",
"Float8QuantizedTensor",
"to_scaled_tc_floatx",
"from_scaled_tc_floatx",
"Float8Layout",
"to_affine_quantized_float8",
"_linear_fp8_act_fp8_weight_check",
"_linear_fp8_act_fp8_weight_impl",
"_linear_fp_act_fp8_weight_check",
"_linear_fp_act_fp8_weight_impl",
]
74 changes: 64 additions & 10 deletions torchao/dtypes/floatx/float8_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
addmm_float8_unwrapped_inference,
preprocess_data,
)
from torchao.quantization.quant_primitives import (
FP8_TYPES,
choose_qparams_affine_float8,
dequantize_affine_float8,
quantize_affine_float8,
)
from torchao.utils import _is_float8_type, fill_defaults

aten = torch.ops.aten
Expand Down Expand Up @@ -209,19 +215,64 @@ def __repr__(self):
)


class Float8QuantizedTensor(AffineQuantizedTensor):
"""
Float8 quantized tensor subclass which inherits Float8QuantizedTensor class.
"""

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype
int_data, scale, _ = self.tensor_impl.get_plain()
return dequantize_affine_float8(
int_data,
scale,
output_dtype=output_dtype,
)

@classmethod
def from_hp_to_float8(
cls,
input_float: torch.Tensor,
target_dtype: torch.dtype,
block_size: Tuple[int, ...],
_layout: Layout = Float8Layout(),
):
assert target_dtype in FP8_TYPES, f"Unsupported dtype {target_dtype} for float8"
original_shape = input_float.shape
scale = choose_qparams_affine_float8(
input_float,
target_dtype,
)
fp8_data = quantize_affine_float8(
input_float,
scale,
target_dtype,
)
fp8_data = _layout.post_process(fp8_data)
tensor_impl_ctr = cls.get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(fp8_data, scale, None, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
dtype=input_float.dtype,
)


##########################
# Float8 Dispatch Kernels
##########################


def _linear_fp8_act_fp8_weight_check(
input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
input_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
weight_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
bias: Optional[torch.Tensor],
) -> bool:
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
def check_aqt(aqt: Union[torch.Tensor, Float8QuantizedTensor]) -> bool:
return (
isinstance(aqt, AffineQuantizedTensor)
isinstance(aqt, Float8QuantizedTensor)
and isinstance(aqt._layout, Float8Layout)
and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
Expand All @@ -241,8 +292,8 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]):


def _linear_fp8_act_fp8_weight_impl(
input_tensor: "AffineQuantizedTensor",
weight_tensor: "AffineQuantizedTensor",
input_tensor: "Float8QuantizedTensor",
weight_tensor: "Float8QuantizedTensor",
bias: Optional[torch.Tensor],
):
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
Expand Down Expand Up @@ -285,8 +336,8 @@ def _linear_fp8_act_fp8_weight_impl(


def _linear_fp_act_fp8_weight_check(
input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"],
input_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
weight_tensor: Union[torch.Tensor, "Float8QuantizedTensor"],
bias: Optional[torch.Tensor],
) -> bool:
return (
Expand All @@ -295,7 +346,7 @@ def _linear_fp_act_fp8_weight_check(
and input_tensor.is_floating_point()
and
# weight is float8 quantized affine quantized tensor
isinstance(weight_tensor, AffineQuantizedTensor)
isinstance(weight_tensor, Float8QuantizedTensor)
and isinstance(weight_tensor._layout, Float8Layout)
and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and (
Expand All @@ -307,7 +358,10 @@ def _linear_fp_act_fp8_weight_check(

def _linear_fp_act_fp8_weight_impl(
input_tensor: torch.Tensor,
weight_tensor: "AffineQuantizedTensor",
weight_tensor: "Float8QuantizedTensor",
bias: Optional[torch.Tensor],
):
return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)


to_affine_quantized_float8 = Float8QuantizedTensor.from_hp_to_float8
Loading
Loading