Skip to content

Commit 0729151

Browse files
add separate quantization primitives for float8
ghstack-source-id: 50780aa ghstack-comment-id: 2608048970 Pull Request resolved: #1597
1 parent 32d9b0b commit 0729151

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

torchao/quantization/quant_primitives.py

+71
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010

1111
import torch
1212

13+
from torchao.float8.float8_utils import (
14+
ScalingGranularity,
15+
)
16+
from torchao.float8.float8_utils import (
17+
tensor_to_scale as tensor_to_float8_scale,
18+
)
1319
from torchao.prototype.custom_fp_utils import (
1420
_f32_to_floatx_unpacked,
1521
_floatx_unpacked_to_f32,
@@ -39,6 +45,9 @@
3945
"MappingType",
4046
"ZeroPointDomain",
4147
"TorchAODType",
48+
"choose_qparams_affine_float8",
49+
"quantize_affine_float8",
50+
"dequantize_affine_float8",
4251
]
4352

4453

@@ -1300,3 +1309,65 @@ def dequantize_affine_floatx(
13001309
tensor = tensor * scale.float().view(-1, 1)
13011310
tensor = tensor.to(dtype=output_dtype)
13021311
return tensor
1312+
1313+
1314+
def choose_qparams_affine_float8(
1315+
tensor: torch.Tensor, float8_dtype: torch.dtype
1316+
) -> torch.Tensor:
1317+
"""
1318+
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
1319+
1320+
Args:
1321+
tensor (torch.Tensor): Input tensor to be quantized.
1322+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1323+
"""
1324+
# NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now:
1325+
# https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416
1326+
scale = tensor_to_float8_scale(
1327+
tensor,
1328+
float8_dtype,
1329+
scaling_granularity=ScalingGranularity.AXISWISE,
1330+
axiswise_dim=1,
1331+
)
1332+
return scale
1333+
1334+
1335+
def quantize_affine_float8(
1336+
tensor: torch.Tensor,
1337+
scale: torch.Tensor,
1338+
float8_dtype: torch.dtype,
1339+
) -> torch.Tensor:
1340+
"""
1341+
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
1342+
1343+
Args:
1344+
tensor (torch.Tensor): Input tensor to be quantized.
1345+
scale (torch.Tensor): Scaling factor for the quantization.
1346+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1347+
"""
1348+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1349+
# upcasted to `float32` to multiply with the scale
1350+
# In order to match numerics between eager and compile, we upcast manually here.
1351+
tensor_scaled = tensor.to(torch.float32) * scale
1352+
max_value = torch.finfo(float8_dtype).max
1353+
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
1354+
fp8_tensor = tensor_clamped.to(float8_dtype)
1355+
return fp8_tensor
1356+
1357+
1358+
def dequantize_affine_float8(
1359+
tensor: torch.Tensor,
1360+
scale: torch.Tensor,
1361+
output_dtype: torch.dtype = torch.float32,
1362+
) -> torch.Tensor:
1363+
"""
1364+
Dequantizes the float8 tensor to float32 tensor.
1365+
1366+
Args:
1367+
tensor (torch.Tensor): Input float8 tensor to be dequantized.
1368+
scale (torch.Tensor): Scaling factor for the dequantization.
1369+
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
1370+
"""
1371+
fp8_tensor = tensor.to(torch.float32)
1372+
hp_tensor = fp8_tensor / scale
1373+
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)