|
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 |
|
| 13 | +from torchao.float8.float8_utils import ( |
| 14 | + ScalingGranularity, |
| 15 | +) |
| 16 | +from torchao.float8.float8_utils import ( |
| 17 | + tensor_to_scale as tensor_to_float8_scale, |
| 18 | +) |
13 | 19 | from torchao.prototype.custom_fp_utils import (
|
14 | 20 | _f32_to_floatx_unpacked,
|
15 | 21 | _floatx_unpacked_to_f32,
|
|
39 | 45 | "MappingType",
|
40 | 46 | "ZeroPointDomain",
|
41 | 47 | "TorchAODType",
|
| 48 | + "choose_qparams_affine_float8", |
| 49 | + "quantize_affine_float8", |
| 50 | + "dequantize_affine_float8", |
42 | 51 | ]
|
43 | 52 |
|
44 | 53 |
|
@@ -1300,3 +1309,65 @@ def dequantize_affine_floatx(
|
1300 | 1309 | tensor = tensor * scale.float().view(-1, 1)
|
1301 | 1310 | tensor = tensor.to(dtype=output_dtype)
|
1302 | 1311 | return tensor
|
| 1312 | + |
| 1313 | + |
| 1314 | +def choose_qparams_affine_float8( |
| 1315 | + tensor: torch.Tensor, float8_dtype: torch.dtype |
| 1316 | +) -> torch.Tensor: |
| 1317 | + """ |
| 1318 | + Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. |
| 1319 | +
|
| 1320 | + Args: |
| 1321 | + tensor (torch.Tensor): Input tensor to be quantized. |
| 1322 | + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). |
| 1323 | + """ |
| 1324 | + # NOTE: quantization primitives are hardcoded to use axiswise granularity w/ axis=1 right now: |
| 1325 | + # https://github.com/pytorch/ao/blob/5d1444bdef6df15eb89c4c5716ede1c5f8677798/torchao/dtypes/affine_quantized_tensor.py#L416 |
| 1326 | + scale = tensor_to_float8_scale( |
| 1327 | + tensor, |
| 1328 | + float8_dtype, |
| 1329 | + scaling_granularity=ScalingGranularity.AXISWISE, |
| 1330 | + axiswise_dim=1, |
| 1331 | + ) |
| 1332 | + return scale |
| 1333 | + |
| 1334 | + |
| 1335 | +def quantize_affine_float8( |
| 1336 | + tensor: torch.Tensor, |
| 1337 | + scale: torch.Tensor, |
| 1338 | + float8_dtype: torch.dtype, |
| 1339 | +) -> torch.Tensor: |
| 1340 | + """ |
| 1341 | + Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. |
| 1342 | +
|
| 1343 | + Args: |
| 1344 | + tensor (torch.Tensor): Input tensor to be quantized. |
| 1345 | + scale (torch.Tensor): Scaling factor for the quantization. |
| 1346 | + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). |
| 1347 | + """ |
| 1348 | + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically |
| 1349 | + # upcasted to `float32` to multiply with the scale |
| 1350 | + # In order to match numerics between eager and compile, we upcast manually here. |
| 1351 | + tensor_scaled = tensor.to(torch.float32) * scale |
| 1352 | + max_value = torch.finfo(float8_dtype).max |
| 1353 | + tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) |
| 1354 | + fp8_tensor = tensor_clamped.to(float8_dtype) |
| 1355 | + return fp8_tensor |
| 1356 | + |
| 1357 | + |
| 1358 | +def dequantize_affine_float8( |
| 1359 | + tensor: torch.Tensor, |
| 1360 | + scale: torch.Tensor, |
| 1361 | + output_dtype: torch.dtype = torch.float32, |
| 1362 | +) -> torch.Tensor: |
| 1363 | + """ |
| 1364 | + Dequantizes the float8 tensor to float32 tensor. |
| 1365 | +
|
| 1366 | + Args: |
| 1367 | + tensor (torch.Tensor): Input float8 tensor to be dequantized. |
| 1368 | + scale (torch.Tensor): Scaling factor for the dequantization. |
| 1369 | + output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32). |
| 1370 | + """ |
| 1371 | + fp8_tensor = tensor.to(torch.float32) |
| 1372 | + hp_tensor = fp8_tensor / scale |
| 1373 | + return hp_tensor.to(output_dtype) |
0 commit comments