-
Notifications
You must be signed in to change notification settings - Fork 789
[Fix] Pads query_start_loc to satisfy FIA/TND constraint #6357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,6 +206,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): | |
| vllm_config.scheduler_config.max_num_batched_tokens += max_pcp_pad_tokens | ||
| with _torch_cuda_wrapper(): | ||
| super().__init__(vllm_config, device) | ||
|
|
||
| # NOTE: For FULL mode we change +1 to +2 to reserve extra space for padding. | ||
| # See _pad_query_start_loc_for_fia. | ||
| self.query_start_loc = self._make_buffer( | ||
| self.max_num_reqs + 2, dtype=torch.int32 # type: ignore[has-type] | ||
| ) | ||
|
|
||
| vllm_config.scheduler_config.max_num_batched_tokens -= max_pcp_pad_tokens | ||
| self.max_num_tokens = self.scheduler_config.max_num_batched_tokens | ||
| self.max_num_reqs = self.scheduler_config.max_num_seqs | ||
|
|
@@ -509,6 +516,36 @@ def get_model(self) -> nn.Module: | |
| return self.model.unwrap() | ||
| return self.model | ||
|
|
||
| def _pad_query_start_loc_for_fia( | ||
| self, num_tokens_padded: int, num_reqs_padded: int, num_reqs: int | ||
| ) -> int: | ||
| """ | ||
| This function is only designed to satisfied the constraint that when the layout is TND, | ||
| the first dimension of `hidden_states` must equal the last element of `actual_seq_lengths_q`. | ||
| """ | ||
|
|
||
| if num_tokens_padded == num_reqs_padded * self.uniform_decode_query_len: | ||
| # Uniform-batch case: num_reqs must be no greater than num_reqs_padded | ||
| assert num_reqs <= num_reqs_padded | ||
|
|
||
| last_loc = self.query_start_loc.np[num_reqs] | ||
| self.query_start_loc.np[num_reqs + 1 : num_reqs_padded + 1] = ( | ||
| self.arange_np[1 : num_reqs_padded + 1 - num_reqs] | ||
| * self.uniform_decode_query_len | ||
| + last_loc | ||
| ) | ||
|
Comment on lines
+527
to
+536
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [0, 1, 2] -> [0, 1, 2, 3, 4] |
||
| else: | ||
| # Mixed-batch case: num_reqs must equal num_reqs_padded | ||
| assert num_reqs == num_reqs_padded | ||
|
|
||
| # Insert a dummy request instead of setting query_start_loc[num_reqs] = num_tokens_padded directly | ||
| self.query_start_loc.np[num_reqs_padded + 1] = num_tokens_padded | ||
| num_reqs_padded = num_reqs_padded + 1 | ||
|
|
||
| self.query_start_loc.copy_to_gpu() | ||
|
Comment on lines
+537
to
+545
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [0, 3] -> [0, 3, 4] |
||
|
|
||
| return num_reqs_padded | ||
|
|
||
| def _prepare_inputs( | ||
| self, | ||
| scheduler_output: "SchedulerOutput", | ||
|
|
@@ -666,10 +703,6 @@ def _prepare_inputs( | |
|
|
||
| self.query_start_loc.np[0] = 0 | ||
| self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens | ||
| # NOTE: Due to the FIA operator limitation, here we pad so that hidden_states.shape[0] | ||
| # and self.query_start_loc[num_reqs_padded] are equal | ||
| self.query_start_loc.np[num_reqs + 1:] = (self.arange_np[1:self.max_num_reqs + 1 - num_reqs] | ||
| * self.uniform_decode_query_len + cu_num_tokens[-1]) | ||
| self.query_start_loc.copy_to_gpu() | ||
|
|
||
| self.seq_lens.np[:num_reqs] = ( | ||
|
|
@@ -1153,6 +1186,7 @@ def execute_model( | |
| scheduler_output, | ||
| num_scheduled_tokens_np, | ||
| ) | ||
|
|
||
| num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens | ||
| if self.pcp_size > 1: | ||
| num_tokens_unpadded = self.pcp_manager.total_num_sampled_tokens_pcp | ||
|
|
@@ -1207,6 +1241,11 @@ def execute_model( | |
| use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 | ||
| ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices | ||
|
|
||
| if cudagraph_mode != CUDAGraphMode.NONE: | ||
| num_reqs_padded = self._pad_query_start_loc_for_fia( | ||
| num_tokens_padded, num_reqs_padded, num_reqs | ||
| ) | ||
|
|
||
|
Comment on lines
+1244
to
+1248
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe current_platform.post_process_after_padding |
||
| (attn_metadata, spec_decode_common_attn_metadata) = ( | ||
| self._build_attention_metadata( | ||
| num_tokens=num_tokens_unpadded, | ||
|
|
@@ -1341,7 +1380,6 @@ def execute_model( | |
| assert broadcasted is not None | ||
| logits = broadcasted["logits"] | ||
|
|
||
|
|
||
| # Apply structured output bitmasks if present | ||
| self.execute_model_state = ExecuteModelState( | ||
| scheduler_output, | ||
|
|
@@ -1941,6 +1979,13 @@ def _get_block_table_and_slot_mapping(kv_cache_gid: int): | |
| long_seq_metdadata = _get_pcp_metadata(num_tokens) | ||
| block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) | ||
|
|
||
| actual_last_loc = self.query_start_loc.np[num_reqs_padded] | ||
| error_msg = ( | ||
| f"Due to FIA kernel constraints, when the layout is TND, " | ||
| f"the first dimension of `hidden_states` ({num_tokens_padded}) " | ||
| f"must equal the last element of `actual_seq_lengths_q` ({actual_last_loc})." | ||
| ) | ||
| assert self.query_start_loc.np[num_reqs_padded] == num_tokens_padded, error_msg | ||
| cm_base = AscendCommonAttentionMetadata( | ||
| query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], | ||
| query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], | ||
|
|
@@ -2193,9 +2238,15 @@ def _dummy_run( | |
| self.seq_lens.np[:num_reqs_padded] = seq_lens | ||
| self.seq_lens.np[num_reqs_padded:] = 0 | ||
| self.seq_lens.copy_to_gpu() | ||
|
|
||
| cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) | ||
| self.query_start_loc.np[1 : num_reqs_padded + 1] = cum_num_tokens | ||
| self.query_start_loc.copy_to_gpu() | ||
|
|
||
| num_reqs_padded = self._pad_query_start_loc_for_fia( | ||
| num_tokens_padded, num_reqs_padded, num_reqs | ||
| ) | ||
|
|
||
| pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL | ||
| attn_metadata, _ = self._build_attention_metadata( | ||
| num_tokens=num_tokens_unpadded, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check if other buffers should be extended too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This strange, no error by now.