-
Notifications
You must be signed in to change notification settings - Fork 703
[pref] calculate local_total_toks in build meata #5543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: F.Liu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a performance optimization by pre-calculating local_total_toks during the metadata build process, which avoids redundant calculations in _compute_prefill_context. The changes are logical and correctly implemented. I've added suggestions for a further minor optimization to avoid a redundant sum() operation during the initial calculation of local_total_toks.
| dcp_rank] | ||
| actual_seq_lengths_kv = torch.cumsum( | ||
| local_chunked_kv_lens_rank, dim=0).tolist() | ||
| local_total_toks = local_chunked_kv_lens_rank.sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid a redundant sum() operation, you can obtain local_total_toks from actual_seq_lengths_kv, which is calculated just before. The total sum is the last element of the cumulative sum list. This change will make local_total_toks an integer. Note that a corresponding change is needed at the call site to remove the .item() call.
| local_total_toks = local_chunked_kv_lens_rank.sum() | |
| local_total_toks = actual_seq_lengths_kv[-1] if actual_seq_lengths_kv else 0 |
| batch_chunk_seq_mask=batch_chunk_seq_mask, | ||
| chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices | ||
| chunk_seq_mask_filtered_indices=chunk_seq_mask_filtered_indices, | ||
| local_total_toks=local_total_toks.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: F.Liu <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?