forked from littsk/test_attn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbenchmark_attention.py
284 lines (246 loc) · 9.35 KB
/
benchmark_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import math
from contextlib import contextmanager
from typing import Literal, Tuple
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from hopper.flash_attn_interface import flash_attn_func as flash_attn_func_hopper
from torch.nn.attention import SDPBackend
from torch.nn.attention._flex_attention import _flex_attention
from transformer_engine.pytorch.attention import DotProductAttention
from triton.ops import attention as attention_triton
try:
import xformers.ops
import xformers.ops.fmha as fmha
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
compiled_flex_attention = torch.compile(_flex_attention, dynamic=False)
def compiled_flash_attention_v2(q, k, v):
#kv_len_max = 1024 * 12
#torch._dynamo.mark_dynamic(k, 1, min=512, max=kv_len_max)
flash_attention_v2 = torch.compile(
flash_attn_func,
fullgraph=True,
backend="inductor",
mode="max-autotune-no-cudagraphs",
)
return flash_attention_v2(q, k, v)
def compiled_xformers_flash_hopper(q, k, v):
xformers_flash3 = torch.compile(
xformers.ops.fmha.flash3.FwOp,
fullgraph=True,
backend="inductor",
mode="max-autotune",
)
softmax_scale = q.size(-1) ** -0.5
return fmha.memory_efficient_attention_forward( # noqa: E731
q,
k,
v,
scale=softmax_scale,
op=xformers_flash3,
)
@contextmanager
def time_with_cuda_event(name, flops):
"""
Context manager to time CUDA operations and compute MFU.
Args:
name (str): Name of the attention implementation.
flops (float): Number of floating-point operations.
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.nvtx.range_push(name)
start.record()
yield
end.record()
torch.cuda.nvtx.range_pop()
end.synchronize()
elapsed_time = start.elapsed_time(end)
mfu = flops / (elapsed_time * 0.989 * 1e12)
print(f"{name} took {elapsed_time:.4f} ms, mfu: {mfu:.2f}")
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float = None,
) -> torch.Tensor:
query_len, key_len = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(query_len, key_len, dtype=query.dtype, device=query.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(
query_len, key_len, dtype=torch.bool, device=query.device
).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
def get_qkv(
batch_size: int,
num_heads: int,
q_len: int,
kv_len: int,
head_dim: int,
mqa: bool,
layout: Literal["bhsd", "sbhd"] = "bhsd"
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert layout in ["bhsd", "sbhd"]
if mqa:
q = torch.randn(
batch_size, num_heads, q_len, head_dim, device="cuda", dtype=torch.float16
)
k = torch.randn(
batch_size, 1, kv_len, head_dim, device="cuda", dtype=torch.float16
)
v = torch.randn(
batch_size, 1, kv_len, head_dim, device="cuda", dtype=torch.float16
)
else:
q = torch.randn(
batch_size, num_heads, q_len, head_dim, device="cuda", dtype=torch.float16
)
k = torch.randn(
batch_size, num_heads, kv_len, head_dim, device="cuda", dtype=torch.float16
)
v = torch.randn(
batch_size, num_heads, kv_len, head_dim, device="cuda", dtype=torch.float16
)
if layout == "sbhd":
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
q.requires_grad_()
k.requires_grad_()
v.requires_grad_()
return q, k, v
if __name__ == "__main__":
batch_size = 1
num_heads = 8
head_dim = 128
q_len = 1024 * 12
kv_lens = [512, q_len]
warmup_iter = 10
test_iter = 100
mqa = False
for kv_len in kv_lens:
torch.cuda.empty_cache()
print(f"========== kv_len={kv_len} ==========")
"""
FA -> (batch, seqlen, nheads, headdim)
Torch sdpa expects -> (batch, nheads, seqlen, headdim)
"""
q, k, v = get_qkv(
batch_size, num_heads, q_len, kv_len, head_dim, mqa, layout="bhsd"
)
softmax_scale = q.size(-1) ** -0.5
flops = 4 * q_len * kv_len * num_heads * head_dim * test_iter
flops_bwd = flops * 2
torch.cuda.profiler.start()
for _ in range(warmup_iter):
scaled_dot_product_attention(q, k, v, is_causal=False)
with time_with_cuda_event(
f"scaled_dot_product_attention_torch_fwd, kv_len={kv_len}", flops
):
for _ in range(test_iter):
attn_output_torch = scaled_dot_product_attention(
q, k, v, is_causal=False
)
for _ in range(warmup_iter):
with torch.nn.attention.sdpa_kernel(SDPBackend.MATH):
_ = torch.nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=False
)
with time_with_cuda_event(
f"scaled_dot_product_attention_torch_math_fwd, kv_len={kv_len}", flops
):
with torch.nn.attention.sdpa_kernel(SDPBackend.MATH):
for _ in range(test_iter):
attn_output_torch_math = (
torch.nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=False
)
)
for _ in range(warmup_iter):
_ = compiled_flex_attention(q, k, v)
with time_with_cuda_event("flex_attention_fwd", flops):
for _ in range(test_iter):
flex_attn_output = compiled_flex_attention(q, k, v)
q, k, v = get_qkv(
batch_size, num_heads, q_len, kv_len, head_dim, mqa, layout="sbhd"
)
te_fused_attn = DotProductAttention(
num_attention_heads=num_heads,
kv_channels=head_dim,
qkv_format="bshd",
attn_mask_type="no_mask",
num_gqa_groups=1 if mqa else None,
)
for _ in range(warmup_iter):
_ = te_fused_attn(q, k, v)
with time_with_cuda_event(f"cudnn_attention_fwd, kv_len={kv_len}", flops):
for _ in range(test_iter):
attn_output_te = te_fused_attn(q, k, v)
for _ in range(warmup_iter):
_ = flash_attn_func(q, k, v)
with time_with_cuda_event(f"flash_attn_func_v2_fwd, kv_len={kv_len}", flops):
for _ in range(test_iter):
attn_output_fa_v2 = flash_attn_func(q, k, v)
for _ in range(warmup_iter):
_ = compiled_flash_attention_v2(q, k, v)
with time_with_cuda_event(f"flash_attn_func_v2_fwd_compiled, kv_len={kv_len}", flops):
for _ in range(test_iter):
with torch.no_grad():
attn_output_fa_v2_compiled = compiled_flash_attention_v2(q, k, v)
for _ in range(warmup_iter):
_ = fmha.memory_efficient_attention_forward( # noqa: E731
q,
k,
v,
scale=softmax_scale,
op=xformers.ops.fmha.cutlass.FwOp,
)
with time_with_cuda_event(
f"xformers_flash_attn_cutlass_fwd, kv_len={kv_len}", flops
):
for _ in range(test_iter):
_ = fmha.memory_efficient_attention_forward( # noqa: E731
q,
k,
v,
scale=softmax_scale,
op=xformers.ops.fmha.cutlass.FwOp,
)
for _ in range(warmup_iter):
_ = compiled_xformers_flash_hopper(q, k, v)
with time_with_cuda_event(f"xformers_flash_hopper_fwd, kv_len={kv_len}", flops):
for _ in range(test_iter):
attn_output_xformers_flash_hopper = compiled_xformers_flash_hopper(q, k, v)
for _ in range(warmup_iter):
_ = attention_triton(q, k, v, True, softmax_scale)
with time_with_cuda_event(f"triton_attention_fwd, kv_len={kv_len}", flops):
for _ in range(test_iter):
attn_output_triton = attention_triton(q, k, v, True, softmax_scale)
for _ in range(warmup_iter):
_, _ = flash_attn_func_hopper(q, k, v)
with time_with_cuda_event(
f"flash_attn_func_hopper_fwd, kv_len={kv_len}", flops
):
for _ in range(test_iter):
attn_output_fa_hopper, softmax_lse = flash_attn_func_hopper(q, k, v)
torch.cuda.profiler.stop()