Skip to content

Conversation

@Jeff-Huang
Copy link
Contributor

Introduces support for padded sequence lengths in the backward pass of the variable-length flash attention (fmha_v3_varlen_bwd).

  • Updated Python and C++ function signatures to accept optional cu_seqlens_q_padded and cu_seqlens_k_padded arguments.
  • Modified the underlying CUDA kernels and code generation scripts to pass padding information via the new seqlen_q_ptr and seqlen_k_ptr fields in the CK fmha_bwd_args struct.
  • Modified the underlying kernels and code generation scripts to correctly handle pointers for both padded and unpadded sequence data.
  • Added comprehensive gradient verification to the test suite (test_mha_varlen.py) to ensure the correctness of the backward pass with various padding scenarios.

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@Jeff-Huang Jeff-Huang force-pushed the SWDEV-558893 branch 2 times, most recently from 7486448 to 67ca00e Compare October 20, 2025 05:25
Introduces support for padded sequence lengths in the backward pass of the variable-length flash attention (fmha_v3_varlen_bwd).
- Updated Python and C++ function signatures to accept optional `cu_seqlens_q_padded` and `cu_seqlens_k_padded` arguments.
- Modified the underlying CUDA kernels and code generation scripts to pass padding information via the new `seqlen_q_ptr` and `seqlen_k_ptr` fields in
     the CK `fmha_bwd_args` struct.
- Modified the underlying kernels and code generation scripts to correctly handle pointers for both padded and unpadded sequence data.
- Added comprehensive gradient verification to the test suite (`test_mha_varlen.py`) to ensure the correctness of the backward pass with various
     padding scenarios.
Refactor the FMHA forward and backward pass to align with the updated padding API in `composable_kernel`.

- Argument Simplification: Removed the manual calculation of `seqlen_q` and `seqlen_k` from `cu_seqlens` in the `mha.cu` interface. The underlying kernels now handle this logic.
- API Alignment: Updated the arguments passed to `aiter::mha_fwd` and `aiter::mha_bwd` to match the new `composable_kernel` API. This involves passing `cu_seqlen` pointers directly.
- Kernel Interface Update: Modified the `codegen.py` scripts for `gfx942` and `gfx950` to reflect the changes in the kernel's function signatures and argument handling for padded and unpadded sequence lengths.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant