diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index e587d4bc2b..906dbec73c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -10,6 +10,12 @@ import torch +from torchao.float8.float8_utils import ( + ScalingGranularity, +) +from torchao.float8.float8_utils import ( + tensor_to_scale as tensor_to_float8_scale, +) from torchao.prototype.custom_fp_utils import ( _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, @@ -39,6 +45,9 @@ "MappingType", "ZeroPointDomain", "TorchAODType", + "choose_qparams_affine_float8", + "quantize_affine_float8", + "dequantize_affine_float8", ] @@ -1300,3 +1309,65 @@ def dequantize_affine_floatx( tensor = tensor * scale.float().view(-1, 1) tensor = tensor.to(dtype=output_dtype) return tensor + + +def choose_qparams_affine_float8( + tensor: torch.Tensor, float8_dtype: torch.dtype +) -> torch.Tensor: + """ + Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now: + # https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416 + scale = tensor_to_float8_scale( + tensor, + float8_dtype, + scaling_granularity=ScalingGranularity.AXISWISE, + axiswise_dim=1, + ) + return scale + + +def quantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + float8_dtype: torch.dtype, +) -> torch.Tensor: + """ + Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. + + Args: + tensor (torch.Tensor): Input tensor to be quantized. + scale (torch.Tensor): Scaling factor for the quantization. + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). + """ + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to multiply with the scale + # In order to match numerics between eager and compile, we upcast manually here. + tensor_scaled = tensor.to(torch.float32) * scale + max_value = torch.finfo(float8_dtype).max + tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) + fp8_tensor = tensor_clamped.to(float8_dtype) + return fp8_tensor + + +def dequantize_affine_float8( + tensor: torch.Tensor, + scale: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Dequantizes the float8 tensor to float32 tensor. + + Args: + tensor (torch.Tensor): Input float8 tensor to be dequantized. + scale (torch.Tensor): Scaling factor for the dequantization. + output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32). + """ + fp8_tensor = tensor.to(torch.float32) + hp_tensor = fp8_tensor / scale + return hp_tensor.to(output_dtype)