6
6
7
7
"""
8
8
This is a script to estimate the benefit from converting a `torch.nn.Linear`
9
- layer to float8, by estimating the difference in e2e GPU kernel time between:
9
+ layer to float8 given a single saturated GPU, by estimating the difference
10
+ in e2e GPU kernel time between:
10
11
1. bf16 gemms in fwd and bwd, and
11
12
2. float8 gemms in fwd and bwd, and float8 overhead
12
13
13
14
The gemm times are estimated either from direct measurements via benchmarks,
14
15
or with a roofline estimation based on TOPS and peak compute bandwidth of an
15
- NVIDIA H100.
16
+ NVIDIA H100 or B200 .
16
17
17
18
The float8 overhead times are estimated by counting memory reads and writes
18
19
based on the specified float8 scaling, and estimating that we can achieve
31
32
input_t @ grad_output = grad_weight
32
33
KxM @ MxN => KxN
33
34
34
- 2. we properly model the worst-case of the current torch.compile limitations regarding
35
- float8 scaling
36
- 3. assume for float8 activations/gradients that torch.compile will fuse to the
35
+ 2. assume for float8 activations/gradients that torch.compile will fuse to the
37
36
preceding op. Note that this is not always true in practice.
38
- 4 . assume no AC (TODO model it)
39
- 5 . assume no float8 all-gather (TODO model it)
37
+ 3 . assume no AC (TODO model it)
38
+ 4 . assume no float8 all-gather (TODO model it)
40
39
"""
41
40
42
41
import copy
@@ -164,68 +163,60 @@ def do_matmul(A, B):
164
163
165
164
def run (
166
165
outfile : str ,
167
- gemm_time_strategy : str = "benchmarks" ,
168
- model_torch_compile_limitations : bool = False ,
166
+ do_benchmarks : bool = True ,
169
167
shape_gen_name : str = "square" ,
170
168
gemm_cache_filename : Optional [str ] = None ,
171
169
n_limit : Optional [int ] = None ,
172
170
):
173
171
"""
174
172
Args:
175
- * `gemm_time_strategy`:
176
- - `benchmarks`: use benchmarks for gemm times (more accurate for all shapes)
177
- - `roofline`: use roofline model for gemm times (only accurate for large shapes)
173
+ * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
178
174
* `shape_gen_name`: `llama`, `square`, or `sweep`
179
175
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
180
176
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
181
177
"""
182
178
183
- print (f"gemm_time_strategy : { gemm_time_strategy } " )
179
+ print (f"do_benchmarks : { do_benchmarks } " )
184
180
print (f"shape_gen_name: { shape_gen_name } " )
185
181
186
- assert gemm_time_strategy in (
187
- "benchmarks" ,
188
- "roofline" ,
189
- ), "`gemm_time_strategy` must be 'benchmarks' or 'roofline'"
190
-
191
182
M , K , N = sympy .symbols ("M K N" )
192
183
193
- fp8_mem_time_sympy_dyn_limit = get_float8_mem_sympy (
194
- M ,
195
- K ,
196
- N ,
197
- model_torch_compile_limitations = True ,
198
- )
199
184
fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy (
200
185
M ,
201
186
K ,
202
187
N ,
203
- model_torch_compile_limitations = False ,
204
188
)
205
189
206
- if gemm_time_strategy == "roofline" :
207
- bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
208
- print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
209
- fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
210
- print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
211
- print ()
212
- else :
213
- print ()
190
+ bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
191
+ print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
192
+ fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
193
+ print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
194
+ print ()
214
195
215
196
headers = [
216
197
"fwd_M" ,
217
198
"fwd_K" ,
218
199
"fwd_N" ,
219
- # gemm microbenchmarks
220
- "bf16_gemm_s" ,
221
- "fp8_gemm_s" ,
222
- # roofline memory overhead estimates
223
- "fp8_oh_estimated" ,
224
- "fp8_oh_ideal" ,
225
- # actual e2e measurements
226
- "bf16_s" ,
227
- "fp8_dyn_s" ,
228
- "fp8_dyn_sp" ,
200
+ # roofline - gemm time (fwd + bwd, 3 gemms)
201
+ "r_bf16_gemm_s" ,
202
+ "r_fp8_gemm_s" ,
203
+ # roofline - fp8 overhead time (by counting reads/writes in the ideal case)
204
+ "r_fp8_ovhd_s" ,
205
+ # roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid)
206
+ "r_fp8_gemm_and_ovhd_s" ,
207
+ "r_fp8_gemm_and_ovhd_spdp" ,
208
+ # benchmarks - gemm time (fwd + bwd, 3 gemms)
209
+ "b_bf16_gemm_s" ,
210
+ "b_fp8_gemm_s" ,
211
+ # benchmarks - e2e LNLinearSigmoid time fwd + bwd
212
+ "b_bf16_e2e_s" ,
213
+ "b_fp8_e2e_s" ,
214
+ # note that e2e speedup is not the same as the roofline speedup:
215
+ # 1. roofline speedup: (bf16_gemm_time) / (fp8_gemm_time + fp8_ovhd_time)
216
+ # 2. e2e speedup: (ln + bf16_gemm_time + sigmoid) / (ln + fp8_gemm_time + fp8_ovhd_time + sigmoid)
217
+ # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218
+ # we don't break them out and don't have a roofline for them.
219
+ "b_fp8_e2e_spdp" ,
229
220
]
230
221
results = []
231
222
@@ -235,7 +226,18 @@ def run(
235
226
if n_limit is not None and idx >= n_limit :
236
227
break
237
228
238
- if gemm_time_strategy == "benchmarks" :
229
+ # use roofline model to estimate gemm time
230
+ # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
231
+ r_bf16_gemm_time_s = float (
232
+ bf16_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
233
+ )
234
+ r_fp8_gemm_time_s = float (
235
+ fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
236
+ )
237
+
238
+ # if enabled, also measured observed gemm time
239
+ b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
240
+ if do_benchmarks :
239
241
bf16_g1 , f8_g1 = get_gemm_times (
240
242
M_val , K_val , N_val , True , gemm_cache_filename
241
243
)
@@ -245,60 +247,58 @@ def run(
245
247
bf16_g3 , f8_g3 = get_gemm_times (
246
248
K_val , M_val , N_val , False , gemm_cache_filename
247
249
)
248
- bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
249
- fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
250
- else :
251
- assert gemm_time_strategy == "roofline" , "unsupported"
252
- bf16_time_val = (
253
- bf16_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
254
- )
255
- fp8_gemm_time_s = (
256
- fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
257
- )
250
+ b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251
+ b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
258
252
259
- fp8_mem_time_dyn_limit_s = (
260
- fp8_mem_time_sympy_dyn_limit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
261
- )
262
- fp8_mem_time_dyn_nolimit_s = (
253
+ # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254
+ r_fp8_ovhd_time_s = float (
263
255
fp8_mem_time_sympy_dyn_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
264
256
)
265
257
266
- # create the model
267
- m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
268
- x = torch .randn (
269
- M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
270
- ).requires_grad_ ()
258
+ b_bf16_e2e_time_s , b_fp8_e2e_time_s = 0 , 0
259
+ if do_benchmarks :
260
+ # create the model
261
+ m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
262
+ x = torch .randn (
263
+ M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
264
+ ).requires_grad_ ()
271
265
272
- # get the bf16 gpu kernel time
273
- torch ._dynamo .reset ()
274
- m_bf16 = torch .compile (copy .deepcopy (m_orig ))
275
- bf16_time_actual_s = get_gpu_kernel_time (m_bf16 , x )
266
+ # get the bf16 gpu kernel time
267
+ torch ._dynamo .reset ()
268
+ m_bf16 = torch .compile (copy .deepcopy (m_orig ))
269
+ b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x )
276
270
277
- # get the float8 dynamic scaling gpu kernel time
271
+ # get the float8 dynamic scaling gpu kernel time
278
272
279
- torch ._dynamo .reset ()
280
- m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
281
- m_fp8_dyn = torch .compile (m_fp8_dyn )
282
- fp8_dyn_time_actual_s = get_gpu_kernel_time (m_fp8_dyn , x )
273
+ torch ._dynamo .reset ()
274
+ m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
275
+ m_fp8_dyn = torch .compile (m_fp8_dyn )
276
+ b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x )
283
277
284
278
results .append (
285
279
[
286
280
M_val ,
287
281
K_val ,
288
282
N_val ,
289
- # gemm microbenchmarks
290
- bf16_time_val ,
291
- fp8_gemm_time_s ,
292
- # roofline overhead estimates
293
- fp8_mem_time_dyn_limit_s ,
294
- fp8_mem_time_dyn_nolimit_s ,
295
- # e2e numbers
296
- bf16_time_actual_s ,
297
- fp8_dyn_time_actual_s ,
298
- bf16_time_actual_s / fp8_dyn_time_actual_s ,
283
+ # roofline - gemm
284
+ r_bf16_gemm_time_s ,
285
+ r_fp8_gemm_time_s ,
286
+ # roofline - fp8 overhead
287
+ r_fp8_ovhd_time_s ,
288
+ # roofline - gemm + overhead, and speedup
289
+ r_fp8_gemm_time_s + r_fp8_ovhd_time_s ,
290
+ r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s ),
291
+ # benchmarks - gemm
292
+ b_bf16_gemm_time_s ,
293
+ b_fp8_gemm_time_s ,
294
+ # benchmarks - e2e, and speedup
295
+ b_bf16_e2e_time_s ,
296
+ b_fp8_e2e_time_s ,
297
+ b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20 ),
299
298
]
300
299
)
301
300
301
+ pd .set_option ("display.precision" , 2 )
302
302
df = pd .DataFrame (results , columns = headers )
303
303
print (df )
304
304
df .to_csv (outfile )
0 commit comments