Skip to content

Commit 4e33f57

Browse files
committed
[not for land] hook up MX to CUDA 12.8 cuBLAS MX gemm
Summary: Requires https://github.com/pytorch/pytorch/pull/145562/files None of this is for land - just testing for now as we work on a long term support plan. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 946b39a28673d9a9b7945c4ad51c638bfd3963cc ghstack-comment-id: 2616580142 Pull Request resolved: #1625
1 parent e30d654 commit 4e33f57

File tree

2 files changed

+426
-12
lines changed

2 files changed

+426
-12
lines changed

benchmarks/float8/bench_matmul.py

+115-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import itertools
7+
from enum import IntEnum
78
from typing import Optional
89

910
import fire
@@ -26,14 +27,44 @@
2627
h100_peak_flops_fp16_tc = 989e12
2728
h100_peak_tops_float8_tc = 1979e12
2829

29-
dtype_to_peak_tops = {
30+
# HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/
31+
# note: divided numbers from ^ by 2 to undo the effects of sparsity
32+
# TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8,
33+
# something seems funky
34+
b200_peak_flops_float32 = 600e12
35+
b200_peak_flops_fp16_tc = 18e15
36+
b200_peak_tops_float8_tc = 36e15
37+
b200_peak_tops_float4_tc = 72e15
38+
39+
dtype_to_peak_tops_h100 = {
3040
torch.float32: h100_peak_flops_float32,
3141
torch.float16: h100_peak_flops_fp16_tc,
3242
torch.bfloat16: h100_peak_flops_fp16_tc,
3343
torch.float8_e4m3fn: h100_peak_tops_float8_tc,
3444
torch.float8_e5m2: h100_peak_tops_float8_tc,
3545
}
3646

47+
dtype_to_peak_tops_b200 = {
48+
torch.float32: b200_peak_flops_float32,
49+
torch.float16: b200_peak_flops_fp16_tc,
50+
torch.bfloat16: b200_peak_flops_fp16_tc,
51+
torch.float8_e4m3fn: b200_peak_tops_float8_tc,
52+
torch.float8_e5m2: b200_peak_tops_float8_tc,
53+
# TODO float4
54+
}
55+
56+
# TODO(this PR): switch automatically by detected hardware type
57+
# TODO(this PR): fp4 is currently using fp8's peak tops below, fix it
58+
dtype_to_peak_tops = dtype_to_peak_tops_b200
59+
60+
61+
# not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991
62+
class DataType(IntEnum):
63+
DEFAULT = 0
64+
E8M0 = 1
65+
FP4 = 2
66+
UFP8 = 3
67+
3768

3869
def benchmark_fn_in_sec(f, *args, **kwargs):
3970
# Manual warmup
@@ -75,6 +106,7 @@ def run(
75106
N: Optional[int] = None,
76107
use_gpu_kernel_time: bool = False,
77108
scaling_granularity: str = "tensorwise",
109+
blockwise_dtype: Optional[str] = None,
78110
):
79111
device = "cuda"
80112

@@ -85,15 +117,17 @@ def run(
85117
"K",
86118
"N",
87119
"ref_time_s",
88-
"fp8_time_s",
89-
"fp8_speedup",
120+
"lowp_time_s",
121+
"lowp_speedup",
90122
)
91123
results = []
92124

93125
dtype = torch.bfloat16
94126
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
95127
fast_accum_vals = [True, False]
96-
scaling_granularity = ScalingGranularity(scaling_granularity)
128+
# Note: blockwise not in enum because blockwise is in prototype
129+
if scaling_granularity != "blockwise":
130+
scaling_granularity = ScalingGranularity(scaling_granularity)
97131

98132
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
99133
itertools.product(fast_accum_vals, name_to_shapes)
@@ -119,28 +153,97 @@ def run(
119153
# raw float8 matmul (upper bound for what we can achive in eager mode)
120154
# TODO(future): add e5m2
121155
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
122-
A = torch.zeros(M, K, device=device, dtype=d1)
123-
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
156+
A = torch.randn(M, K, device=device).to(d1)
157+
B = torch.randn(K, N, device=device).to(d2).t().contiguous().t()
124158
if scaling_granularity == ScalingGranularity.TENSORWISE:
125159
scale_a = torch.tensor([1.0], device=device)
126160
scale_b = torch.tensor([1.0], device=device)
127-
else:
128-
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
161+
elif scaling_granularity == ScalingGranularity.AXISWISE:
129162
scale_a = torch.ones(M, 1, device=device)
130163
scale_b = torch.ones(1, N, device=device)
164+
elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
165+
# TODO(this PR): also block size 16
166+
BLOCK_SIZE = 32
167+
A = torch.randint(128, (M, K), device=device, dtype=torch.uint8).view(
168+
torch.float8_e4m3fn
169+
)
170+
B = (
171+
torch.randint(128, (N, K), device=device, dtype=torch.uint8)
172+
.view(torch.float8_e4m3fn)
173+
.t()
174+
)
175+
scale_a = torch.randint(
176+
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
177+
)
178+
scale_b = torch.randint(
179+
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
180+
).t()
181+
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
182+
# TODO(this PR): also block size 16
183+
BLOCK_SIZE = 16
184+
A = torch.randint(128, (M, K // 2), device=device, dtype=torch.uint8).view(
185+
torch.float8_e4m3fn
186+
)
187+
B = (
188+
torch.randint(128, (N, K // 2), device=device, dtype=torch.uint8)
189+
.view(torch.float8_e4m3fn)
190+
.t()
191+
)
192+
scale_a = torch.randint(
193+
128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
194+
)
195+
scale_b = torch.randint(
196+
128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda"
197+
).t()
198+
else:
199+
raise AssertionError(f"unsupported granularity {scaling_granularity}")
131200

132201
def do_matmul(A, B):
133202
nonlocal scale_a
134203
nonlocal scale_b
135-
return torch._scaled_mm(
136-
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
137-
)
204+
205+
if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3":
206+
return torch._scaled_mm(
207+
A,
208+
B,
209+
scale_a,
210+
scale_b,
211+
bias=None,
212+
scale_result=None,
213+
out_dtype=d3,
214+
use_fast_accum=fast_accum,
215+
a_dtype=None, # inferred from A
216+
b_dtype=None, # inferred from B
217+
scale_dtype=DataType.E8M0,
218+
)
219+
elif scaling_granularity == "blockwise" and blockwise_dtype == "float4":
220+
return torch._scaled_mm(
221+
A,
222+
B,
223+
scale_a,
224+
scale_b,
225+
bias=None,
226+
scale_result=None,
227+
out_dtype=d3,
228+
use_fast_accum=fast_accum,
229+
a_dtype=DataType.FP4,
230+
b_dtype=DataType.FP4,
231+
scale_dtype=DataType.E8M0,
232+
)
233+
234+
else:
235+
return torch._scaled_mm(
236+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
237+
)
238+
239+
# test
240+
# res = do_matmul(A, B)
138241

139242
fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
140243
tops, dtype_to_peak_tops[d1], use_gpu_kernel_time, do_matmul, A, B
141244
)
142245
print(
143-
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
246+
f"lowp time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
144247
)
145248

146249
del A, B, scale_a, scale_b

0 commit comments

Comments
 (0)