Skip to content

Commit ac177bd

Browse files
committed
[not for land] hook up MX to CUDA 12.8 cuBLAS MX gemm
Summary: Requires https://github.com/pytorch/pytorch/pull/145562/files None of this is for land - just testing for now as we work on a long term support plan. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 5442d55 ghstack-comment-id: 2616580142 Pull Request resolved: #1625
1 parent 3c99a1b commit ac177bd

File tree

2 files changed

+208
-12
lines changed

2 files changed

+208
-12
lines changed

benchmarks/float8/bench_matmul.py

+115-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import itertools
7+
from enum import IntEnum
78
from typing import Optional
89

910
import fire
@@ -26,14 +27,44 @@
2627
h100_peak_flops_fp16_tc = 989e12
2728
h100_peak_tops_float8_tc = 1979e12
2829

29-
dtype_to_peak_tops = {
30+
# HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/
31+
# note: divided numbers from ^ by 2 to undo the effects of sparsity
32+
# TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8,
33+
# something seems funky
34+
b200_peak_flops_float32 = 600e12
35+
b200_peak_flops_fp16_tc = 18e15
36+
b200_peak_tops_float8_tc = 36e15
37+
b200_peak_tops_float4_tc = 72e15
38+
39+
dtype_to_peak_tops_h100 = {
3040
torch.float32: h100_peak_flops_float32,
3141
torch.float16: h100_peak_flops_fp16_tc,
3242
torch.bfloat16: h100_peak_flops_fp16_tc,
3343
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
3444
torch.float8_e5m2: h100_peak_tops_float8_tc,
3545
}
3646

47+
dtype_to_peak_tops_b200 = {
48+
torch.float32: b200_peak_flops_float32,
49+
torch.float16: b200_peak_flops_fp16_tc,
50+
torch.bfloat16: b200_peak_flops_fp16_tc,
51+
torch.float8_e4m3fn: b200_peak_tops_float8_tc,
52+
torch.float8_e5m2: b200_peak_tops_float8_tc,
53+
# TODO float4
54+
}
55+
56+
# TODO(this PR): switch automatically by detected hardware type
57+
# TODO(this PR): fp4 is currently using fp8's peak tops below, fix it
58+
dtype_to_peak_tops = dtype_to_peak_tops_b200
59+
60+
61+
# not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991
62+
class DataType(IntEnum):
63+
DEFAULT = 0
64+
E8M0 = 1
65+
FP4 = 2
66+
UFP8 = 3
67+
3768

3869
def benchmark_fn_in_sec(f, *args, **kwargs):
3970
# Manual warmup
@@ -75,6 +106,7 @@ def run(
75106
N: Optional[int] = None,
76107
use_gpu_kernel_time: bool = False,
77108
scaling_granularity: str = "tensorwise",
109+
blockwise_dtype: Optional[str] = None,
78110
):
79111
device = "cuda"
80112

@@ -85,15 +117,17 @@ def run(
85117
"K",
86118
"N",
87119
"ref_time_s",
88-
"fp8_time_s",
89-
"fp8_speedup",
120+
"lowp_time_s",
121+
"lowp_speedup",
90122
)
91123
results = []
92124

93125
dtype = torch.bfloat16
94126
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
95127
fast_accum_vals = [True, False]
96-
scaling_granularity = ScalingGranularity(scaling_granularity)
128+
# Note: blockwise not in enum because blockwise is in prototype
129+
if scaling_granularity != "blockwise":
130+
scaling_granularity = ScalingGranularity(scaling_granularity)
97131

98132
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
99133
itertools.product(fast_accum_vals, name_to_shapes)
@@ -119,28 +153,97 @@ def run(
119153
# raw float8 matmul (upper bound for what we can achive in eager mode)
120154
# TODO(future): add e5m2
121155
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
122-
A = torch.zeros(M, K, device=device, dtype=d1)
123-
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
156+
A = torch.randn(M, K, device=device).to(d1)
157+
B = torch.randn(K, N, device=device).to(d2).t().contiguous().t()
124158
if scaling_granularity == ScalingGranularity.TENSORWISE:
125159
scale_a = torch.tensor([1.0], device=device)
126160
scale_b = torch.tensor([1.0], device=device)
127-
else:
128-
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
161+
elif scaling_granularity == ScalingGranularity.AXISWISE:
129162
scale_a = torch.ones(M, 1, device=device)
130163
scale_b = torch.ones(1, N, device=device)
164+
elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
165+
# TODO(this PR): also block size 16
166+
BLOCK_SIZE = 32
167+
A = torch.randint(128, (M, K), device=device, dtype=torch.uint8).view(
168+
torch.float8_e4m3fn
169+
)
170+
B = (
171+
torch.randint(128, (N, K), device=device, dtype=torch.uint8)
172+
.view(torch.float8_e4m3fn)
173+
.t()
174+
)
175+
scale_a = torch.randint(
176+
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
177+
)
178+
scale_b = torch.randint(
179+
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
180+
).t()
181+
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
182+
# TODO(this PR): also block size 16
183+
BLOCK_SIZE = 16
184+
A = torch.randint(128, (M, K // 2), device=device, dtype=torch.uint8).view(
185+
torch.float8_e4m3fn
186+
)
187+
B = (
188+
torch.randint(128, (N, K // 2), device=device, dtype=torch.uint8)
189+
.view(torch.float8_e4m3fn)
190+
.t()
191+
)
192+
scale_a = torch.randint(
193+
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
194+
)
195+
scale_b = torch.randint(
196+
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
197+
).t()
198+
else:
199+
raise AssertionError(f"unsupported granularity {scaling_granularity}")
131200

132201
def do_matmul(A, B):
133202
nonlocal scale_a
134203
nonlocal scale_b
135-
return torch._scaled_mm(
136-
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
137-
)
204+
205+
if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
206+
return torch._scaled_mm(
207+
A,
208+
B,
209+
scale_a,
210+
scale_b,
211+
bias=None,
212+
scale_result=None,
213+
out_dtype=d3,
214+
use_fast_accum=fast_accum,
215+
a_dtype=None, # inferred from A
216+
b_dtype=None, # inferred from B
217+
scale_dtype=DataType.E8M0,
218+
)
219+
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
220+
return torch._scaled_mm(
221+
A,
222+
B,
223+
scale_a,
224+
scale_b,
225+
bias=None,
226+
scale_result=None,
227+
out_dtype=d3,
228+
use_fast_accum=fast_accum,
229+
a_dtype=DataType.FP4,
230+
b_dtype=DataType.FP4,
231+
scale_dtype=DataType.E8M0,
232+
)
233+
234+
else:
235+
return torch._scaled_mm(
236+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
237+
)
238+
239+
# test
240+
# res = do_matmul(A, B)
138241

139242
fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
140243
tops, dtype_to_peak_tops[d1], use_gpu_kernel_time, do_matmul, A, B
141244
)
142245
print(
143-
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
246+
f"lowp time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
144247
)
145248

146249
del A, B, scale_a, scale_b

test/prototype/mx_formats/test_mx_linear.py

+93
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
from enum import IntEnum
89

910
import pytest
1011
import torch
@@ -30,6 +31,14 @@
3031
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3132

3233

34+
# not for land, https://www.internalfb.com/phabricator/paste/view/P1717686991
35+
class DataType(IntEnum):
36+
DEFAULT = 0
37+
E8M0 = 1
38+
FP4 = 2
39+
UFP8 = 3
40+
41+
3342
# source: https://stackoverflow.com/a/22638709
3443
@pytest.fixture(autouse=True)
3544
def run_around_tests():
@@ -234,3 +243,87 @@ def test_filter_fn():
234243
swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501
235244
assert type(m2[0]) == MXInferenceLinear
236245
assert type(m2[1]) == torch.nn.Linear
246+
247+
248+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
249+
@pytest.mark.skipif(
250+
not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher"
251+
)
252+
def test_scaled_mm_mxfp8():
253+
# hello world
254+
# next: basic numerics
255+
256+
M, K, N = 8192, 4096, 8192
257+
BLOCK_SIZE = 32
258+
a = torch.randint(128, (M, K), device="cuda", dtype=torch.uint8).view(
259+
torch.float8_e4m3fn
260+
)
261+
b = (
262+
torch.randint(128, (N, K), device="cuda", dtype=torch.uint8)
263+
.view(torch.float8_e4m3fn)
264+
.t()
265+
)
266+
a_scales = torch.randint(
267+
128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8
268+
).view(M, K // BLOCK_SIZE)
269+
b_scales = (
270+
torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8)
271+
.view(N, K // BLOCK_SIZE)
272+
.t()
273+
)
274+
out = torch._scaled_mm(
275+
a,
276+
b,
277+
a_scales,
278+
b_scales,
279+
None,
280+
None,
281+
torch.bfloat16,
282+
False,
283+
None,
284+
None,
285+
DataType.E8M0,
286+
)
287+
print(out)
288+
289+
290+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
291+
@pytest.mark.skipif(
292+
not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher"
293+
)
294+
def test_scaled_mm_nvfp4():
295+
# hello world
296+
# next: basic numerics
297+
298+
M, K, N = 8192, 4096, 8192
299+
BLOCK_SIZE = 16
300+
a = torch.randint(128, ((M * K) // 2,), device="cuda", dtype=torch.uint8).view(
301+
M, K // 2
302+
)
303+
b = (
304+
torch.randint(128, ((K * N) // 2,), device="cuda", dtype=torch.uint8)
305+
.view(N, K // 2)
306+
.t()
307+
)
308+
a_scales = torch.randint(
309+
128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8
310+
).view(M, K // BLOCK_SIZE)
311+
b_scales = (
312+
torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8)
313+
.view(N, K // BLOCK_SIZE)
314+
.t()
315+
)
316+
out = torch._scaled_mm(
317+
a,
318+
b,
319+
a_scales,
320+
b_scales,
321+
None,
322+
None,
323+
torch.bfloat16,
324+
False,
325+
DataType.FP4,
326+
DataType.FP4,
327+
DataType.UFP8,
328+
)
329+
print(out)

0 commit comments

Comments
 (0)