Skip to content

Attention sinks are not applied correctly in integrations.flex_attention #41026

@jonny-so

Description

@jonny-so

The score_mod function passed to flex_attention should operate on the pre-softmax attention scores, but the snippet below appears to be applying the attention biases (s_aux) and computing the post-softmax scores.

if s_aux is not None:
logits_max = torch.max(score, dim=-1, keepdim=True).values
sinks = torch.exp(s_aux - logits_max)
unnormalized_scores = torch.exp(score - logits_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
score = unnormalized_scores / normalizer

I don't think it is possible to apply (gpt-oss-style) attention sinks using the score_mod alone, but you can do it by passing return_lse=True to flex_attention and renormalising using the extra return value. If someone can point me to where unit tests for this code should live I'm happy to PR a fix.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions