Skip to content

Commit 44951b2

Browse files
add separate quantization primitives for float8
ghstack-source-id: 51628a9d0c9bcdc03a77b1ddcb5ab002f49f856e ghstack-comment-id: 2608048970 Pull Request resolved: #1597
1 parent 32d9b0b commit 44951b2

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

torchao/quantization/quant_api.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -66,21 +66,13 @@
6666
Int8DynActInt4WeightGPTQQuantizer,
6767
Int8DynActInt4WeightQuantizer,
6868
)
69-
from .granularity import (
70-
PerRow,
71-
PerTensor,
72-
)
69+
from .granularity import PerRow, PerTensor
7370
from .linear_activation_quantized_tensor import (
7471
LinearActivationQuantizedTensor,
7572
to_linear_activation_quantized,
7673
)
77-
from .qat import (
78-
intx_quantization_aware_training,
79-
)
80-
from .quant_primitives import (
81-
MappingType,
82-
ZeroPointDomain,
83-
)
74+
from .qat import intx_quantization_aware_training
75+
from .quant_primitives import MappingType, ZeroPointDomain
8476
from .subclass import (
8577
Int4WeightOnlyQuantizedLinearWeight,
8678
Int8DynamicallyQuantizedLinearWeight,
@@ -915,10 +907,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
915907
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
916908
quantization + 2:4 sparsity to linear layers.
917909
"""
918-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
910+
warnings.warn(
911+
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
919912
920913
from torchao.dtypes import SemiSparseLayout
921-
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
914+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
915+
)
922916

923917
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
924918

torchao/quantization/quant_primitives.py

+71
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010

1111
import torch
1212

13+
from torchao.float8.float8_utils import (
14+
ScalingGranularity,
15+
)
16+
from torchao.float8.float8_utils import (
17+
tensor_to_scale as tensor_to_float8_scale,
18+
)
1319
from torchao.prototype.custom_fp_utils import (
1420
_f32_to_floatx_unpacked,
1521
_floatx_unpacked_to_f32,
@@ -39,6 +45,9 @@
3945
"MappingType",
4046
"ZeroPointDomain",
4147
"TorchAODType",
48+
"choose_qparams_affine_float8",
49+
"quantize_affine_float8",
50+
"dequantize_affine_float8",
4251
]
4352

4453

@@ -1300,3 +1309,65 @@ def dequantize_affine_floatx(
13001309
tensor = tensor * scale.float().view(-1, 1)
13011310
tensor = tensor.to(dtype=output_dtype)
13021311
return tensor
1312+
1313+
1314+
def choose_qparams_affine_float8(
1315+
tensor: torch.Tensor, float8_dtype: torch.dtype
1316+
) -> torch.Tensor:
1317+
"""
1318+
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
1319+
1320+
Args:
1321+
tensor (torch.Tensor): Input tensor to be quantized.
1322+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1323+
"""
1324+
# NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now:
1325+
# https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416
1326+
scale = tensor_to_float8_scale(
1327+
tensor,
1328+
float8_dtype,
1329+
scaling_granularity=ScalingGranularity.AXISWISE,
1330+
axiswise_dim=1,
1331+
)
1332+
return scale
1333+
1334+
1335+
def quantize_affine_float8(
1336+
tensor: torch.Tensor,
1337+
scale: torch.Tensor,
1338+
float8_dtype: torch.dtype,
1339+
) -> torch.Tensor:
1340+
"""
1341+
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
1342+
1343+
Args:
1344+
tensor (torch.Tensor): Input tensor to be quantized.
1345+
scale (torch.Tensor): Scaling factor for the quantization.
1346+
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1347+
"""
1348+
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
1349+
# upcasted to `float32` to multiply with the scale
1350+
# In order to match numerics between eager and compile, we upcast manually here.
1351+
tensor_scaled = tensor.to(torch.float32) * scale
1352+
max_value = torch.finfo(float8_dtype).max
1353+
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
1354+
fp8_tensor = tensor_clamped.to(float8_dtype)
1355+
return fp8_tensor
1356+
1357+
1358+
def dequantize_affine_float8(
1359+
tensor: torch.Tensor,
1360+
scale: torch.Tensor,
1361+
output_dtype: torch.dtype = torch.float32,
1362+
) -> torch.Tensor:
1363+
"""
1364+
Dequantizes the float8 tensor to float32 tensor.
1365+
1366+
Args:
1367+
tensor (torch.Tensor): Input float8 tensor to be dequantized.
1368+
scale (torch.Tensor): Scaling factor for the dequantization.
1369+
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
1370+
"""
1371+
fp8_tensor = tensor.to(torch.float32)
1372+
hp_tensor = fp8_tensor / scale
1373+
return hp_tensor.to(output_dtype)

0 commit comments

Comments
 (0)