Skip to content

delete mxfp8 torch._scaled_mm wrapper #1965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Mar 28, 2025
52 changes: 0 additions & 52 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch
import torch.nn as nn

from torchao.float8.float8_utils import is_row_major
from torchao.prototype.mx_formats.config import (
MXLinearConfig,
MXLinearRecipeName,
Expand Down Expand Up @@ -334,57 +333,6 @@ def test_inference_compile_simple(elem_dtype):
assert sqnr >= 11.5


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="MX gemms require CUDA capability 10.0",
)
def test_scaled_mm_wrapper():
# today, e8m0 isn't supported in torchinductor or triton
# for now, work around this by creating a wrapper around torch._scaled_mm
# which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper
from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales

M, K, N = 128, 256, 512
BLOCK_SIZE = 32
a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn)
b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn)

a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)

out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16)

def wrapped(a, b, a_scale, b_scale, out_dtype):
if is_row_major(b.stride()):
b = b.t().contiguous().t()
res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype)
return res

wrapped = torch.compile(wrapped)

# correct memory format of `b`
out2 = wrapped(
a,
b.t(),
a_scale.view(torch.uint8),
b_scale.view(torch.uint8),
out_dtype=torch.bfloat16,
)
torch.testing.assert_close(out, out2, atol=0, rtol=0)

# incorrect memory format of `b`
b_col_major = b.t().contiguous().t()
out3 = wrapped(
a,
b_col_major.t(),
a_scale.view(torch.uint8),
b_scale.view(torch.uint8),
out_dtype=torch.bfloat16,
)
torch.testing.assert_close(out, out3, atol=0, rtol=0)


def test_filter_fn():
m1 = nn.Sequential(
nn.Linear(32, 32),
Expand Down
36 changes: 3 additions & 33 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,41 +35,11 @@
tensor_size_hpx3_to_fp6x4,
)
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

aten = torch.ops.aten

MX_OPS_TABLE: Dict[Any, Any] = {}

if TORCH_VERSION_AT_LEAST_2_5:

@torch.library.custom_op("mylib::_scaled_mm_with_uint8_scales", mutates_args=())
def _scaled_mm_with_uint8_scales(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
"""
Until https://github.com/pytorch/pytorch/issues/147873 is done, we need to
work around the lack of support for `torch.float8_e8m0fnu` in
torchinductor. We do so by hiding the cast of scales to e8m0 inside a
custom op.
"""
# cast back to e8m0 where torchinductor can't see it
a_scale = a_scale.view(torch.float8_e8m0fnu)
b_scale = b_scale.view(torch.float8_e8m0fnu)
res = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=out_dtype)
return res

@_scaled_mm_with_uint8_scales.register_fake
def _(a, b, a_scale, b_scale, out_dtype):
m, k = a.shape
k2, n = b.shape
res = torch.empty(m, n, dtype=out_dtype, device=a.device)
return res


def implements(aten_ops):
"""Register aten ops to the mx op table"""
Expand Down Expand Up @@ -119,11 +89,11 @@ def mx_mm(aten_op, args, kwargs=None):
if a._elem_dtype == torch.float8_e4m3fn:
assert b._elem_dtype == torch.float8_e4m3fn
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
res = _scaled_mm_with_uint8_scales(
res = torch._scaled_mm(
a._data,
b._data,
a_scale_block,
b_scale_block,
a_scale_block.view(torch.float8_e8m0fnu),
b_scale_block.view(torch.float8_e8m0fnu),
out_dtype=torch.bfloat16,
)
else:
Expand Down
Loading