Skip to content

Commit 7e3978c

Browse files
authored
delete mxfp8 torch._scaled_mm wrapper (#1965)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 3766ed7 commit 7e3978c

File tree

2 files changed

+3
-85
lines changed

2 files changed

+3
-85
lines changed

test/prototype/mx_formats/test_mx_linear.py

-52
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import torch
1111
import torch.nn as nn
1212

13-
from torchao.float8.float8_utils import is_row_major
1413
from torchao.prototype.mx_formats.config import (
1514
MXLinearConfig,
1615
MXLinearRecipeName,
@@ -334,57 +333,6 @@ def test_inference_compile_simple(elem_dtype):
334333
assert sqnr >= 11.5
335334

336335

337-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
338-
@pytest.mark.skipif(
339-
not is_sm_at_least_100(),
340-
reason="MX gemms require CUDA capability 10.0",
341-
)
342-
def test_scaled_mm_wrapper():
343-
# today, e8m0 isn't supported in torchinductor or triton
344-
# for now, work around this by creating a wrapper around torch._scaled_mm
345-
# which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper
346-
from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales
347-
348-
M, K, N = 128, 256, 512
349-
BLOCK_SIZE = 32
350-
a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn)
351-
b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn)
352-
353-
a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
354-
b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
355-
356-
out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16)
357-
358-
def wrapped(a, b, a_scale, b_scale, out_dtype):
359-
if is_row_major(b.stride()):
360-
b = b.t().contiguous().t()
361-
res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype)
362-
return res
363-
364-
wrapped = torch.compile(wrapped)
365-
366-
# correct memory format of `b`
367-
out2 = wrapped(
368-
a,
369-
b.t(),
370-
a_scale.view(torch.uint8),
371-
b_scale.view(torch.uint8),
372-
out_dtype=torch.bfloat16,
373-
)
374-
torch.testing.assert_close(out, out2, atol=0, rtol=0)
375-
376-
# incorrect memory format of `b`
377-
b_col_major = b.t().contiguous().t()
378-
out3 = wrapped(
379-
a,
380-
b_col_major.t(),
381-
a_scale.view(torch.uint8),
382-
b_scale.view(torch.uint8),
383-
out_dtype=torch.bfloat16,
384-
)
385-
torch.testing.assert_close(out, out3, atol=0, rtol=0)
386-
387-
388336
def test_filter_fn():
389337
m1 = nn.Sequential(
390338
nn.Linear(32, 32),

torchao/prototype/mx_formats/mx_ops.py

+3-33
Original file line numberDiff line numberDiff line change
@@ -35,41 +35,11 @@
3535
tensor_size_hpx3_to_fp6x4,
3636
)
3737
from torchao.prototype.mx_formats.utils import to_blocked
38-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
3938

4039
aten = torch.ops.aten
4140

4241
MX_OPS_TABLE: Dict[Any, Any] = {}
4342

44-
if TORCH_VERSION_AT_LEAST_2_5:
45-
46-
@torch.library.custom_op("mylib::_scaled_mm_with_uint8_scales", mutates_args=())
47-
def _scaled_mm_with_uint8_scales(
48-
a: torch.Tensor,
49-
b: torch.Tensor,
50-
a_scale: torch.Tensor,
51-
b_scale: torch.Tensor,
52-
out_dtype: torch.dtype,
53-
) -> torch.Tensor:
54-
"""
55-
Until https://github.com/pytorch/pytorch/issues/147873 is done, we need to
56-
work around the lack of support for `torch.float8_e8m0fnu` in
57-
torchinductor. We do so by hiding the cast of scales to e8m0 inside a
58-
custom op.
59-
"""
60-
# cast back to e8m0 where torchinductor can't see it
61-
a_scale = a_scale.view(torch.float8_e8m0fnu)
62-
b_scale = b_scale.view(torch.float8_e8m0fnu)
63-
res = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=out_dtype)
64-
return res
65-
66-
@_scaled_mm_with_uint8_scales.register_fake
67-
def _(a, b, a_scale, b_scale, out_dtype):
68-
m, k = a.shape
69-
k2, n = b.shape
70-
res = torch.empty(m, n, dtype=out_dtype, device=a.device)
71-
return res
72-
7343

7444
def implements(aten_ops):
7545
"""Register aten ops to the mx op table"""
@@ -119,11 +89,11 @@ def mx_mm(aten_op, args, kwargs=None):
11989
if a._elem_dtype == torch.float8_e4m3fn:
12090
assert b._elem_dtype == torch.float8_e4m3fn
12191
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
122-
res = _scaled_mm_with_uint8_scales(
92+
res = torch._scaled_mm(
12393
a._data,
12494
b._data,
125-
a_scale_block,
126-
b_scale_block,
95+
a_scale_block.view(torch.float8_e8m0fnu),
96+
b_scale_block.view(torch.float8_e8m0fnu),
12797
out_dtype=torch.bfloat16,
12898
)
12999
else:

0 commit comments

Comments
 (0)