Skip to content

Commit f478692

Browse files
authored
roofline estimator: simplify (#1783)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent cd69415 commit f478692

File tree

2 files changed

+83
-96
lines changed

2 files changed

+83
-96
lines changed

benchmarks/float8/float8_roofline.py

+82-82
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
"""
88
This is a script to estimate the benefit from converting a `torch.nn.Linear`
9-
layer to float8, by estimating the difference in e2e GPU kernel time between:
9+
layer to float8 given a single saturated GPU, by estimating the difference
10+
in e2e GPU kernel time between:
1011
1. bf16 gemms in fwd and bwd, and
1112
2. float8 gemms in fwd and bwd, and float8 overhead
1213
1314
The gemm times are estimated either from direct measurements via benchmarks,
1415
or with a roofline estimation based on TOPS and peak compute bandwidth of an
15-
NVIDIA H100.
16+
NVIDIA H100 or B200.
1617
1718
The float8 overhead times are estimated by counting memory reads and writes
1819
based on the specified float8 scaling, and estimating that we can achieve
@@ -31,12 +32,10 @@
3132
input_t @ grad_output = grad_weight
3233
KxM @ MxN => KxN
3334
34-
2. we properly model the worst-case of the current torch.compile limitations regarding
35-
float8 scaling
36-
3. assume for float8 activations/gradients that torch.compile will fuse to the
35+
2. assume for float8 activations/gradients that torch.compile will fuse to the
3736
preceding op. Note that this is not always true in practice.
38-
4. assume no AC (TODO model it)
39-
5. assume no float8 all-gather (TODO model it)
37+
3. assume no AC (TODO model it)
38+
4. assume no float8 all-gather (TODO model it)
4039
"""
4140

4241
import copy
@@ -164,68 +163,60 @@ def do_matmul(A, B):
164163

165164
def run(
166165
outfile: str,
167-
gemm_time_strategy: str = "benchmarks",
168-
model_torch_compile_limitations: bool = False,
166+
do_benchmarks: bool = True,
169167
shape_gen_name: str = "square",
170168
gemm_cache_filename: Optional[str] = None,
171169
n_limit: Optional[int] = None,
172170
):
173171
"""
174172
Args:
175-
* `gemm_time_strategy`:
176-
- `benchmarks`: use benchmarks for gemm times (more accurate for all shapes)
177-
- `roofline`: use roofline model for gemm times (only accurate for large shapes)
173+
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
178174
* `shape_gen_name`: `llama`, `square`, or `sweep`
179175
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
180176
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
181177
"""
182178

183-
print(f"gemm_time_strategy: {gemm_time_strategy}")
179+
print(f"do_benchmarks: {do_benchmarks}")
184180
print(f"shape_gen_name: {shape_gen_name}")
185181

186-
assert gemm_time_strategy in (
187-
"benchmarks",
188-
"roofline",
189-
), "`gemm_time_strategy` must be 'benchmarks' or 'roofline'"
190-
191182
M, K, N = sympy.symbols("M K N")
192183

193-
fp8_mem_time_sympy_dyn_limit = get_float8_mem_sympy(
194-
M,
195-
K,
196-
N,
197-
model_torch_compile_limitations=True,
198-
)
199184
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy(
200185
M,
201186
K,
202187
N,
203-
model_torch_compile_limitations=False,
204188
)
205189

206-
if gemm_time_strategy == "roofline":
207-
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
208-
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
209-
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
210-
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
211-
print()
212-
else:
213-
print()
190+
bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16)
191+
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
192+
fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn)
193+
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
194+
print()
214195

215196
headers = [
216197
"fwd_M",
217198
"fwd_K",
218199
"fwd_N",
219-
# gemm microbenchmarks
220-
"bf16_gemm_s",
221-
"fp8_gemm_s",
222-
# roofline memory overhead estimates
223-
"fp8_oh_estimated",
224-
"fp8_oh_ideal",
225-
# actual e2e measurements
226-
"bf16_s",
227-
"fp8_dyn_s",
228-
"fp8_dyn_sp",
200+
# roofline - gemm time (fwd + bwd, 3 gemms)
201+
"r_bf16_gemm_s",
202+
"r_fp8_gemm_s",
203+
# roofline - fp8 overhead time (by counting reads/writes in the ideal case)
204+
"r_fp8_ovhd_s",
205+
# roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid)
206+
"r_fp8_gemm_and_ovhd_s",
207+
"r_fp8_gemm_and_ovhd_spdp",
208+
# benchmarks - gemm time (fwd + bwd, 3 gemms)
209+
"b_bf16_gemm_s",
210+
"b_fp8_gemm_s",
211+
# benchmarks - e2e LNLinearSigmoid time fwd + bwd
212+
"b_bf16_e2e_s",
213+
"b_fp8_e2e_s",
214+
# note that e2e speedup is not the same as the roofline speedup:
215+
# 1. roofline speedup: (bf16_gemm_time) / (fp8_gemm_time + fp8_ovhd_time)
216+
# 2. e2e speedup: (ln + bf16_gemm_time + sigmoid) / (ln + fp8_gemm_time + fp8_ovhd_time + sigmoid)
217+
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218+
# we don't break them out and don't have a roofline for them.
219+
"b_fp8_e2e_spdp",
229220
]
230221
results = []
231222

@@ -235,7 +226,18 @@ def run(
235226
if n_limit is not None and idx >= n_limit:
236227
break
237228

238-
if gemm_time_strategy == "benchmarks":
229+
# use roofline model to estimate gemm time
230+
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
231+
r_bf16_gemm_time_s = float(
232+
bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
233+
)
234+
r_fp8_gemm_time_s = float(
235+
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
236+
)
237+
238+
# if enabled, also measured observed gemm time
239+
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
240+
if do_benchmarks:
239241
bf16_g1, f8_g1 = get_gemm_times(
240242
M_val, K_val, N_val, True, gemm_cache_filename
241243
)
@@ -245,60 +247,58 @@ def run(
245247
bf16_g3, f8_g3 = get_gemm_times(
246248
K_val, M_val, N_val, False, gemm_cache_filename
247249
)
248-
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
249-
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
250-
else:
251-
assert gemm_time_strategy == "roofline", "unsupported"
252-
bf16_time_val = (
253-
bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
254-
)
255-
fp8_gemm_time_s = (
256-
fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
257-
)
250+
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251+
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
258252

259-
fp8_mem_time_dyn_limit_s = (
260-
fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
261-
)
262-
fp8_mem_time_dyn_nolimit_s = (
253+
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254+
r_fp8_ovhd_time_s = float(
263255
fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val)
264256
)
265257

266-
# create the model
267-
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
268-
x = torch.randn(
269-
M_val, K_val, dtype=torch.bfloat16, device="cuda"
270-
).requires_grad_()
258+
b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0
259+
if do_benchmarks:
260+
# create the model
261+
m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16()
262+
x = torch.randn(
263+
M_val, K_val, dtype=torch.bfloat16, device="cuda"
264+
).requires_grad_()
271265

272-
# get the bf16 gpu kernel time
273-
torch._dynamo.reset()
274-
m_bf16 = torch.compile(copy.deepcopy(m_orig))
275-
bf16_time_actual_s = get_gpu_kernel_time(m_bf16, x)
266+
# get the bf16 gpu kernel time
267+
torch._dynamo.reset()
268+
m_bf16 = torch.compile(copy.deepcopy(m_orig))
269+
b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x)
276270

277-
# get the float8 dynamic scaling gpu kernel time
271+
# get the float8 dynamic scaling gpu kernel time
278272

279-
torch._dynamo.reset()
280-
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
281-
m_fp8_dyn = torch.compile(m_fp8_dyn)
282-
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)
273+
torch._dynamo.reset()
274+
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
275+
m_fp8_dyn = torch.compile(m_fp8_dyn)
276+
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x)
283277

284278
results.append(
285279
[
286280
M_val,
287281
K_val,
288282
N_val,
289-
# gemm microbenchmarks
290-
bf16_time_val,
291-
fp8_gemm_time_s,
292-
# roofline overhead estimates
293-
fp8_mem_time_dyn_limit_s,
294-
fp8_mem_time_dyn_nolimit_s,
295-
# e2e numbers
296-
bf16_time_actual_s,
297-
fp8_dyn_time_actual_s,
298-
bf16_time_actual_s / fp8_dyn_time_actual_s,
283+
# roofline - gemm
284+
r_bf16_gemm_time_s,
285+
r_fp8_gemm_time_s,
286+
# roofline - fp8 overhead
287+
r_fp8_ovhd_time_s,
288+
# roofline - gemm + overhead, and speedup
289+
r_fp8_gemm_time_s + r_fp8_ovhd_time_s,
290+
r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s),
291+
# benchmarks - gemm
292+
b_bf16_gemm_time_s,
293+
b_fp8_gemm_time_s,
294+
# benchmarks - e2e, and speedup
295+
b_bf16_e2e_time_s,
296+
b_fp8_e2e_time_s,
297+
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20),
299298
]
300299
)
301300

301+
pd.set_option("display.precision", 2)
302302
df = pd.DataFrame(results, columns=headers)
303303
print(df)
304304
df.to_csv(outfile)

torchao/testing/float8/roofline_utils.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def get_tensor_memory_traffic_bytes(
5656
dim0,
5757
dim1,
5858
fuse_with_prev=False,
59-
model_torch_compile_limitations=False,
6059
):
6160
# assumes input bf16, output f8
6261
numel = dim0 * dim1
@@ -75,15 +74,7 @@ def get_tensor_memory_traffic_bytes(
7574
# kernel 3: read in bf16, write twice in float8 (row-major and col-major)
7675
kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
7776

78-
if model_torch_compile_limitations:
79-
# today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
80-
# has an extra memory read of the input in fp8
81-
# context: https://github.com/pytorch/pytorch/issues/130015
82-
tc_adjustment = numel * BYTES_PER_EL_FLOAT8
83-
else:
84-
tc_adjustment = 0
85-
86-
return kernel_1_rw + kernel_3_rw + tc_adjustment
77+
return kernel_1_rw + kernel_3_rw
8778

8879

8980
def get_gemm_time_sympy(M, K, N, dtype):
@@ -101,7 +92,6 @@ def get_float8_mem_sympy(
10192
M,
10293
K,
10394
N,
104-
model_torch_compile_limitations: bool = False,
10595
):
10696
specs = get_specs()
10797

@@ -123,13 +113,11 @@ def get_float8_mem_sympy(
123113
M,
124114
K,
125115
fuse_with_prev=True,
126-
model_torch_compile_limitations=model_torch_compile_limitations,
127116
)
128117
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes(
129118
K,
130119
N,
131120
fuse_with_prev=False,
132-
model_torch_compile_limitations=model_torch_compile_limitations,
133121
)
134122
fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem
135123

@@ -140,7 +128,6 @@ def get_float8_mem_sympy(
140128
M,
141129
N,
142130
fuse_with_prev=True,
143-
model_torch_compile_limitations=model_torch_compile_limitations,
144131
)
145132
# already casted, assuming that we save weight from fw to bw
146133
# TODO: model this if FSDP float8 all-gather is on

0 commit comments

Comments
 (0)