Skip to content

Commit

Permalink
add separate quantization primitives for float8
Browse files Browse the repository at this point in the history
ghstack-source-id: 50780aa701de01474ce520235f576909528141c6
ghstack-comment-id: 2608048970
Pull Request resolved: #1597
  • Loading branch information
danielvegamyhre committed Jan 22, 2025
1 parent 32d9b0b commit 6b982ab
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@

import torch

from torchao.float8.float8_utils import (
ScalingGranularity,
)
from torchao.float8.float8_utils import (
tensor_to_scale as tensor_to_float8_scale,
)
from torchao.prototype.custom_fp_utils import (
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
Expand Down Expand Up @@ -39,6 +45,9 @@
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"choose_qparams_affine_float8",
"quantize_affine_float8",
"dequantize_affine_float8",
]


Expand Down Expand Up @@ -1300,3 +1309,65 @@ def dequantize_affine_floatx(
tensor = tensor * scale.float().view(-1, 1)
tensor = tensor.to(dtype=output_dtype)
return tensor


def choose_qparams_affine_float8(
tensor: torch.Tensor, float8_dtype: torch.dtype
) -> torch.Tensor:
"""
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now:
# https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416
scale = tensor_to_float8_scale(
tensor,
float8_dtype,
scaling_granularity=ScalingGranularity.AXISWISE,
axiswise_dim=1,
)
return scale


def quantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype,
) -> torch.Tensor:
"""
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
scale (torch.Tensor): Scaling factor for the quantization.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) * scale
max_value = torch.finfo(float8_dtype).max
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
fp8_tensor = tensor_clamped.to(float8_dtype)
return fp8_tensor


def dequantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to float32 tensor.
Args:
tensor (torch.Tensor): Input float8 tensor to be dequantized.
scale (torch.Tensor): Scaling factor for the dequantization.
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
"""
fp8_tensor = tensor.to(torch.float32)
hp_tensor = fp8_tensor / scale
return hp_tensor.to(output_dtype)

0 comments on commit 6b982ab

Please sign in to comment.