Skip to content

Commit 42fe7b9

Browse files
committed
mx: add ceil and RNE rounding modes to the cast from fp32 to e8m0
Summary: Why we want this: in the newly released cuBLAS 12.8 documentation, the RNE rounding mode is used for the cast to e8m0. We want to properly emulate this cast. This is a copy-pasta of #516 with the modifications being keeping FLOOR the default more, and removing e3m0, credit to NicoleMayer for the original code. I don't have a way to check bitwise equivalency with the most recent cuBLAS version yet, but will come back and add tests when I do. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-comment-id: 2613597310 ghstack-source-id: c607b6c ghstack-comment-id: 2623497069 Pull Request resolved: #1643
1 parent 4d1c774 commit 42fe7b9

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

test/prototype/mx_formats/test_mx_tensor.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.prototype.mx_formats.mx_tensor import (
1919
E8M0_EXPONENT_NAN_VAL,
2020
MXTensor,
21+
ScaleCalculationMode,
2122
to_dtype,
2223
)
2324
from torchao.quantization.utils import compute_error
@@ -47,8 +48,10 @@ def run_before_and_after_tests():
4748
torch._dynamo.reset()
4849

4950

50-
def _test_mx(data_hp, elem_dtype, block_size):
51-
data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size)
51+
def _test_mx(
52+
data_hp, elem_dtype, block_size, scale_calculation_mode=ScaleCalculationMode.FLOOR
53+
):
54+
data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size, scale_calculation_mode)
5255
data_mx_dq = data_mx.to_dtype(data_hp.dtype)
5356

5457
def assert_sqnr_gt_threshold(orig, new, threshold):
@@ -61,7 +64,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
6164
assert sqnr >= threshold
6265

6366
if elem_dtype is torch.float8_e4m3fn:
64-
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 20.0)
67+
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 18.0)
6568
else:
6669
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 14.0)
6770

@@ -74,6 +77,15 @@ def test_hello_world(elem_dtype):
7477
_test_mx(data, elem_dtype, block_size)
7578

7679

80+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
81+
@pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode])
82+
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
83+
def test_realistic_numerics(elem_dtype, scale_calculation_mode):
84+
data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
85+
block_size = 32
86+
_test_mx(data, elem_dtype, block_size, scale_calculation_mode)
87+
88+
7789
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
7890
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
7991
def test_all_zeros(elem_dtype):

torchao/prototype/mx_formats/mx_tensor.py

+61-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* Zeros: N/A
1717
"""
1818

19+
from enum import Enum, auto
1920
from typing import Dict, Union
2021

2122
import torch
@@ -53,11 +54,38 @@
5354
unpack_uint4,
5455
)
5556

57+
# TODO(later): read from somewhere else?
58+
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
59+
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
60+
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
61+
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
62+
EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3
63+
EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2
64+
65+
66+
class ScaleCalculationMode(Enum):
67+
"""
68+
Enum representing the different methods for calculating MX block scaling.
69+
There are three methods available:
70+
FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp).
71+
It result in overflow issues for large values and bad for gradient quantization.
72+
CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor.
73+
It uses X = 2^ceil(log2(max_abs(v))-max_exp).
74+
EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)).
75+
It provides better accuracy for MX4 training compared to FLOOR and CEIL.
76+
By default, we use the EVEN method for better accuracy.
77+
"""
78+
79+
FLOOR = auto()
80+
CEIL = auto()
81+
EVEN = auto()
82+
5683

5784
def to_mx(
5885
data_hp: torch.Tensor,
5986
elem_dtype: Union[torch.dtype, str],
6087
block_size: int,
88+
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
6189
):
6290
"""
6391
Takes a high precision tensor and converts to MX scale and raw data, in
@@ -88,25 +116,45 @@ def to_mx(
88116
# where the values are zero.
89117
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)
90118

91-
# Find largest power of 2 less than or equal to max_abs.
92-
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps))
93-
94119
# Set X to be the largest power-of-two less than or equal to
95120
# max_abs(v), divided by the largest power of two representable
96-
# in the element data type
121+
# in the element data type, and get the mbits at the same time
97122
if elem_dtype == torch.float8_e4m3fn:
98123
target_max_pow2 = F8E4M3_MAX_POW2
124+
mbits = MBITS_F8_E4M3
99125
elif elem_dtype == torch.float8_e5m2:
100126
target_max_pow2 = F8E5M2_MAX_POW2
127+
mbits = MBITS_F8_E5M2
101128
elif elem_dtype == DTYPE_FP6_E2M3:
102129
target_max_pow2 = F6_E2M3_MAX_POW2
130+
mbits = MBITS_F6_E2M3
103131
elif elem_dtype == DTYPE_FP6_E3M2:
104132
target_max_pow2 = F6_E3M2_MAX_POW2
133+
mbits = MBITS_F6_E3M2
105134
elif elem_dtype == DTYPE_FP4:
106135
target_max_pow2 = F4_E2M1_MAX_POW2
136+
mbits = MBITS_F4_E2M1
107137
else:
108-
raise AssertionError("unsupported")
109-
scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2
138+
raise AssertionError("unsupported element dtype")
139+
140+
# rounding before calculating the largest power of 2
141+
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
142+
if scaling_mode == ScaleCalculationMode.EVEN:
143+
nan_mask = torch.isnan(max_abs)
144+
max_abs = max_abs.to(torch.float32).view(torch.int32)
145+
val_to_add = 1 << (MBITS_F32 - mbits - 1)
146+
mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32
147+
max_abs = (max_abs + val_to_add) & mask
148+
max_abs = max_abs.view(torch.float32)
149+
max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device)
150+
151+
# Calculate the scale for different modes
152+
if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
153+
scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2
154+
elif scaling_mode == ScaleCalculationMode.CEIL:
155+
scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2
156+
else:
157+
raise AssertionError("unsupported scaling calculation mode")
110158

111159
# Clamp to exponents that can be represented in e8m0
112160
scale_e8m0_unbiased = torch.clamp(
@@ -270,15 +318,17 @@ class ToMXConstrFunc(torch.autograd.Function):
270318
"""
271319

272320
@staticmethod
273-
def forward(ctx, data_hp, elem_dtype, block_size):
274-
scale_e8m0_biased, data_lp = to_mx(data_hp, elem_dtype, block_size)
321+
def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode):
322+
scale_e8m0_biased, data_lp = to_mx(
323+
data_hp, elem_dtype, block_size, scaling_mode
324+
)
275325
return MXTensor(
276326
scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype
277327
)
278328

279329
@staticmethod
280330
def backward(ctx, g):
281-
return g, None, None
331+
return g, None, None, None
282332

283333

284334
@torch._dynamo.allow_in_graph
@@ -392,8 +442,9 @@ def to_mx(
392442
data_hp: torch.Tensor,
393443
elem_dtype: Union[torch.dtype, str],
394444
block_size: int = BLOCK_SIZE_DEFAULT,
445+
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
395446
):
396-
return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size)
447+
return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode)
397448

398449
def __tensor_flatten__(self):
399450
ctx = {

0 commit comments

Comments
 (0)