|
10 | 10 | import torch |
11 | 11 | import torch.nn as nn |
12 | 12 |
|
13 | | -from torchao.float8.float8_utils import is_row_major |
14 | 13 | from torchao.prototype.mx_formats.config import ( |
15 | 14 | MXLinearConfig, |
16 | 15 | MXLinearRecipeName, |
@@ -334,57 +333,6 @@ def test_inference_compile_simple(elem_dtype): |
334 | 333 | assert sqnr >= 11.5 |
335 | 334 |
|
336 | 335 |
|
337 | | -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
338 | | -@pytest.mark.skipif( |
339 | | - not is_sm_at_least_100(), |
340 | | - reason="MX gemms require CUDA capability 10.0", |
341 | | -) |
342 | | -def test_scaled_mm_wrapper(): |
343 | | - # today, e8m0 isn't supported in torchinductor or triton |
344 | | - # for now, work around this by creating a wrapper around torch._scaled_mm |
345 | | - # which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper |
346 | | - from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales |
347 | | - |
348 | | - M, K, N = 128, 256, 512 |
349 | | - BLOCK_SIZE = 32 |
350 | | - a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn) |
351 | | - b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn) |
352 | | - |
353 | | - a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu) |
354 | | - b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu) |
355 | | - |
356 | | - out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16) |
357 | | - |
358 | | - def wrapped(a, b, a_scale, b_scale, out_dtype): |
359 | | - if is_row_major(b.stride()): |
360 | | - b = b.t().contiguous().t() |
361 | | - res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype) |
362 | | - return res |
363 | | - |
364 | | - wrapped = torch.compile(wrapped) |
365 | | - |
366 | | - # correct memory format of `b` |
367 | | - out2 = wrapped( |
368 | | - a, |
369 | | - b.t(), |
370 | | - a_scale.view(torch.uint8), |
371 | | - b_scale.view(torch.uint8), |
372 | | - out_dtype=torch.bfloat16, |
373 | | - ) |
374 | | - torch.testing.assert_close(out, out2, atol=0, rtol=0) |
375 | | - |
376 | | - # incorrect memory format of `b` |
377 | | - b_col_major = b.t().contiguous().t() |
378 | | - out3 = wrapped( |
379 | | - a, |
380 | | - b_col_major.t(), |
381 | | - a_scale.view(torch.uint8), |
382 | | - b_scale.view(torch.uint8), |
383 | | - out_dtype=torch.bfloat16, |
384 | | - ) |
385 | | - torch.testing.assert_close(out, out3, atol=0, rtol=0) |
386 | | - |
387 | | - |
388 | 336 | def test_filter_fn(): |
389 | 337 | m1 = nn.Sequential( |
390 | 338 | nn.Linear(32, 32), |
|
0 commit comments