-
Notifications
You must be signed in to change notification settings - Fork 789
[Fix] Pads query_start_loc to satisfy FIA/TND constraint #6357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fix to satisfy a layout constraint for the FIA/TND operator in full cudagraph mode by padding the query_start_loc buffer. The changes include increasing the buffer size, centralizing the padding logic into a new helper function _pad_query_start_loc_for_fia, and adding an assertion to ensure the constraint is met. This is a good improvement for correctness and maintainability.
I've found one critical issue in the new helper function related to a slice mismatch that would cause a runtime error. Please see my specific comment for details and a suggested fix.
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
bf4cfa6 to
9b1c177
Compare
wangxiyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enable the e2e test the same as #6284 ?
yiz-liu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Enable other skipped test cases
- Check other persistent buffer
- Check speculative decoding (workaround + reuse)
| if num_tokens_padded == num_reqs_padded * self.uniform_decode_query_len: | ||
| # Uniform-batch case: num_reqs must be no greater than num_reqs_padded | ||
| assert num_reqs <= num_reqs_padded | ||
|
|
||
| last_loc = self.query_start_loc.np[num_reqs] | ||
| self.query_start_loc.np[num_reqs + 1 : num_reqs_padded + 1] = ( | ||
| self.arange_np[1 : num_reqs_padded + 1 - num_reqs] | ||
| * self.uniform_decode_query_len | ||
| + last_loc | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[0, 1, 2] -> [0, 1, 2, 3, 4]
| else: | ||
| # Mixed-batch case: num_reqs must equal num_reqs_padded | ||
| assert num_reqs == num_reqs_padded | ||
|
|
||
| # Insert a dummy request instead of setting query_start_loc[num_reqs] = num_tokens_padded directly | ||
| self.query_start_loc.np[num_reqs_padded + 1] = num_tokens_padded | ||
| num_reqs_padded = num_reqs_padded + 1 | ||
|
|
||
| self.query_start_loc.copy_to_gpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[0, 3] -> [0, 3, 4]
| # NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding. | ||
| # See _pad_query_start_loc_for_fia. | ||
| self.query_start_loc = self._make_buffer( | ||
| self.max_num_reqs + 2, dtype=torch.int32 # type: ignore[has-type] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check if other buffers should be extended too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This strange, no error by now.
| if cudagraph_mode != CUDAGraphMode.NONE: | ||
| num_reqs_padded = self._pad_query_start_loc_for_fia( | ||
| num_tokens_padded, num_reqs_padded, num_reqs | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe current_platform.post_process_after_padding
Adds a larger reserve (+2) for the query_start_loc buffer in FULL cudagraph mode and introduces a helper to pad it so the first dimension of hidden_states equals the final element of actual_seq_lengths_q required by the FIA/TND operator. Handles both uniform and mixed batches (inserting a dummy request for mixed batches), moves ad-hoc padding into a single helper, copies the updated buffer to the device, and asserts the layout constraint before building attention metadata. These changes prevent kernel mismatches/failures and ensure correct shapes for FIA/TND execution in full graph modes. Signed-off-by: Yizhou Liu <[email protected]>
Signed-off-by: Yizhou Liu <[email protected]>
…m-project#6357)" This reverts commit 56f5d3b. Signed-off-by: wangli <[email protected]>
…onstraint (#6459) This reverts commit 56f5d3b. ### What this PR does / why we need it? The patch #6357 which break the functionality availability in the spec_decode scenario, let's revert and make CI happy first ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.1 - vLLM main: vllm-project/vllm@dc917cc Signed-off-by: wangli <[email protected]>
What this PR does / why we need it?
This handles both uniform and mixed batches (by inserting a dummy request for mixed batches), consolidates ad-hoc padding into a single helper, copies the updated buffer to the device, and asserts the layout constraint before building the attention metadata. Together, these changes prevent kernel mismatches or failures and ensure correct shapes for FIA/TND execution in full graph modes.
We currently place this helper in
execute_model. My original design was to include it in_prepare_inputs, but that doesn’t work because it must run after padding. While I’d prefer to minimize the impact and reuse as much of the base class as possible in the future, it doesn’t seem achievable at the moment.Does this PR introduce any user-facing change?
None.
How was this patch tested?
Test cases added.