|
16 | 16 | * Zeros: N/A
|
17 | 17 | """
|
18 | 18 |
|
| 19 | +from enum import Enum, auto |
19 | 20 | from typing import Dict, Union
|
20 | 21 |
|
21 | 22 | import torch
|
|
53 | 54 | unpack_uint4,
|
54 | 55 | )
|
55 | 56 |
|
| 57 | +# TODO(later): read from somewhere else? |
| 58 | +SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 |
| 59 | +EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 |
| 60 | +EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 |
| 61 | +EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 |
| 62 | +EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3 |
| 63 | +EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2 |
| 64 | + |
| 65 | + |
| 66 | +class ScaleCalculationMode(Enum): |
| 67 | + """ |
| 68 | + Enum representing the different methods for calculating MX block scaling. |
| 69 | + There are three methods available: |
| 70 | + FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp). |
| 71 | + It result in overflow issues for large values and bad for gradient quantization. |
| 72 | + CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor. |
| 73 | + It uses X = 2^ceil(log2(max_abs(v))-max_exp). |
| 74 | + EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)). |
| 75 | + It provides better accuracy for MX4 training compared to FLOOR and CEIL. |
| 76 | + By default, we use the EVEN method for better accuracy. |
| 77 | + """ |
| 78 | + |
| 79 | + FLOOR = auto() |
| 80 | + CEIL = auto() |
| 81 | + EVEN = auto() |
| 82 | + |
56 | 83 |
|
57 | 84 | def to_mx(
|
58 | 85 | data_hp: torch.Tensor,
|
59 | 86 | elem_dtype: Union[torch.dtype, str],
|
60 | 87 | block_size: int,
|
| 88 | + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, |
61 | 89 | ):
|
62 | 90 | """
|
63 | 91 | Takes a high precision tensor and converts to MX scale and raw data, in
|
@@ -88,25 +116,45 @@ def to_mx(
|
88 | 116 | # where the values are zero.
|
89 | 117 | eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)
|
90 | 118 |
|
91 |
| - # Find largest power of 2 less than or equal to max_abs. |
92 |
| - largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps)) |
93 |
| - |
94 | 119 | # Set X to be the largest power-of-two less than or equal to
|
95 | 120 | # max_abs(v), divided by the largest power of two representable
|
96 |
| - # in the element data type |
| 121 | + # in the element data type, and get the mbits at the same time |
97 | 122 | if elem_dtype == torch.float8_e4m3fn:
|
98 | 123 | target_max_pow2 = F8E4M3_MAX_POW2
|
| 124 | + mbits = MBITS_F8_E4M3 |
99 | 125 | elif elem_dtype == torch.float8_e5m2:
|
100 | 126 | target_max_pow2 = F8E5M2_MAX_POW2
|
| 127 | + mbits = MBITS_F8_E5M2 |
101 | 128 | elif elem_dtype == DTYPE_FP6_E2M3:
|
102 | 129 | target_max_pow2 = F6_E2M3_MAX_POW2
|
| 130 | + mbits = MBITS_F6_E2M3 |
103 | 131 | elif elem_dtype == DTYPE_FP6_E3M2:
|
104 | 132 | target_max_pow2 = F6_E3M2_MAX_POW2
|
| 133 | + mbits = MBITS_F6_E3M2 |
105 | 134 | elif elem_dtype == DTYPE_FP4:
|
106 | 135 | target_max_pow2 = F4_E2M1_MAX_POW2
|
| 136 | + mbits = MBITS_F4_E2M1 |
107 | 137 | else:
|
108 |
| - raise AssertionError("unsupported") |
109 |
| - scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2 |
| 138 | + raise AssertionError("unsupported element dtype") |
| 139 | + |
| 140 | + # rounding before calculating the largest power of 2 |
| 141 | + # X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) |
| 142 | + if scaling_mode == ScaleCalculationMode.EVEN: |
| 143 | + nan_mask = torch.isnan(max_abs) |
| 144 | + max_abs = max_abs.to(torch.float32).view(torch.int32) |
| 145 | + val_to_add = 1 << (MBITS_F32 - mbits - 1) |
| 146 | + mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32 |
| 147 | + max_abs = (max_abs + val_to_add) & mask |
| 148 | + max_abs = max_abs.view(torch.float32) |
| 149 | + max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device) |
| 150 | + |
| 151 | + # Calculate the scale for different modes |
| 152 | + if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): |
| 153 | + scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2 |
| 154 | + elif scaling_mode == ScaleCalculationMode.CEIL: |
| 155 | + scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2 |
| 156 | + else: |
| 157 | + raise AssertionError("unsupported scaling calculation mode") |
110 | 158 |
|
111 | 159 | # Clamp to exponents that can be represented in e8m0
|
112 | 160 | scale_e8m0_unbiased = torch.clamp(
|
@@ -270,15 +318,17 @@ class ToMXConstrFunc(torch.autograd.Function):
|
270 | 318 | """
|
271 | 319 |
|
272 | 320 | @staticmethod
|
273 |
| - def forward(ctx, data_hp, elem_dtype, block_size): |
274 |
| - scale_e8m0_biased, data_lp = to_mx(data_hp, elem_dtype, block_size) |
| 321 | + def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode): |
| 322 | + scale_e8m0_biased, data_lp = to_mx( |
| 323 | + data_hp, elem_dtype, block_size, scaling_mode |
| 324 | + ) |
275 | 325 | return MXTensor(
|
276 | 326 | scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype
|
277 | 327 | )
|
278 | 328 |
|
279 | 329 | @staticmethod
|
280 | 330 | def backward(ctx, g):
|
281 |
| - return g, None, None |
| 331 | + return g, None, None, None |
282 | 332 |
|
283 | 333 |
|
284 | 334 | @torch._dynamo.allow_in_graph
|
@@ -392,8 +442,9 @@ def to_mx(
|
392 | 442 | data_hp: torch.Tensor,
|
393 | 443 | elem_dtype: Union[torch.dtype, str],
|
394 | 444 | block_size: int = BLOCK_SIZE_DEFAULT,
|
| 445 | + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, |
395 | 446 | ):
|
396 |
| - return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size) |
| 447 | + return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode) |
397 | 448 |
|
398 | 449 | def __tensor_flatten__(self):
|
399 | 450 | ctx = {
|
|
0 commit comments