|
39 | 39 | "MappingType",
|
40 | 40 | "ZeroPointDomain",
|
41 | 41 | "TorchAODType",
|
| 42 | + "choose_qparams_affine_float8", |
| 43 | + "quantize_affine_float8", |
| 44 | + "dequantize_affine_float8", |
42 | 45 | ]
|
43 | 46 |
|
44 | 47 |
|
@@ -1300,3 +1303,67 @@ def dequantize_affine_floatx(
|
1300 | 1303 | tensor = tensor * scale.float().view(-1, 1)
|
1301 | 1304 | tensor = tensor.to(dtype=output_dtype)
|
1302 | 1305 | return tensor
|
| 1306 | + |
| 1307 | + |
| 1308 | +def choose_qparams_affine_float8( |
| 1309 | + tensor: torch.Tensor, |
| 1310 | + float8_dtype: torch.dtype = torch.float8_e4m3fn, |
| 1311 | +) -> torch.Tensor: |
| 1312 | + """ |
| 1313 | + Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity. |
| 1314 | +
|
| 1315 | + Args: |
| 1316 | + tensor (torch.Tensor): Input tensor to be quantized. |
| 1317 | + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). |
| 1318 | + """ |
| 1319 | + # only tensorwise scaling is supported for now: |
| 1320 | + quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max |
| 1321 | + min_val_neg = torch.min(tensor) |
| 1322 | + max_val_pos = torch.max(tensor) |
| 1323 | + max_val_pos = torch.max(-min_val_neg, max_val_pos) |
| 1324 | + scale = max_val_pos / (float(quant_max - quant_min) / 2) |
| 1325 | + return scale.to(dtype=torch.float32) |
| 1326 | + |
| 1327 | + |
| 1328 | +def quantize_affine_float8( |
| 1329 | + tensor: torch.Tensor, |
| 1330 | + scale: torch.Tensor, |
| 1331 | + float8_dtype: torch.dtype = torch.float8_e4m3fn, |
| 1332 | +) -> torch.Tensor: |
| 1333 | + """ |
| 1334 | + Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor. |
| 1335 | +
|
| 1336 | + Args: |
| 1337 | + tensor (torch.Tensor): Input tensor to be quantized. |
| 1338 | + scale (torch.Tensor): Scaling factor for the quantization. |
| 1339 | + float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2). |
| 1340 | + """ |
| 1341 | + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically |
| 1342 | + # upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization. |
| 1343 | + # In order to match numerics between eager and compile, we upcast manually here. |
| 1344 | + tensor_scaled = tensor.to(torch.float32) / scale |
| 1345 | + max_value = torch.finfo(float8_dtype).max |
| 1346 | + tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value) |
| 1347 | + fp8_tensor = tensor_clamped.to(float8_dtype) |
| 1348 | + return fp8_tensor |
| 1349 | + |
| 1350 | + |
| 1351 | +def dequantize_affine_float8( |
| 1352 | + tensor: torch.Tensor, |
| 1353 | + scale: torch.Tensor, |
| 1354 | + output_dtype: torch.dtype = torch.float32, |
| 1355 | +) -> torch.Tensor: |
| 1356 | + """ |
| 1357 | + Dequantizes the float8 tensor to high precision tensor. |
| 1358 | +
|
| 1359 | + Args: |
| 1360 | + tensor (torch.Tensor): Input float8 tensor to be dequantized. |
| 1361 | + scale (torch.Tensor): Scaling factor for the dequantization. |
| 1362 | + output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32). |
| 1363 | + """ |
| 1364 | + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically |
| 1365 | + # upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization. |
| 1366 | + # In order to match numerics between eager and compile, we upcast manually here. |
| 1367 | + fp8_tensor = tensor.to(torch.float32) |
| 1368 | + hp_tensor = fp8_tensor * scale |
| 1369 | + return hp_tensor.to(output_dtype) |
0 commit comments