Skip to content

Commit 47f96f1

Browse files
add separate quantization primitives for float8 (#1597)
1 parent 6b472e5 commit 47f96f1

File tree

2 files changed

+137
-0
lines changed

2 files changed

+137
-0
lines changed

test/quantization/test_quant_primitives.py

+70
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
import unittest
1010

1111
import torch
12+
from parameterized import parameterized
1213

1314
from torchao.dtypes.utils import is_device
15+
from torchao.float8.float8_utils import EPS as float8_eps
1416
from torchao.quantization.quant_primitives import (
1517
MappingType,
1618
ZeroPointDomain,
1719
choose_qparams_affine,
20+
choose_qparams_affine_float8,
1821
dequantize_affine,
22+
dequantize_affine_float8,
1923
fake_quantize_affine,
2024
fake_quantize_affine_cachemask,
2125
quantize_affine,
26+
quantize_affine_float8,
2227
)
2328

2429
# TODO: remove test for utils?
@@ -838,6 +843,71 @@ def test_fake_quantize_affine_cachemask(self):
838843
torch.testing.assert_close(dequantized, fake_quantized)
839844
torch.testing.assert_close(expected_mask, mask)
840845

846+
@parameterized.expand(
847+
[
848+
(
849+
torch.float32,
850+
torch.float8_e4m3fn,
851+
),
852+
(
853+
torch.float32,
854+
torch.float8_e5m2,
855+
),
856+
(
857+
torch.bfloat16,
858+
torch.float8_e4m3fn,
859+
),
860+
(
861+
torch.bfloat16,
862+
torch.float8_e5m2,
863+
),
864+
]
865+
)
866+
def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
867+
input = torch.randn(10, 10)
868+
869+
# float8 quantization primitives
870+
scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype)
871+
quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype)
872+
dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype)
873+
874+
# reference implementation using generic primitives
875+
expected_scale, _ = choose_qparams_affine(
876+
input,
877+
MappingType.SYMMETRIC,
878+
input.shape,
879+
float8_dtype,
880+
eps=float8_eps, # use same EPS as float8 training
881+
scale_dtype=torch.float32,
882+
quant_min=torch.finfo(float8_dtype).min,
883+
quant_max=torch.finfo(float8_dtype).max,
884+
)
885+
expected_quantized = quantize_affine(
886+
input,
887+
input.shape,
888+
scale,
889+
output_dtype=float8_dtype,
890+
quant_min=torch.finfo(float8_dtype).min,
891+
quant_max=torch.finfo(float8_dtype).max,
892+
zero_point=None,
893+
zero_point_domain=None,
894+
)
895+
expected_dequantized = dequantize_affine(
896+
expected_quantized,
897+
input.shape,
898+
scale,
899+
input_dtype=float8_dtype,
900+
output_dtype=hp_dtype,
901+
quant_min=torch.finfo(float8_dtype).min,
902+
quant_max=torch.finfo(float8_dtype).max,
903+
zero_point=None,
904+
zero_point_domain=None,
905+
)
906+
907+
self.assertTrue(torch.equal(expected_scale, scale))
908+
torch.testing.assert_close(expected_quantized, quantized)
909+
torch.testing.assert_close(expected_dequantized, dequantized)
910+
841911

842912
if __name__ == "__main__":
843913
unittest.main()

torchao/quantization/quant_primitives.py

+67
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
"MappingType",
4040
"ZeroPointDomain",
4141
"TorchAODType",
42+
"choose_qparams_affine_float8",
43+
"quantize_affine_float8",
44+
"dequantize_affine_float8",
4245
]
4346

4447

@@ -1300,3 +1303,67 @@ def dequantize_affine_floatx(
13001303
tensor = tensor * scale.float().view(-1, 1)
13011304
tensor = tensor.to(dtype=output_dtype)
13021305
return tensor
1306+
1307+
1308+
def choose_qparams_affine_float8(
1309+
tensor: torch.Tensor,
1310+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
1311+
) -> torch.Tensor:
1312+
"""
1313+
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
1314+
1315+
Args:
1316+
tensor (torch.Tensor): Input tensor to be quantized.
1317+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1318+
"""
1319+
# only tensorwise scaling is supported for now:
1320+
quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max
1321+
min_val_neg = torch.min(tensor)
1322+
max_val_pos = torch.max(tensor)
1323+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
1324+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
1325+
return scale.to(dtype=torch.float32)
1326+
1327+
1328+
def quantize_affine_float8(
1329+
tensor: torch.Tensor,
1330+
scale: torch.Tensor,
1331+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
1332+
) -> torch.Tensor:
1333+
"""
1334+
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
1335+
1336+
Args:
1337+
tensor (torch.Tensor): Input tensor to be quantized.
1338+
scale (torch.Tensor): Scaling factor for the quantization.
1339+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1340+
"""
1341+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1342+
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
1343+
# In order to match numerics between eager and compile, we upcast manually here.
1344+
tensor_scaled = tensor.to(torch.float32) / scale
1345+
max_value = torch.finfo(float8_dtype).max
1346+
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
1347+
fp8_tensor = tensor_clamped.to(float8_dtype)
1348+
return fp8_tensor
1349+
1350+
1351+
def dequantize_affine_float8(
1352+
tensor: torch.Tensor,
1353+
scale: torch.Tensor,
1354+
output_dtype: torch.dtype = torch.float32,
1355+
) -> torch.Tensor:
1356+
"""
1357+
Dequantizes the float8 tensor to high precision tensor.
1358+
1359+
Args:
1360+
tensor (torch.Tensor): Input float8 tensor to be dequantized.
1361+
scale (torch.Tensor): Scaling factor for the dequantization.
1362+
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
1363+
"""
1364+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1365+
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
1366+
# In order to match numerics between eager and compile, we upcast manually here.
1367+
fp8_tensor = tensor.to(torch.float32)
1368+
hp_tensor = fp8_tensor * scale
1369+
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)