Skip to content

Commit d32afef

Browse files
authored
mx: triton kernel to cast to mx and write in col-major (#1932)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 59c7311 commit d32afef

File tree

3 files changed

+377
-3
lines changed

3 files changed

+377
-3
lines changed

benchmarks/mx_formats/cast_bench.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import triton
66
from torch._inductor.utils import do_bench_using_profiling
77

8+
from torchao.prototype.mx_formats.custom_cast import (
9+
triton_to_mxfp8_dim1,
10+
)
811
from torchao.prototype.mx_formats.mx_tensor import to_mx
912

1013
torch.manual_seed(0)
@@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size):
4952
return data_d0, scale_d0
5053

5154

55+
def to_mx_dim1_reference(x_hp, block_size):
56+
x_hp = x_hp.t().contiguous()
57+
scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
58+
return data_d1.t(), scale_d1
59+
60+
5261
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
5362
"""Thin wrapper around do_bench_using_profiling"""
5463
no_args = lambda: func(*args, **kwargs)
@@ -67,7 +76,7 @@ def run(
6776
print(f"torch version: {torch.__version__}")
6877
print(f"triton version: {triton.__version__}")
6978
print(f"mode: {mode}")
70-
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx")
79+
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
7180

7281
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
7382

@@ -144,6 +153,41 @@ def run(
144153
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
145154
bps = (bytes_r + bytes_w) / (time_us / 1e6)
146155

156+
elif mode == "dim1_mx":
157+
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
158+
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
159+
160+
for _ in range(2):
161+
__ = to_mx_dim1_reference_c(x, BLOCK_SIZE)
162+
time_us = benchmark_cuda_function_in_microseconds(
163+
lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE),
164+
x,
165+
BLOCK_SIZE,
166+
)
167+
168+
assert y_d1.dtype == torch.float8_e4m3fn
169+
assert s_d1.dtype == torch.uint8
170+
bytes_r = x.numel() * bytes_per_el_bf16
171+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
172+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
173+
174+
elif mode == "dim1_mx_triton":
175+
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
176+
177+
for _ in range(2):
178+
__ = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
179+
time_us = benchmark_cuda_function_in_microseconds(
180+
lambda x, b: triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
181+
x,
182+
BLOCK_SIZE,
183+
)
184+
185+
assert y_d1.dtype == torch.float8_e4m3fn
186+
assert s_d1.dtype == torch.float8_e8m0fnu
187+
bytes_r = x.numel() * bytes_per_el_bf16
188+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
189+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
190+
147191
else:
148192
raise AssertionError(f"unknown mode {mode}")
149193

test/prototype/mx_formats/test_custom_cast.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
triton_f4_to_bf16,
3030
triton_f6_e2m3_to_bf16,
3131
triton_f6_e3m2_to_bf16,
32+
triton_to_mxfp8_dim1,
33+
triton_to_mxfp8_dim1_reference,
3234
unpack_uint4,
3335
)
3436
from torchao.prototype.mx_formats.fp_format_spec import (
@@ -42,7 +44,11 @@
4244
sem_vals_to_f32,
4345
)
4446
from torchao.prototype.mx_formats.mx_tensor import MXTensor
45-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100
47+
from torchao.utils import (
48+
TORCH_VERSION_AT_LEAST_2_8,
49+
is_sm_at_least_89,
50+
is_sm_at_least_100,
51+
)
4652

4753
torch.manual_seed(0)
4854

@@ -444,3 +450,20 @@ def test_fp6_e3m2_pack_unpack():
444450
torch.float32
445451
)
446452
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
453+
454+
455+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
456+
@pytest.mark.skipif(
457+
not is_sm_at_least_89(),
458+
reason="float8 in triton requires CUDA capability 8.9 or greater",
459+
)
460+
@pytest.mark.parametrize("M", (256, 2048))
461+
@pytest.mark.parametrize("K", (256, 2048))
462+
# @pytest.mark.parametrize("M", (256,))
463+
# @pytest.mark.parametrize("K", (256,))
464+
def test_triton_mxfp8_dim1(M, K):
465+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
466+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32)
467+
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
468+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
469+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

0 commit comments

Comments
 (0)