|
10 | 10 | import torch
|
11 | 11 | import torch.nn as nn
|
12 | 12 |
|
13 |
| -from torchao.float8.float8_utils import is_row_major |
14 | 13 | from torchao.prototype.mx_formats.config import (
|
15 | 14 | MXLinearConfig,
|
16 | 15 | MXLinearRecipeName,
|
@@ -334,57 +333,6 @@ def test_inference_compile_simple(elem_dtype):
|
334 | 333 | assert sqnr >= 11.5
|
335 | 334 |
|
336 | 335 |
|
337 |
| -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
338 |
| -@pytest.mark.skipif( |
339 |
| - not is_sm_at_least_100(), |
340 |
| - reason="MX gemms require CUDA capability 10.0", |
341 |
| -) |
342 |
| -def test_scaled_mm_wrapper(): |
343 |
| - # today, e8m0 isn't supported in torchinductor or triton |
344 |
| - # for now, work around this by creating a wrapper around torch._scaled_mm |
345 |
| - # which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper |
346 |
| - from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales |
347 |
| - |
348 |
| - M, K, N = 128, 256, 512 |
349 |
| - BLOCK_SIZE = 32 |
350 |
| - a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn) |
351 |
| - b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn) |
352 |
| - |
353 |
| - a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu) |
354 |
| - b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu) |
355 |
| - |
356 |
| - out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16) |
357 |
| - |
358 |
| - def wrapped(a, b, a_scale, b_scale, out_dtype): |
359 |
| - if is_row_major(b.stride()): |
360 |
| - b = b.t().contiguous().t() |
361 |
| - res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype) |
362 |
| - return res |
363 |
| - |
364 |
| - wrapped = torch.compile(wrapped) |
365 |
| - |
366 |
| - # correct memory format of `b` |
367 |
| - out2 = wrapped( |
368 |
| - a, |
369 |
| - b.t(), |
370 |
| - a_scale.view(torch.uint8), |
371 |
| - b_scale.view(torch.uint8), |
372 |
| - out_dtype=torch.bfloat16, |
373 |
| - ) |
374 |
| - torch.testing.assert_close(out, out2, atol=0, rtol=0) |
375 |
| - |
376 |
| - # incorrect memory format of `b` |
377 |
| - b_col_major = b.t().contiguous().t() |
378 |
| - out3 = wrapped( |
379 |
| - a, |
380 |
| - b_col_major.t(), |
381 |
| - a_scale.view(torch.uint8), |
382 |
| - b_scale.view(torch.uint8), |
383 |
| - out_dtype=torch.bfloat16, |
384 |
| - ) |
385 |
| - torch.testing.assert_close(out, out3, atol=0, rtol=0) |
386 |
| - |
387 |
| - |
388 | 336 | def test_filter_fn():
|
389 | 337 | m1 = nn.Sequential(
|
390 | 338 | nn.Linear(32, 32),
|
|
0 commit comments