Skip to content

Commit 215e891

Browse files
improve fp8 blockwise gemm perf
1 parent 1526dfe commit 215e891

File tree

3 files changed

+200
-25
lines changed

3 files changed

+200
-25
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from triton.testing import do_bench
16+
17+
from torchao.prototype.blockwise_fp8_training.kernels import (
18+
blockwise_fp8_gemm_1x128_128x128,
19+
fp8_blockwise_act_quant_lhs,
20+
fp8_blockwise_weight_quant_transposed_rhs,
21+
)
22+
23+
device = torch.device("cuda")
24+
25+
# Needed since changing args to function causes recompiles
26+
torch._dynamo.config.cache_size_limit = 1000
27+
28+
29+
@dataclass(frozen=True)
30+
class ExperimentConfig:
31+
out_dtype: torch.dtype
32+
m: int
33+
n: int
34+
k: int
35+
36+
37+
@dataclass(frozen=True)
38+
class ExperimentResult:
39+
triton_time_us: float
40+
41+
42+
@dataclass(frozen=True)
43+
class Experiment:
44+
config: ExperimentConfig
45+
result: ExperimentResult
46+
47+
48+
def get_configs() -> List[ExperimentConfig]:
49+
mnk_list = [
50+
(16640, 8192, 5120),
51+
(16640, 5120, 8192),
52+
]
53+
out_dtypes = [torch.float32, torch.bfloat16]
54+
configs = []
55+
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
56+
m, n, k = mnk
57+
configs.append(
58+
ExperimentConfig(
59+
out_dtype=out_dtype,
60+
m=m,
61+
n=n,
62+
k=k,
63+
)
64+
)
65+
return configs
66+
67+
68+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
69+
# define test inputs
70+
# Simulate output = input @ weight.T
71+
M, N, K = config.m, config.n, config.k
72+
A = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
73+
B = torch.randn(N, K, dtype=config.out_dtype, device="cuda")
74+
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=torch.float8_e4m3fn)
75+
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(
76+
B, dtype=torch.float8_e4m3fn
77+
)
78+
79+
def warmup(func, *args, **kwargs):
80+
for _ in range(10):
81+
func(*args, **kwargs)
82+
83+
# Warm up then run triton bench
84+
warmup(
85+
blockwise_fp8_gemm_1x128_128x128,
86+
A_q,
87+
1.0 / A_s,
88+
B_t_q,
89+
1.0 / B_t_s,
90+
)
91+
92+
triton_time_us = benchmark_cuda_function_in_microseconds(
93+
blockwise_fp8_gemm_1x128_128x128,
94+
A_q,
95+
1.0 / A_s,
96+
B_t_q,
97+
1.0 / B_t_s,
98+
)
99+
100+
return ExperimentResult(
101+
triton_time_us=triton_time_us,
102+
)
103+
104+
105+
def print_results(experiments: List[Experiment]):
106+
headers = [
107+
"M",
108+
"N",
109+
"K",
110+
"out_dtype",
111+
"triton_time_us",
112+
"tflops/sec",
113+
]
114+
rows = []
115+
for experiment in experiments:
116+
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
117+
flops = 2 * m * n * k
118+
seconds = experiment.result.triton_time_us / 1e6
119+
tflops_per_sec = (flops / seconds) / 1e12
120+
rows.append(
121+
[
122+
m,
123+
n,
124+
k,
125+
experiment.config.out_dtype,
126+
experiment.result.triton_time_us,
127+
tflops_per_sec,
128+
]
129+
)
130+
print(tabulate(rows, headers=headers))
131+
132+
133+
def benchmark_cuda_function_in_microseconds(f, *args):
134+
return do_bench(lambda: f(*args), return_mode="median") * 1e3
135+
136+
137+
def main():
138+
torch.random.manual_seed(123)
139+
configs = get_configs()
140+
results = []
141+
for config in tqdm(configs):
142+
result = run_experiment(config)
143+
results.append(Experiment(config=config, result=result))
144+
145+
# Use Tabulate to print results
146+
print_results(results)
147+
148+
149+
if __name__ == "__main__":
150+
main()

torchao/prototype/blockwise_fp8_training/kernels.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
import triton
1111
import triton.language as tl
1212

13+
from torchao.prototype.moe_training.utils import (
14+
_is_column_major,
15+
_is_row_major,
16+
)
17+
1318
fp8_gemm_configs_max_autotune = [
14-
# Small
15-
triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64}, num_warps=2),
16-
# Medium
17-
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128}, num_warps=4),
18-
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_warps=4),
19-
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_warps=4),
20-
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256}, num_warps=8),
21-
# Large
22-
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_warps=8),
23-
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_warps=8),
24-
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256}, num_warps=4),
25-
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_warps=4),
26-
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_warps=8),
19+
triton.Config(
20+
{"BLOCK_SIZE_M": block_size, "BLOCK_SIZE_N": block_size},
21+
num_warps=num_warps,
22+
num_stages=num_stages,
23+
)
24+
for block_size in [64, 128, 256]
25+
for num_warps in [4, 8]
26+
for num_stages in [2, 4]
2727
]
2828

2929
# For fast compile times during development.
@@ -57,6 +57,7 @@ def blockwise_fp8_gemm_1x128_128x128_kernel(
5757
M,
5858
N: tl.constexpr,
5959
K: tl.constexpr,
60+
out_dtype: tl.constexpr,
6061
BLOCK_SIZE_M: tl.constexpr,
6162
BLOCK_SIZE_N: tl.constexpr,
6263
BLOCK_SIZE_K: tl.constexpr,
@@ -81,18 +82,16 @@ def blockwise_fp8_gemm_1x128_128x128_kernel(
8182
a_s_base_ptr = a_s_ptr + offs_m * a_s_stride_dim_0
8283
b_s_base_ptr = b_s_ptr + (offs_n // BLOCK_SIZE_K) * b_s_stride_dim_1
8384
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
85+
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
86+
b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
8487
for k in range(0, k_num_blocks):
85-
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
8688
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
87-
88-
b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
8989
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
9090

9191
# Reciprocal scales to scale back to dynamic range of output dtype
9292
a_s = tl.load(a_s_base_ptr + k * a_s_stride_dim_1)
9393
b_s = tl.load(b_s_base_ptr + k * b_s_stride_dim_0)
94-
95-
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
94+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s
9695

9796
a_ptrs += BLOCK_SIZE_K * a_stride_dim_1
9897
b_ptrs += BLOCK_SIZE_K * b_stride_dim_0
@@ -109,14 +108,22 @@ def blockwise_fp8_gemm_1x128_128x128(
109108
b: torch.Tensor, # (K, N)
110109
b_s: torch.Tensor, # (K // block_size, N // block_size)
111110
block_size: int = 128,
111+
out_dtype: torch.dtype = torch.float32,
112112
):
113113
# 'a' must be in row-major layout, 'b' must be in column-major layout
114-
assert a.is_contiguous() and not b.is_contiguous()
115-
assert a_s.is_contiguous() and b_s.is_contiguous()
114+
assert _is_row_major(a) and _is_column_major(b), (
115+
"a must be row-major, b must be column-major"
116+
)
117+
118+
# a_scales must be row-major, b_scales must be column-major
119+
assert _is_row_major(a_s) and _is_column_major(b_s), (
120+
"a_s must be row-major, b_s must be column-major"
121+
)
122+
116123
M = a.size(0)
117124
K = a.size(1)
118125
N = b.size(1)
119-
c = a.new_empty(M, N, dtype=torch.bfloat16)
126+
c = a.new_empty(M, N, dtype=out_dtype)
120127
grid = lambda META: (
121128
triton.cdiv(M, META["BLOCK_SIZE_M"]),
122129
triton.cdiv(N, META["BLOCK_SIZE_N"]),
@@ -140,6 +147,7 @@ def blockwise_fp8_gemm_1x128_128x128(
140147
M,
141148
N,
142149
K,
150+
out_dtype=out_dtype,
143151
BLOCK_SIZE_K=block_size,
144152
)
145153
return c
@@ -217,14 +225,15 @@ def blockwise_fp8_gemm_1x128_128x1(
217225
b: torch.Tensor, # (K, N)
218226
b_s: torch.Tensor, # (K // block_size, N) reciprocals of scales
219227
block_size: int = 128,
228+
out_dtype: torch.dtype = torch.float32,
220229
):
221230
# 'a' must be in row-major layout, 'b' must be in column-major layout
222231
assert a.is_contiguous() and not b.is_contiguous()
223232
assert a_s.is_contiguous() and b_s.is_contiguous()
224233
M = a.size(0)
225234
K = a.size(1)
226235
N = b.size(1)
227-
c = a.new_empty(M, N, dtype=torch.bfloat16)
236+
c = a.new_empty(M, N, dtype=out_dtype)
228237
grid = lambda META: (
229238
triton.cdiv(M, META["BLOCK_SIZE_M"]),
230239
triton.cdiv(N, META["BLOCK_SIZE_N"]),
@@ -674,8 +683,10 @@ def fp8_blockwise_weight_quant_transposed_rhs(
674683
M, N = x.size()
675684
y = torch.empty(N, M, dtype=dtype, device=x.device)
676685
y = y.as_strided(y.size(), (1, y.size(0))) # Column major
677-
s = x.new_empty(
678-
triton.cdiv(N, block_size), triton.cdiv(M, block_size), dtype=torch.float32
686+
n_blocks, m_blocks = triton.cdiv(N, block_size), triton.cdiv(M, block_size)
687+
s = x.new_empty(n_blocks, m_blocks, dtype=torch.float32).as_strided(
688+
(n_blocks, m_blocks), # shape
689+
(1, n_blocks), # stride
679690
)
680691
grid = lambda meta: (
681692
triton.cdiv(M, meta["BLOCK_SIZE"]),

torchao/prototype/moe_training/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,21 @@ def _is_column_major(x: torch.Tensor) -> bool:
290290
A boolean indicating whether the input tensor is column-major.
291291
"""
292292
assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D"
293-
return x.stride(-2) == 1 and x.stride(-1) > 1
293+
return x.stride(-2) == 1
294+
295+
296+
def _is_row_major(x: torch.Tensor) -> bool:
297+
"""
298+
This function checks if the input tensor is row-major.
299+
300+
Args:
301+
x (torch.Tensor): The input tensor to be checked.
302+
303+
Returns:
304+
A boolean indicating whether the input tensor is row-major.
305+
"""
306+
assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D"
307+
return x.stride(-1) == 1
294308

295309

296310
def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"):

0 commit comments

Comments
 (0)