Skip to content

Commit 04476b5

Browse files
add separate quantization primitives for float8
ghstack-source-id: f973b55 ghstack-comment-id: 2608048970 Pull Request resolved: #1597
1 parent 32d9b0b commit 04476b5

File tree

2 files changed

+126
-3
lines changed

2 files changed

+126
-3
lines changed

test/quantization/test_quant_primitives.py

+57
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
import unittest
1010

1111
import torch
12+
from parameterized import parameterized
1213

1314
from torchao.dtypes.utils import is_device
15+
from torchao.float8.float8_utils import EPS as float8_eps
1416
from torchao.quantization.quant_primitives import (
1517
MappingType,
1618
ZeroPointDomain,
1719
choose_qparams_affine,
20+
choose_qparams_affine_float8,
1821
dequantize_affine,
22+
dequantize_affine_float8,
1923
fake_quantize_affine,
2024
fake_quantize_affine_cachemask,
2125
quantize_affine,
26+
quantize_affine_float8,
2227
)
2328

2429
# TODO: remove test for utils?
@@ -838,6 +843,58 @@ def test_fake_quantize_affine_cachemask(self):
838843
torch.testing.assert_close(dequantized, fake_quantized)
839844
torch.testing.assert_close(expected_mask, mask)
840845

846+
@parameterized.expand(
847+
[
848+
(
849+
torch.float32,
850+
torch.float8_e4m3fn,
851+
),
852+
]
853+
)
854+
def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
855+
input = torch.randn(10, 10)
856+
857+
# float8 quantization primitives
858+
scale = choose_qparams_affine_float8(input, float8_dtype)
859+
quantized = quantize_affine_float8(input, scale, float8_dtype)
860+
dequantized = dequantize_affine_float8(quantized, scale, hp_dtype)
861+
862+
# reference implementation using generic primitives
863+
expected_scale, _ = choose_qparams_affine(
864+
input,
865+
MappingType.SYMMETRIC,
866+
input.shape,
867+
float8_dtype,
868+
eps=float8_eps, # use same EPS as float8 training
869+
scale_dtype=torch.float32,
870+
quant_min=torch.finfo(float8_dtype).min,
871+
quant_max=torch.finfo(float8_dtype).max,
872+
)
873+
expected_quantized = quantize_affine(
874+
input,
875+
input.shape,
876+
scale,
877+
output_dtype=float8_dtype,
878+
quant_min=torch.finfo(float8_dtype).min,
879+
quant_max=torch.finfo(float8_dtype).max,
880+
zero_point=None,
881+
zero_point_domain=None,
882+
)
883+
expected_dequantized = dequantize_affine(
884+
expected_quantized,
885+
input.shape,
886+
scale,
887+
input_dtype=float8_dtype,
888+
quant_min=torch.finfo(float8_dtype).min,
889+
quant_max=torch.finfo(float8_dtype).max,
890+
zero_point=None,
891+
zero_point_domain=None,
892+
)
893+
894+
self.assertTrue(torch.equal(expected_scale, scale))
895+
torch.testing.assert_close(expected_quantized, quantized)
896+
torch.testing.assert_close(expected_dequantized, dequantized)
897+
841898

842899
if __name__ == "__main__":
843900
unittest.main()

torchao/quantization/quant_primitives.py

+69-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8-
from enum import Enum, auto
8+
from enum import auto, Enum
99
from typing import Callable, Dict, List, Optional, Tuple, Union
1010

1111
import torch
@@ -16,11 +16,11 @@
1616
_n_ones,
1717
)
1818
from torchao.utils import (
19+
_is_float8_type,
20+
_register_custom_op,
1921
TORCH_VERSION_AT_LEAST_2_3,
2022
TORCH_VERSION_AT_LEAST_2_5,
2123
TORCH_VERSION_AT_LEAST_2_6,
22-
_is_float8_type,
23-
_register_custom_op,
2424
)
2525

2626
__all__ = [
@@ -39,6 +39,9 @@
3939
"MappingType",
4040
"ZeroPointDomain",
4141
"TorchAODType",
42+
"choose_qparams_affine_float8",
43+
"quantize_affine_float8",
44+
"dequantize_affine_float8",
4245
]
4346

4447

@@ -1300,3 +1303,66 @@ def dequantize_affine_floatx(
13001303
tensor = tensor * scale.float().view(-1, 1)
13011304
tensor = tensor.to(dtype=output_dtype)
13021305
return tensor
1306+
1307+
1308+
def choose_qparams_affine_float8(
1309+
tensor: torch.Tensor, float8_dtype: torch.dtype
1310+
) -> torch.Tensor:
1311+
"""
1312+
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
1313+
1314+
Args:
1315+
tensor (torch.Tensor): Input tensor to be quantized.
1316+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1317+
"""
1318+
# only tensorwise scaling is supported for now:
1319+
quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max
1320+
min_val_neg = torch.min(tensor)
1321+
max_val_pos = torch.max(tensor)
1322+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
1323+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
1324+
return scale.to(dtype=torch.float32)
1325+
1326+
1327+
def quantize_affine_float8(
1328+
tensor: torch.Tensor,
1329+
scale: torch.Tensor,
1330+
float8_dtype: torch.dtype,
1331+
) -> torch.Tensor:
1332+
"""
1333+
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
1334+
1335+
Args:
1336+
tensor (torch.Tensor): Input tensor to be quantized.
1337+
scale (torch.Tensor): Scaling factor for the quantization.
1338+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1339+
"""
1340+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1341+
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
1342+
# In order to match numerics between eager and compile, we upcast manually here.
1343+
tensor_scaled = tensor.to(torch.float32) / scale
1344+
max_value = torch.finfo(float8_dtype).max
1345+
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
1346+
fp8_tensor = tensor_clamped.to(float8_dtype)
1347+
return fp8_tensor
1348+
1349+
1350+
def dequantize_affine_float8(
1351+
tensor: torch.Tensor,
1352+
scale: torch.Tensor,
1353+
output_dtype: torch.dtype = torch.float32,
1354+
) -> torch.Tensor:
1355+
"""
1356+
Dequantizes the float8 tensor to high precision tensor.
1357+
1358+
Args:
1359+
tensor (torch.Tensor): Input float8 tensor to be dequantized.
1360+
scale (torch.Tensor): Scaling factor for the dequantization.
1361+
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
1362+
"""
1363+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1364+
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
1365+
# In order to match numerics between eager and compile, we upcast manually here.
1366+
fp8_tensor = tensor.to(torch.float32)
1367+
hp_tensor = fp8_tensor * scale
1368+
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)