Skip to content

Commit b7126d8

Browse files
committed
Adding roofline benchmark for float8 inference;
1 parent a685747 commit b7126d8

File tree

1 file changed

+319
-0
lines changed

1 file changed

+319
-0
lines changed
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
This is a script to estimate the benefit from converting a `torch.nn.Linear`
9+
layer to float8, by estimating the difference in e2e GPU kernel time between:
10+
1. bf16 gemms in fwd and bwd, and
11+
2. float8 gemms in fwd and bwd, and float8 overhead
12+
13+
The gemm times are estimated either from direct measurements via benchmarks,
14+
or with a roofline estimation based on TOPS and peak compute bandwidth of an
15+
NVIDIA H100.
16+
17+
The float8 overhead times are estimated by counting memory reads and writes
18+
based on the specified float8 scaling, and estimating that we can achieve
19+
a certain % of machine peak memory bandwidth when performing these reads and writes.
20+
21+
Additional context:
22+
1. the formulas for fwd/bwd gemms in a linear layer, with corresponding input
23+
and output sizes:
24+
25+
input @ weight_t = output
26+
MxK @ KxN => MxN
27+
28+
grad_output @ weight = grad_input
29+
MxN @ NxK => MxK
30+
31+
input_t @ grad_output = grad_weight
32+
KxM @ MxN => KxN
33+
34+
2. we properly model the worst-case of the current torch.compile limitations regarding
35+
float8 scaling
36+
3. assume for float8 activations/gradients that torch.compile will fuse to the
37+
preceding op. Note that this is not always true in practice.
38+
4. assume no AC (TODO model it)
39+
5. assume no float8 all-gather (TODO model it)
40+
"""
41+
42+
import csv
43+
import copy
44+
import json
45+
import os
46+
import time
47+
from typing import Optional
48+
49+
import fire
50+
import pandas as pd
51+
import sympy
52+
import tqdm
53+
54+
import torch
55+
import torch.utils.benchmark as benchmark
56+
from torch.profiler import profile, ProfilerActivity, record_function
57+
58+
from utils import (
59+
get_name_to_shapes_iter,
60+
get_gpu_kernel_gemm_time_s,
61+
profiler_output_to_filtered_time_by_kernel_name,
62+
)
63+
from torchao.float8.roofline_utils import (
64+
get_gemm_time_sympy,
65+
get_float8_mem_sympy,
66+
)
67+
from torchao.float8 import (
68+
convert_to_float8_training,
69+
Float8LinearConfig,
70+
ScalingType,
71+
CastConfig,
72+
)
73+
74+
75+
class LNLinearSigmoid(torch.nn.Module):
76+
def __init__(self, fc_dim1, fc_dim2):
77+
super().__init__()
78+
self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False)
79+
self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False)
80+
self.sigmoid = torch.nn.Sigmoid()
81+
82+
def forward(self, x):
83+
x = self.ln(x)
84+
x = self.fc(x)
85+
x = self.sigmoid(x)
86+
return x
87+
# TODO(next): hook this up
88+
89+
90+
def benchmark_fn_in_sec(f, *args, **kwargs):
91+
# Manual warmup
92+
for _ in range(4):
93+
f(*args, **kwargs)
94+
t0 = benchmark.Timer(
95+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
96+
)
97+
measurement = t0.blocked_autorange()
98+
return measurement.mean
99+
100+
101+
def get_gpu_kernel_time(m, x):
102+
# warm up
103+
# for _ in range(2):
104+
# m(x).sum().backward()
105+
106+
# capture a profiling run
107+
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
108+
n_iter = 5
109+
with profile(activities=activities) as prof:
110+
for _ in range(n_iter):
111+
# m(x).sum().backward()
112+
torch.cuda.synchronize()
113+
# get the gpu kernel time and aggregate it
114+
num_leaf_tensors = 1 + len(list(m.parameters()))
115+
ref_times = profiler_output_to_filtered_time_by_kernel_name(
116+
prof, n_iter, num_leaf_tensors)
117+
total_time_s = sum(v for v in ref_times.values()) / 1e6 / n_iter
118+
return total_time_s
119+
120+
def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
121+
122+
# Note: this is definitely not the best way to build a cache,
123+
# but it will do for now.
124+
if cache_filename is not None:
125+
if os.path.isfile(cache_filename):
126+
# cache already exists, use it
127+
with open(cache_filename, 'r') as f:
128+
cache = json.load(f)
129+
else:
130+
# cache does not exist yet, create it
131+
cache = dict()
132+
key = f"{M},{K},{N},{fast_accum}"
133+
if key in cache:
134+
return cache[key]
135+
136+
device = torch.device('cuda')
137+
138+
# bf16 time
139+
x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device)
140+
w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
141+
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
142+
143+
# f8 time
144+
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16
145+
A = torch.zeros(M, K, device=device, dtype=d1)
146+
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
147+
scale_a = torch.tensor([1.0], device=device)
148+
scale_b = torch.tensor([1.0], device=device)
149+
150+
def do_matmul(A, B):
151+
return torch._scaled_mm(
152+
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
153+
)
154+
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
155+
156+
# save to cache if needed
157+
if cache_filename is not None:
158+
cache[key] = [bf16_time_s, f8_time_s]
159+
with open(cache_filename, 'w') as f:
160+
json.dump(cache, f)
161+
162+
return bf16_time_s, f8_time_s
163+
164+
def run(
165+
outfile: str,
166+
gemm_time_strategy: str = "benchmarks",
167+
model_torch_compile_limitations: bool = False,
168+
scaling_type_input: str = "dynamic",
169+
scaling_type_weight: str = "dynamic",
170+
scaling_type_grad_output: str = "dynamic",
171+
shape_gen_name: str = "square",
172+
gemm_cache_filename: Optional[str] = None,
173+
n_limit: Optional[int] = None,
174+
):
175+
"""
176+
Args:
177+
* `gemm_time_strategy`:
178+
- `benchmarks`: use benchmarks for gemm times (more accurate for all shapes)
179+
- `roofline`: use roofline model for gemm times (only accurate for large shapes)
180+
* `shape_gen_name`: `llama`, `square`, or `sweep`
181+
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
182+
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
183+
"""
184+
185+
print(f'gemm_time_strategy: {gemm_time_strategy}')
186+
print(f'shape_gen_name: {shape_gen_name}')
187+
188+
assert gemm_time_strategy in ("benchmarks", "roofline"), \
189+
"`gemm_time_strategy` must be 'benchmarks' or 'roofline'"
190+
191+
M, K, N = sympy.symbols('M K N')
192+
193+
fp8_mem_time_sympy_dyn_limit = get_float8_mem_sympy(
194+
M, K, N,
195+
model_torch_compile_limitations=True,
196+
scaling_type_input="dynamic",
197+
scaling_type_weight="dynamic",
198+
scaling_type_grad_output="dynamic",
199+
)
200+
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
201+
M, K, N,
202+
model_torch_compile_limitations=False,
203+
scaling_type_input="dynamic",
204+
scaling_type_weight="dynamic",
205+
scaling_type_grad_output="dynamic",
206+
)
207+
fp8_mem_time_sympy_del_limit = get_float8_mem_sympy(
208+
M, K, N,
209+
model_torch_compile_limitations=True,
210+
scaling_type_input="delayed",
211+
scaling_type_weight="delayed",
212+
scaling_type_grad_output="delayed",
213+
)
214+
fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy(
215+
M, K, N,
216+
model_torch_compile_limitations=False,
217+
scaling_type_input="delayed",
218+
scaling_type_weight="delayed",
219+
scaling_type_grad_output="delayed",
220+
)
221+
222+
if gemm_time_strategy == "roofline":
223+
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
224+
print('bf16_gemm_time_sympy', bf16_gemm_time_sympy)
225+
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
226+
print('fp8_gemm_time_sympy', fp8_gemm_time_sympy)
227+
print()
228+
else:
229+
print()
230+
231+
headers = [
232+
'fwd_M', 'fwd_K', 'fwd_N',
233+
# gemm microbenchmarks
234+
'bf16_gemm_s', 'fp8_gemm_s',
235+
# roofline memory overhead estimates
236+
'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit',
237+
'fp8_oh_del_limit', 'fp8_oh_del_nolimit',
238+
# actual e2e measurements
239+
'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s',
240+
'fp8_dyn_speedup', 'fp8_del_speedup',
241+
]
242+
results = []
243+
244+
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, None, None, None)
245+
246+
for idx, (name, (M_val, K_val, N_val)) in enumerate(tqdm.tqdm(name_to_shapes)):
247+
if n_limit is not None and idx >= n_limit:
248+
break
249+
250+
if gemm_time_strategy == "benchmarks":
251+
bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename)
252+
bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename)
253+
bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename)
254+
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
255+
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
256+
else:
257+
assert gemm_time_strategy == "roofline", "unsupported"
258+
bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
259+
fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
260+
261+
fp8_mem_time_dyn_limit_s = \
262+
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
263+
fp8_mem_time_dyn_nolimit_s = \
264+
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
265+
fp8_mem_time_del_limit_s = \
266+
fp8_mem_time_sympy_del_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
267+
fp8_mem_time_del_nolimit_s = \
268+
fp8_mem_time_sympy_del_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
269+
270+
# create the model
271+
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
272+
x = torch.randn(M_val, K_val, dtype=torch.bfloat16, device="cuda").requires_grad_()
273+
274+
# get the bf16 gpu kernel time
275+
torch._dynamo.reset()
276+
m_bf16 = torch.compile(copy.deepcopy(m_orig))
277+
bf16_time_actual_s = get_gpu_kernel_time(m_bf16, x)
278+
279+
# get the float8 dynamic scaling gpu kernel time
280+
torch._dynamo.reset()
281+
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
282+
m_fp8_dyn = torch.compile(m_fp8_dyn)
283+
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)
284+
285+
# get the float8 delayed scaling gpu kernel time
286+
torch._dynamo.reset()
287+
config = Float8LinearConfig(
288+
enable_amax_init=False,
289+
enable_pre_and_post_forward=False,
290+
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
291+
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
292+
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
293+
)
294+
m_fp8_del = convert_to_float8_training(m_orig)
295+
m_fp8_del = torch.compile(m_fp8_del)
296+
fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x)
297+
298+
results.append([
299+
M_val, K_val, N_val,
300+
# gemm microbenchmarks
301+
bf16_time_val, fp8_gemm_time_s,
302+
# roofline overhead estimates
303+
fp8_mem_time_dyn_limit_s,
304+
fp8_mem_time_dyn_nolimit_s,
305+
fp8_mem_time_del_limit_s,
306+
fp8_mem_time_del_nolimit_s,
307+
# e2e numbers
308+
bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s,
309+
bf16_time_actual_s / fp8_dyn_time_actual_s,
310+
bf16_time_actual_s / fp8_del_time_actual_s,
311+
])
312+
313+
df = pd.DataFrame(results, columns=headers)
314+
print(df)
315+
df.to_csv(outfile)
316+
print('done')
317+
318+
if __name__ == '__main__':
319+
fire.Fire(run)

0 commit comments

Comments
 (0)