Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix attention bias broadcast #24017

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Fix attention bias broadcast #24017

wants to merge 3 commits into from

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Mar 13, 2025

Description

  • Fix broadcast on attention bias dim 1.
  • Increase test cases in test_mha.py in pipeline to cover the testing.

Motivation and Context

This feature was added in #21710.

There was bug when computing the offset when attention bias broadcast on dim 1 only in both CUDA and CPU kernel.

It can be triggered when attention bias shape is like [batch_size, 1, sequence_length, total_sequence_length] and batch_size > 1 when unfused kernel is selected. Note that cudnn flash attention and cutlass fused attention also supports attention bias, so the bug in unfused kernel was not discovered previously.

@tianleiwu tianleiwu requested a review from RyanUnderhill March 13, 2025 04:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants