|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +This is a script to estimate the benefit from converting a `torch.nn.Linear` |
| 9 | +layer to float8, by estimating the difference in e2e GPU kernel time between: |
| 10 | +1. bf16 gemms in fwd and bwd, and |
| 11 | +2. float8 gemms in fwd and bwd, and float8 overhead |
| 12 | +
|
| 13 | +The gemm times are estimated either from direct measurements via benchmarks, |
| 14 | +or with a roofline estimation based on TOPS and peak compute bandwidth of an |
| 15 | +NVIDIA H100. |
| 16 | +
|
| 17 | +The float8 overhead times are estimated by counting memory reads and writes |
| 18 | +based on the specified float8 scaling, and estimating that we can achieve |
| 19 | +a certain % of machine peak memory bandwidth when performing these reads and writes. |
| 20 | +
|
| 21 | +Additional context: |
| 22 | +1. the formulas for fwd/bwd gemms in a linear layer, with corresponding input |
| 23 | + and output sizes: |
| 24 | +
|
| 25 | + input @ weight_t = output |
| 26 | + MxK @ KxN => MxN |
| 27 | +
|
| 28 | + grad_output @ weight = grad_input |
| 29 | + MxN @ NxK => MxK |
| 30 | +
|
| 31 | + input_t @ grad_output = grad_weight |
| 32 | + KxM @ MxN => KxN |
| 33 | +
|
| 34 | +2. we properly model the worst-case of the current torch.compile limitations regarding |
| 35 | + float8 scaling |
| 36 | +3. assume for float8 activations/gradients that torch.compile will fuse to the |
| 37 | +preceding op. Note that this is not always true in practice. |
| 38 | +4. assume no AC (TODO model it) |
| 39 | +5. assume no float8 all-gather (TODO model it) |
| 40 | +""" |
| 41 | + |
| 42 | +import csv |
| 43 | +import copy |
| 44 | +import json |
| 45 | +import os |
| 46 | +import time |
| 47 | +from typing import Optional |
| 48 | + |
| 49 | +import fire |
| 50 | +import pandas as pd |
| 51 | +import sympy |
| 52 | +import tqdm |
| 53 | + |
| 54 | +import torch |
| 55 | +import torch.utils.benchmark as benchmark |
| 56 | +from torch.profiler import profile, ProfilerActivity, record_function |
| 57 | + |
| 58 | +from utils import ( |
| 59 | + get_name_to_shapes_iter, |
| 60 | + get_gpu_kernel_gemm_time_s, |
| 61 | + profiler_output_to_filtered_time_by_kernel_name, |
| 62 | +) |
| 63 | +from torchao.float8.roofline_utils import ( |
| 64 | + get_gemm_time_sympy, |
| 65 | + get_float8_mem_sympy, |
| 66 | +) |
| 67 | +from torchao.float8 import ( |
| 68 | + convert_to_float8_training, |
| 69 | + Float8LinearConfig, |
| 70 | + ScalingType, |
| 71 | + CastConfig, |
| 72 | +) |
| 73 | + |
| 74 | + |
| 75 | +class LNLinearSigmoid(torch.nn.Module): |
| 76 | + def __init__(self, fc_dim1, fc_dim2): |
| 77 | + super().__init__() |
| 78 | + self.ln = torch.nn.LayerNorm(fc_dim1, elementwise_affine=False) |
| 79 | + self.fc = torch.nn.Linear(fc_dim1, fc_dim2, bias=False) |
| 80 | + self.sigmoid = torch.nn.Sigmoid() |
| 81 | + |
| 82 | + def forward(self, x): |
| 83 | + x = self.ln(x) |
| 84 | + x = self.fc(x) |
| 85 | + x = self.sigmoid(x) |
| 86 | + return x |
| 87 | +# TODO(next): hook this up |
| 88 | + |
| 89 | + |
| 90 | +def benchmark_fn_in_sec(f, *args, **kwargs): |
| 91 | + # Manual warmup |
| 92 | + for _ in range(4): |
| 93 | + f(*args, **kwargs) |
| 94 | + t0 = benchmark.Timer( |
| 95 | + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} |
| 96 | + ) |
| 97 | + measurement = t0.blocked_autorange() |
| 98 | + return measurement.mean |
| 99 | + |
| 100 | + |
| 101 | +def get_gpu_kernel_time(m, x): |
| 102 | + # warm up |
| 103 | + # for _ in range(2): |
| 104 | + # m(x).sum().backward() |
| 105 | + |
| 106 | + # capture a profiling run |
| 107 | + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] |
| 108 | + n_iter = 5 |
| 109 | + with profile(activities=activities) as prof: |
| 110 | + for _ in range(n_iter): |
| 111 | + # m(x).sum().backward() |
| 112 | + torch.cuda.synchronize() |
| 113 | + # get the gpu kernel time and aggregate it |
| 114 | + num_leaf_tensors = 1 + len(list(m.parameters())) |
| 115 | + ref_times = profiler_output_to_filtered_time_by_kernel_name( |
| 116 | + prof, n_iter, num_leaf_tensors) |
| 117 | + total_time_s = sum(v for v in ref_times.values()) / 1e6 / n_iter |
| 118 | + return total_time_s |
| 119 | + |
| 120 | +def get_gemm_times(M, K, N, fast_accum, cache_filename=None): |
| 121 | + |
| 122 | + # Note: this is definitely not the best way to build a cache, |
| 123 | + # but it will do for now. |
| 124 | + if cache_filename is not None: |
| 125 | + if os.path.isfile(cache_filename): |
| 126 | + # cache already exists, use it |
| 127 | + with open(cache_filename, 'r') as f: |
| 128 | + cache = json.load(f) |
| 129 | + else: |
| 130 | + # cache does not exist yet, create it |
| 131 | + cache = dict() |
| 132 | + key = f"{M},{K},{N},{fast_accum}" |
| 133 | + if key in cache: |
| 134 | + return cache[key] |
| 135 | + |
| 136 | + device = torch.device('cuda') |
| 137 | + |
| 138 | + # bf16 time |
| 139 | + x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) |
| 140 | + w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() |
| 141 | + bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) |
| 142 | + |
| 143 | + # f8 time |
| 144 | + d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16 |
| 145 | + A = torch.zeros(M, K, device=device, dtype=d1) |
| 146 | + B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() |
| 147 | + scale_a = torch.tensor([1.0], device=device) |
| 148 | + scale_b = torch.tensor([1.0], device=device) |
| 149 | + |
| 150 | + def do_matmul(A, B): |
| 151 | + return torch._scaled_mm( |
| 152 | + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum |
| 153 | + ) |
| 154 | + f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) |
| 155 | + |
| 156 | + # save to cache if needed |
| 157 | + if cache_filename is not None: |
| 158 | + cache[key] = [bf16_time_s, f8_time_s] |
| 159 | + with open(cache_filename, 'w') as f: |
| 160 | + json.dump(cache, f) |
| 161 | + |
| 162 | + return bf16_time_s, f8_time_s |
| 163 | + |
| 164 | +def run( |
| 165 | + outfile: str, |
| 166 | + gemm_time_strategy: str = "benchmarks", |
| 167 | + model_torch_compile_limitations: bool = False, |
| 168 | + scaling_type_input: str = "dynamic", |
| 169 | + scaling_type_weight: str = "dynamic", |
| 170 | + scaling_type_grad_output: str = "dynamic", |
| 171 | + shape_gen_name: str = "square", |
| 172 | + gemm_cache_filename: Optional[str] = None, |
| 173 | + n_limit: Optional[int] = None, |
| 174 | +): |
| 175 | + """ |
| 176 | + Args: |
| 177 | + * `gemm_time_strategy`: |
| 178 | + - `benchmarks`: use benchmarks for gemm times (more accurate for all shapes) |
| 179 | + - `roofline`: use roofline model for gemm times (only accurate for large shapes) |
| 180 | + * `shape_gen_name`: `llama`, `square`, or `sweep` |
| 181 | + * `gemm_cache_filename (optional)`: file to cache gemm benchmark results |
| 182 | + * `n_limit (optional)`: if specified, only runs `n_limit` iterations |
| 183 | + """ |
| 184 | + |
| 185 | + print(f'gemm_time_strategy: {gemm_time_strategy}') |
| 186 | + print(f'shape_gen_name: {shape_gen_name}') |
| 187 | + |
| 188 | + assert gemm_time_strategy in ("benchmarks", "roofline"), \ |
| 189 | + "`gemm_time_strategy` must be 'benchmarks' or 'roofline'" |
| 190 | + |
| 191 | + M, K, N = sympy.symbols('M K N') |
| 192 | + |
| 193 | + fp8_mem_time_sympy_dyn_limit = get_float8_mem_sympy( |
| 194 | + M, K, N, |
| 195 | + model_torch_compile_limitations=True, |
| 196 | + scaling_type_input="dynamic", |
| 197 | + scaling_type_weight="dynamic", |
| 198 | + scaling_type_grad_output="dynamic", |
| 199 | + ) |
| 200 | + fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy( |
| 201 | + M, K, N, |
| 202 | + model_torch_compile_limitations=False, |
| 203 | + scaling_type_input="dynamic", |
| 204 | + scaling_type_weight="dynamic", |
| 205 | + scaling_type_grad_output="dynamic", |
| 206 | + ) |
| 207 | + fp8_mem_time_sympy_del_limit = get_float8_mem_sympy( |
| 208 | + M, K, N, |
| 209 | + model_torch_compile_limitations=True, |
| 210 | + scaling_type_input="delayed", |
| 211 | + scaling_type_weight="delayed", |
| 212 | + scaling_type_grad_output="delayed", |
| 213 | + ) |
| 214 | + fp8_mem_time_sympy_del_nolimit = get_float8_mem_sympy( |
| 215 | + M, K, N, |
| 216 | + model_torch_compile_limitations=False, |
| 217 | + scaling_type_input="delayed", |
| 218 | + scaling_type_weight="delayed", |
| 219 | + scaling_type_grad_output="delayed", |
| 220 | + ) |
| 221 | + |
| 222 | + if gemm_time_strategy == "roofline": |
| 223 | + bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16) |
| 224 | + print('bf16_gemm_time_sympy', bf16_gemm_time_sympy) |
| 225 | + fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn) |
| 226 | + print('fp8_gemm_time_sympy', fp8_gemm_time_sympy) |
| 227 | + print() |
| 228 | + else: |
| 229 | + print() |
| 230 | + |
| 231 | + headers = [ |
| 232 | + 'fwd_M', 'fwd_K', 'fwd_N', |
| 233 | + # gemm microbenchmarks |
| 234 | + 'bf16_gemm_s', 'fp8_gemm_s', |
| 235 | + # roofline memory overhead estimates |
| 236 | + 'fp8_oh_dyn_limit', 'fp8_oh_dyn_nolimit', |
| 237 | + 'fp8_oh_del_limit', 'fp8_oh_del_nolimit', |
| 238 | + # actual e2e measurements |
| 239 | + 'bf16_e2e_s', 'fp8_dyn_e2e_s', 'fp8_del_e2e_s', |
| 240 | + 'fp8_dyn_speedup', 'fp8_del_speedup', |
| 241 | + ] |
| 242 | + results = [] |
| 243 | + |
| 244 | + name_to_shapes = get_name_to_shapes_iter(shape_gen_name, None, None, None) |
| 245 | + |
| 246 | + for idx, (name, (M_val, K_val, N_val)) in enumerate(tqdm.tqdm(name_to_shapes)): |
| 247 | + if n_limit is not None and idx >= n_limit: |
| 248 | + break |
| 249 | + |
| 250 | + if gemm_time_strategy == "benchmarks": |
| 251 | + bf16_g1, f8_g1 = get_gemm_times(M_val, K_val, N_val, True, gemm_cache_filename) |
| 252 | + bf16_g2, f8_g2 = get_gemm_times(M_val, N_val, K_val, False, gemm_cache_filename) |
| 253 | + bf16_g3, f8_g3 = get_gemm_times(K_val, M_val, N_val, False, gemm_cache_filename) |
| 254 | + bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3 |
| 255 | + fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3 |
| 256 | + else: |
| 257 | + assert gemm_time_strategy == "roofline", "unsupported" |
| 258 | + bf16_time_val = bf16_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) |
| 259 | + fp8_gemm_time_s = fp8_gemm_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) |
| 260 | + |
| 261 | + fp8_mem_time_dyn_limit_s = \ |
| 262 | + fp8_mem_time_sympy_dyn_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) |
| 263 | + fp8_mem_time_dyn_nolimit_s = \ |
| 264 | + fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) |
| 265 | + fp8_mem_time_del_limit_s = \ |
| 266 | + fp8_mem_time_sympy_del_limit.subs(M, M_val).subs(K, K_val).subs(N, N_val) |
| 267 | + fp8_mem_time_del_nolimit_s = \ |
| 268 | + fp8_mem_time_sympy_del_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) |
| 269 | + |
| 270 | + # create the model |
| 271 | + m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16() |
| 272 | + x = torch.randn(M_val, K_val, dtype=torch.bfloat16, device="cuda").requires_grad_() |
| 273 | + |
| 274 | + # get the bf16 gpu kernel time |
| 275 | + torch._dynamo.reset() |
| 276 | + m_bf16 = torch.compile(copy.deepcopy(m_orig)) |
| 277 | + bf16_time_actual_s = get_gpu_kernel_time(m_bf16, x) |
| 278 | + |
| 279 | + # get the float8 dynamic scaling gpu kernel time |
| 280 | + torch._dynamo.reset() |
| 281 | + m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig)) |
| 282 | + m_fp8_dyn = torch.compile(m_fp8_dyn) |
| 283 | + fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x) |
| 284 | + |
| 285 | + # get the float8 delayed scaling gpu kernel time |
| 286 | + torch._dynamo.reset() |
| 287 | + config = Float8LinearConfig( |
| 288 | + enable_amax_init=False, |
| 289 | + enable_pre_and_post_forward=False, |
| 290 | + cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED), |
| 291 | + cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED), |
| 292 | + cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED), |
| 293 | + ) |
| 294 | + m_fp8_del = convert_to_float8_training(m_orig) |
| 295 | + m_fp8_del = torch.compile(m_fp8_del) |
| 296 | + fp8_del_time_actual_s = get_gpu_kernel_time(m_fp8_del, x) |
| 297 | + |
| 298 | + results.append([ |
| 299 | + M_val, K_val, N_val, |
| 300 | + # gemm microbenchmarks |
| 301 | + bf16_time_val, fp8_gemm_time_s, |
| 302 | + # roofline overhead estimates |
| 303 | + fp8_mem_time_dyn_limit_s, |
| 304 | + fp8_mem_time_dyn_nolimit_s, |
| 305 | + fp8_mem_time_del_limit_s, |
| 306 | + fp8_mem_time_del_nolimit_s, |
| 307 | + # e2e numbers |
| 308 | + bf16_time_actual_s, fp8_dyn_time_actual_s, fp8_del_time_actual_s, |
| 309 | + bf16_time_actual_s / fp8_dyn_time_actual_s, |
| 310 | + bf16_time_actual_s / fp8_del_time_actual_s, |
| 311 | + ]) |
| 312 | + |
| 313 | + df = pd.DataFrame(results, columns=headers) |
| 314 | + print(df) |
| 315 | + df.to_csv(outfile) |
| 316 | + print('done') |
| 317 | + |
| 318 | +if __name__ == '__main__': |
| 319 | + fire.Fire(run) |
0 commit comments