Skip to content

Commit 408d065

Browse files
add separate quantization primitives for float8
ghstack-source-id: b5340379b49bdab5e00e4c27d7444dc7d7f1acd7 ghstack-comment-id: 2608048970 Pull Request resolved: #1597
1 parent 32d9b0b commit 408d065

File tree

1 file changed

+72
-3
lines changed

1 file changed

+72
-3
lines changed

torchao/quantization/quant_primitives.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,26 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8-
from enum import Enum, auto
8+
from enum import auto, Enum
99
from typing import Callable, Dict, List, Optional, Tuple, Union
1010

1111
import torch
12+
from torchao.float8.float8_utils import (
13+
ScalingGranularity,
14+
tensor_to_scale as tensor_to_float8_scale,
15+
)
1216

1317
from torchao.prototype.custom_fp_utils import (
1418
_f32_to_floatx_unpacked,
1519
_floatx_unpacked_to_f32,
1620
_n_ones,
1721
)
1822
from torchao.utils import (
23+
_is_float8_type,
24+
_register_custom_op,
1925
TORCH_VERSION_AT_LEAST_2_3,
2026
TORCH_VERSION_AT_LEAST_2_5,
2127
TORCH_VERSION_AT_LEAST_2_6,
22-
_is_float8_type,
23-
_register_custom_op,
2428
)
2529

2630
__all__ = [
@@ -39,6 +43,9 @@
3943
"MappingType",
4044
"ZeroPointDomain",
4145
"TorchAODType",
46+
"choose_qparams_affine_float8",
47+
"quantize_affine_float8",
48+
"dequantize_affine_float8",
4249
]
4350

4451

@@ -1300,3 +1307,65 @@ def dequantize_affine_floatx(
13001307
tensor = tensor * scale.float().view(-1, 1)
13011308
tensor = tensor.to(dtype=output_dtype)
13021309
return tensor
1310+
1311+
1312+
def choose_qparams_affine_float8(
1313+
tensor: torch.Tensor, float8_dtype: torch.dtype
1314+
) -> torch.Tensor:
1315+
"""
1316+
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
1317+
1318+
Args:
1319+
tensor (torch.Tensor): Input tensor to be quantized.
1320+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1321+
"""
1322+
# NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now:
1323+
# https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416
1324+
scale = tensor_to_float8_scale(
1325+
tensor,
1326+
float8_dtype,
1327+
scaling_granularity=ScalingGranularity.AXISWISE,
1328+
axiswise_dim=1,
1329+
)
1330+
return scale
1331+
1332+
1333+
def quantize_affine_float8(
1334+
tensor: torch.Tensor,
1335+
scale: torch.Tensor,
1336+
float8_dtype: torch.dtype,
1337+
) -> torch.Tensor:
1338+
"""
1339+
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
1340+
1341+
Args:
1342+
tensor (torch.Tensor): Input tensor to be quantized.
1343+
scale (torch.Tensor): Scaling factor for the quantization.
1344+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1345+
"""
1346+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1347+
# upcasted to `float32` to multiply with the scale
1348+
# In order to match numerics between eager and compile, we upcast manually here.
1349+
tensor_scaled = tensor.to(torch.float32) * scale
1350+
max_value = torch.finfo(float8_dtype).max
1351+
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
1352+
fp8_tensor = tensor_clamped.to(float8_dtype)
1353+
return fp8_tensor
1354+
1355+
1356+
def dequantize_affine_float8(
1357+
tensor: torch.Tensor,
1358+
scale: torch.Tensor,
1359+
output_dtype: torch.dtype = torch.float32,
1360+
) -> torch.Tensor:
1361+
"""
1362+
Dequantizes the float8 tensor to float32 tensor.
1363+
1364+
Args:
1365+
tensor (torch.Tensor): Input float8 tensor to be dequantized.
1366+
scale (torch.Tensor): Scaling factor for the dequantization.
1367+
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
1368+
"""
1369+
fp8_tensor = tensor.to(torch.float32)
1370+
hp_tensor = fp8_tensor / scale
1371+
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)