Skip to content

Commit 102d4a4

Browse files
add separate quantization primitives for float8
ghstack-source-id: ec5a5c2 ghstack-comment-id: 2608048970 Pull Request resolved: #1597
1 parent 32d9b0b commit 102d4a4

File tree

2 files changed

+142
-6
lines changed

2 files changed

+142
-6
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,21 @@
99
import unittest
1010

1111
import torch
12+
from parameterized import parameterized
1213

1314
from torchao.dtypes.utils import is_device
15+
from torchao.float8.float8_utils import EPS as float8_eps
1416
from torchao.quantization.quant_primitives import (
15-
MappingType,
16-
ZeroPointDomain,
1717
choose_qparams_affine,
18+
choose_qparams_affine_float8,
1819
dequantize_affine,
20+
dequantize_affine_float8,
1921
fake_quantize_affine,
2022
fake_quantize_affine_cachemask,
23+
MappingType,
2124
quantize_affine,
25+
quantize_affine_float8,
26+
ZeroPointDomain,
2227
)
2328

2429
# TODO: remove test for utils?
@@ -29,11 +34,11 @@
2934
quantize_activation_per_token_absmax,
3035
)
3136
from torchao.utils import (
37+
is_fbcode,
3238
TORCH_VERSION_AT_LEAST_2_3,
3339
TORCH_VERSION_AT_LEAST_2_4,
3440
TORCH_VERSION_AT_LEAST_2_5,
3541
TORCH_VERSION_AT_LEAST_2_6,
36-
is_fbcode,
3742
)
3843

3944
_SEED = 1234
@@ -838,6 +843,71 @@ def test_fake_quantize_affine_cachemask(self):
838843
torch.testing.assert_close(dequantized, fake_quantized)
839844
torch.testing.assert_close(expected_mask, mask)
840845

846+
@parameterized.expand(
847+
[
848+
(
849+
torch.float32,
850+
torch.float8_e4m3fn,
851+
),
852+
(
853+
torch.float32,
854+
torch.float8_e5m2,
855+
),
856+
(
857+
torch.bfloat16,
858+
torch.float8_e4m3fn,
859+
),
860+
(
861+
torch.bfloat16,
862+
torch.float8_e5m2,
863+
),
864+
]
865+
)
866+
def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
867+
input = torch.randn(10, 10)
868+
869+
# float8 quantization primitives
870+
scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype)
871+
quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype)
872+
dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype)
873+
874+
# reference implementation using generic primitives
875+
expected_scale, _ = choose_qparams_affine(
876+
input,
877+
MappingType.SYMMETRIC,
878+
input.shape,
879+
float8_dtype,
880+
eps=float8_eps, # use same EPS as float8 training
881+
scale_dtype=torch.float32,
882+
quant_min=torch.finfo(float8_dtype).min,
883+
quant_max=torch.finfo(float8_dtype).max,
884+
)
885+
expected_quantized = quantize_affine(
886+
input,
887+
input.shape,
888+
scale,
889+
output_dtype=float8_dtype,
890+
quant_min=torch.finfo(float8_dtype).min,
891+
quant_max=torch.finfo(float8_dtype).max,
892+
zero_point=None,
893+
zero_point_domain=None,
894+
)
895+
expected_dequantized = dequantize_affine(
896+
expected_quantized,
897+
input.shape,
898+
scale,
899+
input_dtype=float8_dtype,
900+
output_dtype=hp_dtype,
901+
quant_min=torch.finfo(float8_dtype).min,
902+
quant_max=torch.finfo(float8_dtype).max,
903+
zero_point=None,
904+
zero_point_domain=None,
905+
)
906+
907+
self.assertTrue(torch.equal(expected_scale, scale))
908+
torch.testing.assert_close(expected_quantized, quantized)
909+
torch.testing.assert_close(expected_dequantized, dequantized)
910+
841911

842912
if __name__ == "__main__":
843913
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import math
8-
from enum import Enum, auto
8+
from enum import auto, Enum
99
from typing import Callable, Dict, List, Optional, Tuple, Union
1010

1111
import torch
@@ -16,11 +16,11 @@
1616
_n_ones,
1717
)
1818
from torchao.utils import (
19+
_is_float8_type,
20+
_register_custom_op,
1921
TORCH_VERSION_AT_LEAST_2_3,
2022
TORCH_VERSION_AT_LEAST_2_5,
2123
TORCH_VERSION_AT_LEAST_2_6,
22-
_is_float8_type,
23-
_register_custom_op,
2424
)
2525

2626
__all__ = [
@@ -39,6 +39,9 @@
3939
"MappingType",
4040
"ZeroPointDomain",
4141
"TorchAODType",
42+
"choose_qparams_affine_float8",
43+
"quantize_affine_float8",
44+
"dequantize_affine_float8",
4245
]
4346

4447

@@ -1300,3 +1303,66 @@ def dequantize_affine_floatx(
13001303
tensor = tensor * scale.float().view(-1, 1)
13011304
tensor = tensor.to(dtype=output_dtype)
13021305
return tensor
1306+
1307+
1308+
def choose_qparams_affine_float8(
1309+
tensor: torch.Tensor, float8_dtype: torch.dtype = torch.float8_e4m3fn,
1310+
) -> torch.Tensor:
1311+
"""
1312+
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
1313+
1314+
Args:
1315+
tensor (torch.Tensor): Input tensor to be quantized.
1316+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1317+
"""
1318+
# only tensorwise scaling is supported for now:
1319+
quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max
1320+
min_val_neg = torch.min(tensor)
1321+
max_val_pos = torch.max(tensor)
1322+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
1323+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
1324+
return scale.to(dtype=torch.float32)
1325+
1326+
1327+
def quantize_affine_float8(
1328+
tensor: torch.Tensor,
1329+
scale: torch.Tensor,
1330+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
1331+
) -> torch.Tensor:
1332+
"""
1333+
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
1334+
1335+
Args:
1336+
tensor (torch.Tensor): Input tensor to be quantized.
1337+
scale (torch.Tensor): Scaling factor for the quantization.
1338+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1339+
"""
1340+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1341+
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
1342+
# In order to match numerics between eager and compile, we upcast manually here.
1343+
tensor_scaled = tensor.to(torch.float32) / scale
1344+
max_value = torch.finfo(float8_dtype).max
1345+
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
1346+
fp8_tensor = tensor_clamped.to(float8_dtype)
1347+
return fp8_tensor
1348+
1349+
1350+
def dequantize_affine_float8(
1351+
tensor: torch.Tensor,
1352+
scale: torch.Tensor,
1353+
output_dtype: torch.dtype = torch.float32,
1354+
) -> torch.Tensor:
1355+
"""
1356+
Dequantizes the float8 tensor to high precision tensor.
1357+
1358+
Args:
1359+
tensor (torch.Tensor): Input float8 tensor to be dequantized.
1360+
scale (torch.Tensor): Scaling factor for the dequantization.
1361+
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
1362+
"""
1363+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1364+
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
1365+
# In order to match numerics between eager and compile, we upcast manually here.
1366+
fp8_tensor = tensor.to(torch.float32)
1367+
hp_tensor = fp8_tensor * scale
1368+
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)