Skip to content

Commit 67fc54c

Browse files
Chilleecjyabraham
andauthored
Update 2024-08-07-flexattention.md (#1707)
* Update 2024-08-07-flexattention.md * Update 2024-08-07-flexattention.md --------- Co-authored-by: Chris Abraham <[email protected]>
1 parent 25636f5 commit 67fc54c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

_posts/2024-08-07-flexattention.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
layout: blog_detail
33
title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention"
4-
author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong"
4+
author: "Team PyTorch: Driss Guessous, Yanbo Liang, Joy Dong, Horace He"
55
---
66

77
![a cartoon chart flexing his muscles](/assets/images/flexattention/fg1.jpg){:style="width:100%"}
@@ -131,7 +131,7 @@ Alibi is similar to relative positional encodings with one exception \- it has a
131131
alibi_bias = generate_alibi_bias() # [num_heads]
132132

133133
def alibi(score, b, h, q_idx, kv_idx):
134-
bias = alibi_bias[h] * (q_idx - kv_idx)
134+
bias = alibi_bias[h] * (kv_idx - q_idx)
135135
return score + bias
136136
```
137137

@@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx):
218218
return causal_mask & window_mask
219219

220220
# If you want to be cute...
221-
from torch.nn.attention import or_masks
221+
from torch.nn.attention import and_masks
222222

223223
def sliding_window(b, h, q_idx, kv_idx)
224224
return q_idx - kv_idx <= SLIDING_WINDOW
225225

226-
sliding_window_causal = or_masks(causal_mask, sliding_window)
226+
sliding_window_causal = and_masks(causal_mask, sliding_window)
227227
```
228228

229229
We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity.
@@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti
479479
- The Jax team's work on SplashAttention
480480
- Philippe Tillet and Keren Zhou for helping us with Triton
481481
- Ali Hassani for discussions on neighborhood attention
482-
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)
482+
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)

0 commit comments

Comments
 (0)