-
Notifications
You must be signed in to change notification settings - Fork 30.5k
Open
Description
The score_mod
function passed to flex_attention
should operate on the pre-softmax attention scores, but the snippet below appears to be applying the attention biases (s_aux
) and computing the post-softmax scores.
transformers/src/transformers/integrations/flex_attention.py
Lines 275 to 280 in 67097bf
if s_aux is not None: | |
logits_max = torch.max(score, dim=-1, keepdim=True).values | |
sinks = torch.exp(s_aux - logits_max) | |
unnormalized_scores = torch.exp(score - logits_max) | |
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks | |
score = unnormalized_scores / normalizer |
I don't think it is possible to apply (gpt-oss-style) attention sinks using the score_mod
alone, but you can do it by passing return_lse=True
to flex_attention
and renormalising using the extra return value. If someone can point me to where unit tests for this code should live I'm happy to PR a fix.
Flakes342
Metadata
Metadata
Assignees
Labels
No labels