Skip to content

Commit 1170bc2

Browse files
committed
[not for land] hook up MX to CUDA 12.8 cuBLAS MX gemm
Summary: Requires https://github.com/pytorch/pytorch/pull/145562/files None of this is for land - just testing for now as we work on a long term support plan. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: bebf96fe36e96ea4f0a3385fc1012e728d9ed5c4 ghstack-comment-id: 2616580142 Pull Request resolved: #1625
1 parent cf2eb3c commit 1170bc2

File tree

2 files changed

+201
-12
lines changed

2 files changed

+201
-12
lines changed

benchmarks/float8/bench_matmul.py

+115-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import itertools
7+
from enum import IntEnum
78
from typing import Optional
89

910
import fire
@@ -26,14 +27,44 @@
2627
h100_peak_flops_fp16_tc = 989e12
2728
h100_peak_tops_float8_tc = 1979e12
2829

29-
dtype_to_peak_tops = {
30+
# HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/
31+
# note: divided numbers from ^ by 2 to undo the effects of sparsity
32+
# TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8,
33+
# something seems funky
34+
b200_peak_flops_float32 = 600e12
35+
b200_peak_flops_fp16_tc = 18e15
36+
b200_peak_tops_float8_tc = 36e15
37+
b200_peak_tops_float4_tc = 72e15
38+
39+
dtype_to_peak_tops_h100 = {
3040
torch.float32: h100_peak_flops_float32,
3141
torch.float16: h100_peak_flops_fp16_tc,
3242
torch.bfloat16: h100_peak_flops_fp16_tc,
3343
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
3444
torch.float8_e5m2: h100_peak_tops_float8_tc,
3545
}
3646

47+
dtype_to_peak_tops_b200 = {
48+
torch.float32: b200_peak_flops_float32,
49+
torch.float16: b200_peak_flops_fp16_tc,
50+
torch.bfloat16: b200_peak_flops_fp16_tc,
51+
torch.float8_e4m3fn: b200_peak_tops_float8_tc,
52+
torch.float8_e5m2: b200_peak_tops_float8_tc,
53+
# TODO float4
54+
}
55+
56+
# TODO(this PR): switch automatically by detected hardware type
57+
# TODO(this PR): fp4 is currently using fp8's peak tops below, fix it
58+
dtype_to_peak_tops = dtype_to_peak_tops_b200
59+
60+
61+
# not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991
62+
class DataType(IntEnum):
63+
DEFAULT = 0
64+
E8M0 = 1
65+
FP4 = 2
66+
UFP8 = 3
67+
3768

3869
def benchmark_fn_in_sec(f, *args, **kwargs):
3970
# Manual warmup
@@ -75,6 +106,7 @@ def run(
75106
N: Optional[int] = None,
76107
use_gpu_kernel_time: bool = False,
77108
scaling_granularity: str = "tensorwise",
109+
blockwise_dtype: Optional[str] = None,
78110
):
79111
device = "cuda"
80112

@@ -85,15 +117,17 @@ def run(
85117
"K",
86118
"N",
87119
"ref_time_s",
88-
"fp8_time_s",
89-
"fp8_speedup",
120+
"lowp_time_s",
121+
"lowp_speedup",
90122
)
91123
results = []
92124

93125
dtype = torch.bfloat16
94126
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
95127
fast_accum_vals = [True, False]
96-
scaling_granularity = ScalingGranularity(scaling_granularity)
128+
# Note: blockwise not in enum because blockwise is in prototype
129+
if scaling_granularity != "blockwise":
130+
scaling_granularity = ScalingGranularity(scaling_granularity)
97131

98132
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
99133
itertools.product(fast_accum_vals, name_to_shapes)
@@ -119,28 +153,97 @@ def run(
119153
# raw float8 matmul (upper bound for what we can achive in eager mode)
120154
# TODO(future): add e5m2
121155
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
122-
A = torch.zeros(M, K, device=device, dtype=d1)
123-
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
156+
A = torch.randn(M, K, device=device).to(d1)
157+
B = torch.randn(K, N, device=device).to(d2).t().contiguous().t()
124158
if scaling_granularity == ScalingGranularity.TENSORWISE:
125159
scale_a = torch.tensor([1.0], device=device)
126160
scale_b = torch.tensor([1.0], device=device)
127-
else:
128-
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
161+
elif scaling_granularity == ScalingGranularity.AXISWISE:
129162
scale_a = torch.ones(M, 1, device=device)
130163
scale_b = torch.ones(1, N, device=device)
164+
elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
165+
# TODO(this PR): also block size 16
166+
BLOCK_SIZE = 32
167+
A = torch.randint(128, (M, K), device=device, dtype=torch.uint8).view(
168+
torch.float8_e4m3fn
169+
)
170+
B = (
171+
torch.randint(128, (N, K), device=device, dtype=torch.uint8)
172+
.view(torch.float8_e4m3fn)
173+
.t()
174+
)
175+
scale_a = torch.randint(
176+
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
177+
)
178+
scale_b = torch.randint(
179+
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
180+
).t()
181+
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
182+
# TODO(this PR): also block size 16
183+
BLOCK_SIZE = 16
184+
A = torch.randint(128, (M, K // 2), device=device, dtype=torch.uint8).view(
185+
torch.float8_e4m3fn
186+
)
187+
B = (
188+
torch.randint(128, (N, K // 2), device=device, dtype=torch.uint8)
189+
.view(torch.float8_e4m3fn)
190+
.t()
191+
)
192+
scale_a = torch.randint(
193+
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
194+
)
195+
scale_b = torch.randint(
196+
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
197+
).t()
198+
else:
199+
raise AssertionError(f"unsupported granularity {scaling_granularity}")
131200

132201
def do_matmul(A, B):
133202
nonlocal scale_a
134203
nonlocal scale_b
135-
return torch._scaled_mm(
136-
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
137-
)
204+
205+
if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
206+
return torch._scaled_mm(
207+
A,
208+
B,
209+
scale_a,
210+
scale_b,
211+
bias=None,
212+
scale_result=None,
213+
out_dtype=d3,
214+
use_fast_accum=fast_accum,
215+
a_dtype=None, # inferred from A
216+
b_dtype=None, # inferred from B
217+
scale_dtype=DataType.E8M0,
218+
)
219+
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
220+
return torch._scaled_mm(
221+
A,
222+
B,
223+
scale_a,
224+
scale_b,
225+
bias=None,
226+
scale_result=None,
227+
out_dtype=d3,
228+
use_fast_accum=fast_accum,
229+
a_dtype=DataType.FP4,
230+
b_dtype=DataType.FP4,
231+
scale_dtype=DataType.E8M0,
232+
)
233+
234+
else:
235+
return torch._scaled_mm(
236+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
237+
)
238+
239+
# test
240+
# res = do_matmul(A, B)
138241

139242
fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
140243
tops, dtype_to_peak_tops[d1], use_gpu_kernel_time, do_matmul, A, B
141244
)
142245
print(
143-
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
246+
f"lowp time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
144247
)
145248

146249
del A, B, scale_a, scale_b

test/prototype/mx_formats/test_mx_linear.py

+86
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
from enum import IntEnum
89

910
import pytest
1011
import torch
@@ -30,6 +31,14 @@
3031
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3132

3233

34+
# not for land, https://www.internalfb.com/phabricator/paste/view/P1717686991
35+
class DataType(IntEnum):
36+
DEFAULT = 0
37+
E8M0 = 1
38+
FP4 = 2
39+
UFP8 = 3
40+
41+
3342
# source: https://stackoverflow.com/a/22638709
3443
@pytest.fixture(autouse=True)
3544
def run_around_tests():
@@ -234,3 +243,80 @@ def test_filter_fn():
234243
swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501
235244
assert type(m2[0]) == MXInferenceLinear
236245
assert type(m2[1]) == torch.nn.Linear
246+
247+
248+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
249+
@pytest.mark.skipif(
250+
not is_sm_at_least_100(),
251+
reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher",
252+
)
253+
def test_scaled_mm_mxfp8():
254+
# basic numerics with all scales 1.0
255+
# next: other scale values
256+
257+
# M, K, N = 8192, 4096, 8192
258+
M, K, N = 128, 128, 128
259+
BLOCK_SIZE = 32
260+
a = torch.eye(M, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn)
261+
b = torch.eye(M, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn).t()
262+
263+
# 127 is 1.0 in e8m0
264+
scale_val = 127
265+
266+
a_scales = torch.full(
267+
(M, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8
268+
)
269+
b_scales = torch.full(
270+
(K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8
271+
).t()
272+
out = torch._scaled_mm(
273+
a,
274+
b,
275+
a_scales,
276+
b_scales,
277+
None,
278+
None,
279+
torch.bfloat16,
280+
False,
281+
None,
282+
None,
283+
DataType.E8M0,
284+
)
285+
286+
# [[1, 0, ...], ..., [0, ..., 1]] - correct
287+
print(out)
288+
289+
290+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
291+
@pytest.mark.skipif(
292+
not is_sm_at_least_100(),
293+
reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher",
294+
)
295+
def test_scaled_mm_nvfp4():
296+
# hello world
297+
# next: basic numerics
298+
299+
M, K, N = 8192, 4096, 8192
300+
BLOCK_SIZE = 16
301+
a = torch.randint(128, (M, K // 2), device="cuda", dtype=torch.uint8)
302+
b = torch.randint(128, (N, K // 2), device="cuda", dtype=torch.uint8).t()
303+
a_scales = torch.randint(
304+
128, (M, K // BLOCK_SIZE), device="cuda", dtype=torch.uint8
305+
)
306+
b_scales = torch.randint(
307+
128, (N, K // BLOCK_SIZE), device="cuda", dtype=torch.uint8
308+
).t()
309+
out = torch._scaled_mm(
310+
a,
311+
b,
312+
a_scales,
313+
b_scales,
314+
None,
315+
None,
316+
torch.bfloat16,
317+
False,
318+
DataType.FP4,
319+
DataType.FP4,
320+
DataType.UFP8,
321+
)
322+
print(out)

0 commit comments

Comments
 (0)