Skip to content

Commit 413e2a1

Browse files
committed
mx: add ceil and RNE rounding modes to the cast from fp32 to e8m0
Summary: Why we want this: in the newly released cuBLAS 12.8 documentation, the RNE rounding mode is used for the cast to e8m0. We want to properly emulate this cast. This is a copy-pasta of #516 with the modifications being keeping FLOOR the default more, and removing e3m0, credit to NicoleMayer for the original code. I don't have a way to check bitwise equivalency with the most recent cuBLAS version yet, but will come back and add tests when I do. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 49330726a68936ddece56d37e15129d3cc546b4e ghstack-comment-id: 2613597310 Pull Request resolved: #1620
1 parent df7ccea commit 413e2a1

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

torchao/prototype/mx_formats/mx_tensor.py

+54-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
* Zeros: N/A
1717
"""
1818

19+
from enum import Enum, auto
1920
from typing import Dict, Union
2021

2122
import torch
@@ -53,11 +54,38 @@
5354
unpack_uint4,
5455
)
5556

57+
# TODO(later): read from somewhere else?
58+
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
59+
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
60+
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
61+
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
62+
EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3
63+
EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2
64+
65+
66+
class ScaleCalculationMode(Enum):
67+
"""
68+
Enum representing the different methods for calculating MX block scaling.
69+
There are three methods available:
70+
FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp).
71+
It result in overflow issues for large values and bad for gradient quantization.
72+
CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor.
73+
It uses X = 2^ceil(log2(max_abs(v))-max_exp).
74+
EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)).
75+
It provides better accuracy for MX4 training compared to FLOOR and CEIL.
76+
By default, we use the EVEN method for better accuracy.
77+
"""
78+
79+
FLOOR = auto()
80+
CEIL = auto()
81+
EVEN = auto()
82+
5683

5784
def to_mx(
5885
data_hp: torch.Tensor,
5986
elem_dtype: Union[torch.dtype, str],
6087
block_size: int,
88+
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
6189
):
6290
"""
6391
Takes a high precision tensor and converts to MX scale and raw data, in
@@ -88,25 +116,45 @@ def to_mx(
88116
# where the values are zero.
89117
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)
90118

91-
# Find largest power of 2 less than or equal to max_abs.
92-
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps))
93-
94119
# Set X to be the largest power-of-two less than or equal to
95120
# max_abs(v), divided by the largest power of two representable
96-
# in the element data type
121+
# in the element data type, and get the mbits at the same time
97122
if elem_dtype == torch.float8_e4m3fn:
98123
target_max_pow2 = F8E4M3_MAX_POW2
124+
mbits = MBITS_F8_E4M3
99125
elif elem_dtype == torch.float8_e5m2:
100126
target_max_pow2 = F8E5M2_MAX_POW2
127+
mbits = MBITS_F8_E5M2
101128
elif elem_dtype == DTYPE_FP6_E2M3:
102129
target_max_pow2 = F6_E2M3_MAX_POW2
130+
mbits = MBITS_F6_E2M3
103131
elif elem_dtype == DTYPE_FP6_E3M2:
104132
target_max_pow2 = F6_E3M2_MAX_POW2
133+
mbits = MBITS_F6_E3M2
105134
elif elem_dtype == DTYPE_FP4:
106135
target_max_pow2 = F4_E2M1_MAX_POW2
136+
mbits = MBITS_F4_E2M1
107137
else:
108-
raise AssertionError("unsupported")
109-
scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2
138+
raise AssertionError("unsupported element dtype")
139+
140+
# rounding before calculating the largest power of 2
141+
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
142+
if scaling_mode == ScaleCalculationMode.EVEN:
143+
nan_mask = torch.isnan(max_abs)
144+
max_abs = max_abs.to(torch.float32).view(torch.int32)
145+
val_to_add = 1 << (MBITS_F32 - mbits - 1)
146+
mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32
147+
max_abs = (max_abs + val_to_add) & mask
148+
max_abs = max_abs.view(torch.float32)
149+
max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device)
150+
151+
# Calculate the scale for different modes
152+
if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
153+
scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2
154+
elif scaling_mode == ScaleCalculationMode.CEIL:
155+
scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2
156+
else:
157+
raise AssertionError("unsupported scaling calculation mode")
110158

111159
# Clamp to exponents that can be represented in e8m0
112160
scale_e8m0_unbiased = torch.clamp(

0 commit comments

Comments
 (0)