Skip to content

Commit e92c721

Browse files
authored
Implemented RCEIL (CUBLAS-style) MXFP scale factor derivation, with test cases. (#1835)
1 parent 7bb7f23 commit e92c721

File tree

2 files changed

+348
-89
lines changed

2 files changed

+348
-89
lines changed

test/prototype/mx_formats/test_mx_tensor.py

+204
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright (c) 2025, NVIDIA CORPORATION.
23
# All rights reserved.
34

45
# This source code is licensed under the license found in the
@@ -105,6 +106,209 @@ def test_some_zeros(elem_dtype):
105106
_test_mx(data, elem_dtype, block_size)
106107

107108

109+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
110+
def test_to_mx_rceil():
111+
# nan
112+
# fmt: off
113+
data_hp = torch.tensor(
114+
[
115+
2143289344, 1054459450, 1060527345, 1045656552, 1058239340, 1045057552, 1061158006, 1049626606,
116+
1052757568, 1032293288, 1056992320, 1064929425, 1061036255, 1047450552, 1057077424, 1055125012,
117+
1036491424, 1063542041, 1057099838, 1058731224, 1050189482, 1049114228, 1058347802, 1060065968,
118+
1058846156, 1048878912, 1065109089, 1054494928, 1044803976, 1049117692, 1065222528, 1056965012,
119+
],
120+
dtype=torch.uint32,
121+
).view(torch.float32)
122+
# fmt: on
123+
ground_truth_scale = torch.tensor([255], dtype=torch.uint8)
124+
# fmt: off
125+
ground_truth_fp8 = torch.tensor(
126+
[
127+
127, 0, 0, 0, 0, 0, 0, 0,
128+
0, 0, 0, 0, 0, 0, 0, 0,
129+
0, 0, 0, 0, 0, 0, 0, 0,
130+
0, 0, 0, 0, 0, 0, 0, 0,
131+
],
132+
dtype=torch.uint8,
133+
).view(torch.float8_e4m3fn)
134+
# fmt: on
135+
data_mx = MXTensor.to_mx(
136+
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
137+
)
138+
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
139+
assert torch.isnan(data_mx._data[0])
140+
assert torch.all(data_mx._data[1:] == 0)
141+
# fp32 denorm
142+
# fmt: off
143+
data_hp = torch.tensor(
144+
[
145+
6142315, 5096174, 3345704, 6178415, 5728750, 419002, 1716691, 4335089,
146+
5785800, 6234845, 1697524, 33075, 3975816, 3714822, 5411407, 3040844,
147+
7400945, 4474166, 7257182, 1273750, 5872176, 4694081, 2096530, 6273621,
148+
67028, 7585260, 4532315, 4599275, 6133942, 4542483, 5992199, 6862780,
149+
],
150+
dtype=torch.uint32,
151+
).view(torch.float32)
152+
# fmt: on
153+
ground_truth_scale = torch.tensor([0], dtype=torch.uint8)
154+
ground_truth_fp8 = torch.tensor([0] * 32, dtype=torch.uint8).view(
155+
torch.float8_e4m3fn
156+
)
157+
data_mx = MXTensor.to_mx(
158+
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
159+
)
160+
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
161+
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
162+
# bf16 denorm
163+
# fmt: off
164+
data_hp = torch.tensor(
165+
[
166+
101, 3, 47, 54, 36, 19, 70, 79,
167+
35, 95, 28, 120, 84, 94, 20, 92,
168+
18, 42, 98, 58, 3, 26, 64, 86,
169+
60, 86, 52, 23, 61, 70, 59, 74,
170+
],
171+
dtype=torch.uint16,
172+
).view(torch.bfloat16)
173+
# fmt: on
174+
ground_truth_scale = torch.tensor([0], dtype=torch.uint8)
175+
ground_truth_fp8 = torch.tensor([0] * 32, dtype=torch.uint8).view(
176+
torch.float8_e4m3fn
177+
)
178+
data_mx = MXTensor.to_mx(
179+
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
180+
)
181+
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
182+
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
183+
# fp32 some denorm
184+
# fmt: off
185+
data_hp = torch.tensor(
186+
[
187+
8388608, 1063716449, 1064039365, 1063568877, 1051091338, 1062185569, 1034449408, 1060813641,
188+
1054893736, 1034907680, 1036660744, 1023639888, 1058536559, 1050896496, 1049237634, 1064950601,
189+
1051852994, 1059794063, 1054011102, 1062023602, 1059467900, 1062276774, 1059155029, 1053287574,
190+
1064378711, 1055768540, 1045266076, 1059575077, 1054928758, 1040468200, 1058061961, 1053066436,
191+
],
192+
dtype=torch.uint32,
193+
).view(torch.float32)
194+
# fmt: on
195+
ground_truth_scale = torch.tensor([119], dtype=torch.uint8)
196+
# fmt: off
197+
ground_truth_fp8 = torch.tensor(
198+
[
199+
0, 118, 119, 118, 106, 117, 91, 116,
200+
110, 91, 93, 80, 113, 106, 105, 120,
201+
107, 115, 109, 117, 114, 117, 114, 108,
202+
119, 111, 101, 114, 110, 96, 113, 108,
203+
],
204+
dtype=torch.uint8,
205+
).view(torch.float8_e4m3fn)
206+
# fmt: on
207+
data_mx = MXTensor.to_mx(
208+
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
209+
)
210+
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
211+
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
212+
# bf16 some denorm
213+
# fmt: off
214+
data_hp = torch.tensor(
215+
[
216+
128, 16118, 16143, 16074, 16187, 16002, 16193, 16217,
217+
15680, 16183, 16092, 16158, 16251, 15876, 15896, 16194,
218+
16135, 16214, 16205, 16110, 16122, 15960, 15824, 16106,
219+
16220, 16230, 15952, 15896, 16000, 16144, 16232, 16157,
220+
],
221+
dtype=torch.uint16,
222+
).view(torch.bfloat16)
223+
# fmt: on
224+
ground_truth_scale = torch.tensor([119], dtype=torch.uint8)
225+
# fmt: off
226+
ground_truth_fp8 = torch.tensor(
227+
[
228+
0, 111, 113, 109, 116, 104, 116, 118,
229+
84, 115, 110, 114, 120, 96, 98, 116,
230+
112, 117, 117, 111, 112, 102, 93, 111,
231+
118, 118, 101, 98, 104, 113, 118, 114,
232+
],
233+
dtype=torch.uint8,
234+
).view(torch.float8_e4m3fn)
235+
# fmt: on
236+
data_mx = MXTensor.to_mx(
237+
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
238+
)
239+
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
240+
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
241+
# zero
242+
data_hp = torch.tensor([0] * 32, dtype=torch.uint32).view(torch.float32)
243+
ground_truth_scale = torch.tensor([0], dtype=torch.uint8)
244+
ground_truth_fp8 = torch.tensor([0] * 32, dtype=torch.uint8).view(
245+
torch.float8_e4m3fn
246+
)
247+
data_mx = MXTensor.to_mx(
248+
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
249+
)
250+
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
251+
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
252+
# fp32 normal
253+
# fmt: off
254+
data_hp = torch.tensor(
255+
[
256+
1037408064, 1058534842, 1053630662, 1063310394, 994704128, 1057245441, 1060663708, 1058053571,
257+
1052395648, 1064831570, 1038427336, 1064777688, 1059248393, 1060959028, 1062878286, 1057799482,
258+
1057854101, 1053562724, 1027482352, 1060498324, 1063238522, 1060472055, 1054346794, 1029092912,
259+
1056687298, 1059146141, 1037992128, 1064097772, 1056522806, 1059255744, 1064364912, 1060606252,
260+
],
261+
dtype=torch.uint32,
262+
).view(torch.float32)
263+
# fmt: on
264+
ground_truth_scale = torch.tensor([119], dtype=torch.uint8)
265+
# fmt: off
266+
ground_truth_fp8 = torch.tensor(
267+
[
268+
93, 113, 109, 118, 53, 112, 116, 113,
269+
108, 120, 94, 119, 114, 116, 118, 113,
270+
113, 109, 84, 115, 118, 115, 110, 85,
271+
112, 114, 94, 119, 112, 114, 119, 115,
272+
],
273+
dtype=torch.uint8,
274+
).view(torch.float8_e4m3fn)
275+
# fmt: on
276+
data_mx = MXTensor.to_mx(
277+
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
278+
)
279+
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
280+
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
281+
# bf16 normal
282+
# fmt: off
283+
data_hp = torch.tensor(
284+
[
285+
15752, 16143, 16182, 15896, 16195, 16186, 16048, 16223,
286+
15988, 16231, 16140, 16088, 16032, 16240, 16228, 16133,
287+
16210, 16024, 16248, 16187, 16050, 15696, 16060, 15956,
288+
16131, 16251, 15896, 16014, 15808, 16024, 16159, 16186,
289+
],
290+
dtype=torch.uint16,
291+
).view(torch.bfloat16)
292+
# fmt: on
293+
ground_truth_scale = torch.tensor([119], dtype=torch.uint8)
294+
# fmt: off
295+
ground_truth_fp8 = torch.tensor(
296+
[
297+
88, 113, 115, 98, 116, 116, 107, 118,
298+
103, 118, 113, 110, 106, 119, 118, 112,
299+
117, 106, 120, 116, 107, 85, 108, 101,
300+
112, 120, 98, 105, 92, 106, 114, 116,
301+
],
302+
dtype=torch.uint8,
303+
).view(torch.float8_e4m3fn)
304+
# fmt: on
305+
data_mx = MXTensor.to_mx(
306+
data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL
307+
)
308+
torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale)
309+
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
310+
311+
108312
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
109313
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
110314
def test_exponent_nan_in(elem_dtype):

0 commit comments

Comments
 (0)