|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 | +# Copyright (c) 2025, NVIDIA CORPORATION. |
2 | 3 | # All rights reserved.
|
3 | 4 |
|
4 | 5 | # This source code is licensed under the license found in the
|
@@ -105,6 +106,209 @@ def test_some_zeros(elem_dtype):
|
105 | 106 | _test_mx(data, elem_dtype, block_size)
|
106 | 107 |
|
107 | 108 |
|
| 109 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 110 | +def test_to_mx_rceil(): |
| 111 | + # nan |
| 112 | + # fmt: off |
| 113 | + data_hp = torch.tensor( |
| 114 | + [ |
| 115 | + 2143289344, 1054459450, 1060527345, 1045656552, 1058239340, 1045057552, 1061158006, 1049626606, |
| 116 | + 1052757568, 1032293288, 1056992320, 1064929425, 1061036255, 1047450552, 1057077424, 1055125012, |
| 117 | + 1036491424, 1063542041, 1057099838, 1058731224, 1050189482, 1049114228, 1058347802, 1060065968, |
| 118 | + 1058846156, 1048878912, 1065109089, 1054494928, 1044803976, 1049117692, 1065222528, 1056965012, |
| 119 | + ], |
| 120 | + dtype=torch.uint32, |
| 121 | + ).view(torch.float32) |
| 122 | + # fmt: on |
| 123 | + ground_truth_scale = torch.tensor([255], dtype=torch.uint8) |
| 124 | + # fmt: off |
| 125 | + ground_truth_fp8 = torch.tensor( |
| 126 | + [ |
| 127 | + 127, 0, 0, 0, 0, 0, 0, 0, |
| 128 | + 0, 0, 0, 0, 0, 0, 0, 0, |
| 129 | + 0, 0, 0, 0, 0, 0, 0, 0, |
| 130 | + 0, 0, 0, 0, 0, 0, 0, 0, |
| 131 | + ], |
| 132 | + dtype=torch.uint8, |
| 133 | + ).view(torch.float8_e4m3fn) |
| 134 | + # fmt: on |
| 135 | + data_mx = MXTensor.to_mx( |
| 136 | + data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL |
| 137 | + ) |
| 138 | + torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) |
| 139 | + assert torch.isnan(data_mx._data[0]) |
| 140 | + assert torch.all(data_mx._data[1:] == 0) |
| 141 | + # fp32 denorm |
| 142 | + # fmt: off |
| 143 | + data_hp = torch.tensor( |
| 144 | + [ |
| 145 | + 6142315, 5096174, 3345704, 6178415, 5728750, 419002, 1716691, 4335089, |
| 146 | + 5785800, 6234845, 1697524, 33075, 3975816, 3714822, 5411407, 3040844, |
| 147 | + 7400945, 4474166, 7257182, 1273750, 5872176, 4694081, 2096530, 6273621, |
| 148 | + 67028, 7585260, 4532315, 4599275, 6133942, 4542483, 5992199, 6862780, |
| 149 | + ], |
| 150 | + dtype=torch.uint32, |
| 151 | + ).view(torch.float32) |
| 152 | + # fmt: on |
| 153 | + ground_truth_scale = torch.tensor([0], dtype=torch.uint8) |
| 154 | + ground_truth_fp8 = torch.tensor([0] * 32, dtype=torch.uint8).view( |
| 155 | + torch.float8_e4m3fn |
| 156 | + ) |
| 157 | + data_mx = MXTensor.to_mx( |
| 158 | + data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL |
| 159 | + ) |
| 160 | + torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) |
| 161 | + torch.testing.assert_close(data_mx._data, ground_truth_fp8) |
| 162 | + # bf16 denorm |
| 163 | + # fmt: off |
| 164 | + data_hp = torch.tensor( |
| 165 | + [ |
| 166 | + 101, 3, 47, 54, 36, 19, 70, 79, |
| 167 | + 35, 95, 28, 120, 84, 94, 20, 92, |
| 168 | + 18, 42, 98, 58, 3, 26, 64, 86, |
| 169 | + 60, 86, 52, 23, 61, 70, 59, 74, |
| 170 | + ], |
| 171 | + dtype=torch.uint16, |
| 172 | + ).view(torch.bfloat16) |
| 173 | + # fmt: on |
| 174 | + ground_truth_scale = torch.tensor([0], dtype=torch.uint8) |
| 175 | + ground_truth_fp8 = torch.tensor([0] * 32, dtype=torch.uint8).view( |
| 176 | + torch.float8_e4m3fn |
| 177 | + ) |
| 178 | + data_mx = MXTensor.to_mx( |
| 179 | + data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL |
| 180 | + ) |
| 181 | + torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) |
| 182 | + torch.testing.assert_close(data_mx._data, ground_truth_fp8) |
| 183 | + # fp32 some denorm |
| 184 | + # fmt: off |
| 185 | + data_hp = torch.tensor( |
| 186 | + [ |
| 187 | + 8388608, 1063716449, 1064039365, 1063568877, 1051091338, 1062185569, 1034449408, 1060813641, |
| 188 | + 1054893736, 1034907680, 1036660744, 1023639888, 1058536559, 1050896496, 1049237634, 1064950601, |
| 189 | + 1051852994, 1059794063, 1054011102, 1062023602, 1059467900, 1062276774, 1059155029, 1053287574, |
| 190 | + 1064378711, 1055768540, 1045266076, 1059575077, 1054928758, 1040468200, 1058061961, 1053066436, |
| 191 | + ], |
| 192 | + dtype=torch.uint32, |
| 193 | + ).view(torch.float32) |
| 194 | + # fmt: on |
| 195 | + ground_truth_scale = torch.tensor([119], dtype=torch.uint8) |
| 196 | + # fmt: off |
| 197 | + ground_truth_fp8 = torch.tensor( |
| 198 | + [ |
| 199 | + 0, 118, 119, 118, 106, 117, 91, 116, |
| 200 | + 110, 91, 93, 80, 113, 106, 105, 120, |
| 201 | + 107, 115, 109, 117, 114, 117, 114, 108, |
| 202 | + 119, 111, 101, 114, 110, 96, 113, 108, |
| 203 | + ], |
| 204 | + dtype=torch.uint8, |
| 205 | + ).view(torch.float8_e4m3fn) |
| 206 | + # fmt: on |
| 207 | + data_mx = MXTensor.to_mx( |
| 208 | + data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL |
| 209 | + ) |
| 210 | + torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) |
| 211 | + torch.testing.assert_close(data_mx._data, ground_truth_fp8) |
| 212 | + # bf16 some denorm |
| 213 | + # fmt: off |
| 214 | + data_hp = torch.tensor( |
| 215 | + [ |
| 216 | + 128, 16118, 16143, 16074, 16187, 16002, 16193, 16217, |
| 217 | + 15680, 16183, 16092, 16158, 16251, 15876, 15896, 16194, |
| 218 | + 16135, 16214, 16205, 16110, 16122, 15960, 15824, 16106, |
| 219 | + 16220, 16230, 15952, 15896, 16000, 16144, 16232, 16157, |
| 220 | + ], |
| 221 | + dtype=torch.uint16, |
| 222 | + ).view(torch.bfloat16) |
| 223 | + # fmt: on |
| 224 | + ground_truth_scale = torch.tensor([119], dtype=torch.uint8) |
| 225 | + # fmt: off |
| 226 | + ground_truth_fp8 = torch.tensor( |
| 227 | + [ |
| 228 | + 0, 111, 113, 109, 116, 104, 116, 118, |
| 229 | + 84, 115, 110, 114, 120, 96, 98, 116, |
| 230 | + 112, 117, 117, 111, 112, 102, 93, 111, |
| 231 | + 118, 118, 101, 98, 104, 113, 118, 114, |
| 232 | + ], |
| 233 | + dtype=torch.uint8, |
| 234 | + ).view(torch.float8_e4m3fn) |
| 235 | + # fmt: on |
| 236 | + data_mx = MXTensor.to_mx( |
| 237 | + data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL |
| 238 | + ) |
| 239 | + torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) |
| 240 | + torch.testing.assert_close(data_mx._data, ground_truth_fp8) |
| 241 | + # zero |
| 242 | + data_hp = torch.tensor([0] * 32, dtype=torch.uint32).view(torch.float32) |
| 243 | + ground_truth_scale = torch.tensor([0], dtype=torch.uint8) |
| 244 | + ground_truth_fp8 = torch.tensor([0] * 32, dtype=torch.uint8).view( |
| 245 | + torch.float8_e4m3fn |
| 246 | + ) |
| 247 | + data_mx = MXTensor.to_mx( |
| 248 | + data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL |
| 249 | + ) |
| 250 | + torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) |
| 251 | + torch.testing.assert_close(data_mx._data, ground_truth_fp8) |
| 252 | + # fp32 normal |
| 253 | + # fmt: off |
| 254 | + data_hp = torch.tensor( |
| 255 | + [ |
| 256 | + 1037408064, 1058534842, 1053630662, 1063310394, 994704128, 1057245441, 1060663708, 1058053571, |
| 257 | + 1052395648, 1064831570, 1038427336, 1064777688, 1059248393, 1060959028, 1062878286, 1057799482, |
| 258 | + 1057854101, 1053562724, 1027482352, 1060498324, 1063238522, 1060472055, 1054346794, 1029092912, |
| 259 | + 1056687298, 1059146141, 1037992128, 1064097772, 1056522806, 1059255744, 1064364912, 1060606252, |
| 260 | + ], |
| 261 | + dtype=torch.uint32, |
| 262 | + ).view(torch.float32) |
| 263 | + # fmt: on |
| 264 | + ground_truth_scale = torch.tensor([119], dtype=torch.uint8) |
| 265 | + # fmt: off |
| 266 | + ground_truth_fp8 = torch.tensor( |
| 267 | + [ |
| 268 | + 93, 113, 109, 118, 53, 112, 116, 113, |
| 269 | + 108, 120, 94, 119, 114, 116, 118, 113, |
| 270 | + 113, 109, 84, 115, 118, 115, 110, 85, |
| 271 | + 112, 114, 94, 119, 112, 114, 119, 115, |
| 272 | + ], |
| 273 | + dtype=torch.uint8, |
| 274 | + ).view(torch.float8_e4m3fn) |
| 275 | + # fmt: on |
| 276 | + data_mx = MXTensor.to_mx( |
| 277 | + data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL |
| 278 | + ) |
| 279 | + torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) |
| 280 | + torch.testing.assert_close(data_mx._data, ground_truth_fp8) |
| 281 | + # bf16 normal |
| 282 | + # fmt: off |
| 283 | + data_hp = torch.tensor( |
| 284 | + [ |
| 285 | + 15752, 16143, 16182, 15896, 16195, 16186, 16048, 16223, |
| 286 | + 15988, 16231, 16140, 16088, 16032, 16240, 16228, 16133, |
| 287 | + 16210, 16024, 16248, 16187, 16050, 15696, 16060, 15956, |
| 288 | + 16131, 16251, 15896, 16014, 15808, 16024, 16159, 16186, |
| 289 | + ], |
| 290 | + dtype=torch.uint16, |
| 291 | + ).view(torch.bfloat16) |
| 292 | + # fmt: on |
| 293 | + ground_truth_scale = torch.tensor([119], dtype=torch.uint8) |
| 294 | + # fmt: off |
| 295 | + ground_truth_fp8 = torch.tensor( |
| 296 | + [ |
| 297 | + 88, 113, 115, 98, 116, 116, 107, 118, |
| 298 | + 103, 118, 113, 110, 106, 119, 118, 112, |
| 299 | + 117, 106, 120, 116, 107, 85, 108, 101, |
| 300 | + 112, 120, 98, 105, 92, 106, 114, 116, |
| 301 | + ], |
| 302 | + dtype=torch.uint8, |
| 303 | + ).view(torch.float8_e4m3fn) |
| 304 | + # fmt: on |
| 305 | + data_mx = MXTensor.to_mx( |
| 306 | + data_hp, torch.float8_e4m3fn, 32, ScaleCalculationMode.RCEIL |
| 307 | + ) |
| 308 | + torch.testing.assert_close(data_mx._scale_e8m0, ground_truth_scale) |
| 309 | + torch.testing.assert_close(data_mx._data, ground_truth_fp8) |
| 310 | + |
| 311 | + |
108 | 312 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
109 | 313 | @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
|
110 | 314 | def test_exponent_nan_in(elem_dtype):
|
|
0 commit comments