Skip to content

Commit 733ecec

Browse files
add separate quantization primitives for float8
ghstack-source-id: bc601f2 ghstack-comment-id: 2608048970 Pull Request resolved: #1597
1 parent 32d9b0b commit 733ecec

File tree

2 files changed

+129
-3
lines changed

2 files changed

+129
-3
lines changed

test/quantization/test_quant_primitives.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
import unittest
1010

1111
import torch
12+
from parameterized import parameterized
1213

1314
from torchao.dtypes.utils import is_device
15+
from torchao.float8.float8_utils import EPS as float8_eps
1416
from torchao.quantization.quant_primitives import (
15-
MappingType,
16-
ZeroPointDomain,
1717
choose_qparams_affine,
18+
choose_qparams_affine_float8,
1819
dequantize_affine,
20+
dequantize_affine_float8,
1921
fake_quantize_affine,
2022
fake_quantize_affine_cachemask,
23+
MappingType,
2124
quantize_affine,
25+
quantize_affine_float8,
26+
ZeroPointDomain,
2227
)
2328

2429
# TODO: remove test for utils?
@@ -29,11 +34,11 @@
2934
quantize_activation_per_token_absmax,
3035
)
3136
from torchao.utils import (
37+
is_fbcode,
3238
TORCH_VERSION_AT_LEAST_2_3,
3339
TORCH_VERSION_AT_LEAST_2_4,
3440
TORCH_VERSION_AT_LEAST_2_5,
3541
TORCH_VERSION_AT_LEAST_2_6,
36-
is_fbcode,
3742
)
3843

3944
_SEED = 1234
@@ -838,6 +843,58 @@ def test_fake_quantize_affine_cachemask(self):
838843
torch.testing.assert_close(dequantized, fake_quantized)
839844
torch.testing.assert_close(expected_mask, mask)
840845

846+
@parameterized.expand(
847+
[
848+
(
849+
torch.float32,
850+
torch.float8_e4m3fn,
851+
),
852+
]
853+
)
854+
def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
855+
input = torch.randn(10, 10)
856+
857+
# float8 quantization primitives
858+
scale = choose_qparams_affine_float8(input, float8_dtype)
859+
quantized = quantize_affine_float8(input, scale, float8_dtype)
860+
dequantized = dequantize_affine_float8(quantized, scale, hp_dtype)
861+
862+
# reference implementation using generic primitives
863+
expected_scale, _ = choose_qparams_affine(
864+
input,
865+
MappingType.SYMMETRIC,
866+
input.shape,
867+
float8_dtype,
868+
eps=float8_eps, # use same EPS as float8 training
869+
scale_dtype=torch.float32,
870+
quant_min=torch.finfo(float8_dtype).min,
871+
quant_max=torch.finfo(float8_dtype).max,
872+
)
873+
expected_quantized = quantize_affine(
874+
input,
875+
input.shape,
876+
scale,
877+
output_dtype=float8_dtype,
878+
quant_min=torch.finfo(float8_dtype).min,
879+
quant_max=torch.finfo(float8_dtype).max,
880+
zero_point=None,
881+
zero_point_domain=None,
882+
)
883+
expected_dequantized = dequantize_affine(
884+
expected_quantized,
885+
input.shape,
886+
scale,
887+
input_dtype=float8_dtype,
888+
quant_min=torch.finfo(float8_dtype).min,
889+
quant_max=torch.finfo(float8_dtype).max,
890+
zero_point=None,
891+
zero_point_domain=None,
892+
)
893+
894+
self.assertTrue(torch.equal(expected_scale, scale))
895+
torch.testing.assert_close(expected_quantized, quantized)
896+
torch.testing.assert_close(expected_dequantized, dequantized)
897+
841898

842899
if __name__ == "__main__":
843900
unittest.main()

torchao/quantization/quant_primitives.py

+69
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
"MappingType",
4040
"ZeroPointDomain",
4141
"TorchAODType",
42+
"choose_qparams_affine_float8",
43+
"quantize_affine_float8",
44+
"dequantize_affine_float8",
4245
]
4346

4447

@@ -1300,3 +1303,69 @@ def dequantize_affine_floatx(
13001303
tensor = tensor * scale.float().view(-1, 1)
13011304
tensor = tensor.to(dtype=output_dtype)
13021305
return tensor
1306+
1307+
1308+
def choose_qparams_affine_float8(
1309+
tensor: torch.Tensor, float8_dtype: torch.dtype
1310+
) -> torch.Tensor:
1311+
"""
1312+
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
1313+
1314+
Args:
1315+
tensor (torch.Tensor): Input tensor to be quantized.
1316+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1317+
"""
1318+
# only tensorwise scaling is supported for now:
1319+
quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max
1320+
min_val_neg = torch.min(tensor)
1321+
max_val_pos = torch.max(tensor)
1322+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
1323+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
1324+
return scale.to(dtype=torch.float32)
1325+
1326+
# max_val_pos = torch.max(-min_val_neg, max_val_pos)
1327+
# scale = max_val_pos / (float(quant_max - quant_min) / 2)
1328+
1329+
1330+
def quantize_affine_float8(
1331+
tensor: torch.Tensor,
1332+
scale: torch.Tensor,
1333+
float8_dtype: torch.dtype,
1334+
) -> torch.Tensor:
1335+
"""
1336+
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
1337+
1338+
Args:
1339+
tensor (torch.Tensor): Input tensor to be quantized.
1340+
scale (torch.Tensor): Scaling factor for the quantization.
1341+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1342+
"""
1343+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1344+
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
1345+
# In order to match numerics between eager and compile, we upcast manually here.
1346+
tensor_scaled = tensor.to(torch.float32) / scale
1347+
max_value = torch.finfo(float8_dtype).max
1348+
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
1349+
fp8_tensor = tensor_clamped.to(float8_dtype)
1350+
return fp8_tensor
1351+
1352+
1353+
def dequantize_affine_float8(
1354+
tensor: torch.Tensor,
1355+
scale: torch.Tensor,
1356+
output_dtype: torch.dtype = torch.float32,
1357+
) -> torch.Tensor:
1358+
"""
1359+
Dequantizes the float8 tensor to high precision tensor.
1360+
1361+
Args:
1362+
tensor (torch.Tensor): Input float8 tensor to be dequantized.
1363+
scale (torch.Tensor): Scaling factor for the dequantization.
1364+
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
1365+
"""
1366+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1367+
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
1368+
# In order to match numerics between eager and compile, we upcast manually here.
1369+
fp8_tensor = tensor.to(torch.float32)
1370+
hp_tensor = fp8_tensor * scale
1371+
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)