|
1 | 1 | ---
|
2 | 2 | layout: blog_detail
|
3 | 3 | title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention"
|
4 |
| -author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong" |
| 4 | +author: "Team PyTorch: Driss Guessous, Yanbo Liang, Joy Dong, Horace He" |
5 | 5 | ---
|
6 | 6 |
|
7 | 7 | {:style="width:100%"}
|
@@ -131,7 +131,7 @@ Alibi is similar to relative positional encodings with one exception \- it has a
|
131 | 131 | alibi_bias = generate_alibi_bias() # [num_heads]
|
132 | 132 |
|
133 | 133 | def alibi(score, b, h, q_idx, kv_idx):
|
134 |
| - bias = alibi_bias[h] * (q_idx - kv_idx) |
| 134 | + bias = alibi_bias[h] * (kv_idx - q_idx) |
135 | 135 | return score + bias
|
136 | 136 | ```
|
137 | 137 |
|
@@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx):
|
218 | 218 | return causal_mask & window_mask
|
219 | 219 |
|
220 | 220 | # If you want to be cute...
|
221 |
| -from torch.nn.attention import or_masks |
| 221 | +from torch.nn.attention import and_masks |
222 | 222 |
|
223 | 223 | def sliding_window(b, h, q_idx, kv_idx)
|
224 | 224 | return q_idx - kv_idx <= SLIDING_WINDOW
|
225 | 225 |
|
226 |
| -sliding_window_causal = or_masks(causal_mask, sliding_window) |
| 226 | +sliding_window_causal = and_masks(causal_mask, sliding_window) |
227 | 227 | ```
|
228 | 228 |
|
229 | 229 | We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity.
|
@@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti
|
479 | 479 | - The Jax team's work on SplashAttention
|
480 | 480 | - Philippe Tillet and Keren Zhou for helping us with Triton
|
481 | 481 | - Ali Hassani for discussions on neighborhood attention
|
482 |
| -- Everybody who's complained about attention kernels not supporting their favorite attention variant :) |
| 482 | +- Everybody who's complained about attention kernels not supporting their favorite attention variant :) |
0 commit comments