-
Notifications
You must be signed in to change notification settings - Fork 725
[Feature] add the magicmtp speculative decoding acceleration algorithm #5542
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ed58a0b to
3d64b0e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the MagicMTP speculative decoding acceleration algorithm, adding both Triton and PyTorch implementations, along with end-to-end tests. The new algorithm is enabled when num_speculative_tokens is 3 or more. My review identified a critical logical error in the Triton kernel's implementation of the cumulative acceptance probability, which deviates from the reference PyTorch implementation and the MagicMTP paper, potentially leading to incorrect sampling results. Additionally, a critical typo was found in the new test file that would prevent it from running. I have provided code suggestions to address both of these critical issues.
| device=DEVICE) | ||
| RECOVERED_TOKEN_IDS = torch.full((BATCH_SIZE,), | ||
| MAX_SPEC_LEN + 1, | ||
| detype=torch.int64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| pi = min(pi * target_prob / draft_prob, 1.0) | ||
| if draft_prob > 0 and pi >= uniform_prob: | ||
| last_accepted_token_pos = pos | ||
| rejected = False | ||
| else: | ||
| rejected = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for calculating the cumulative acceptance probability pi is incorrect and does not match the block verification algorithm from the MagicMTP paper, nor the provided PyTorch reference implementation. The current formula pi = min(pi * target_prob / draft_prob, 1.0) incorrectly inflates the probability when target_prob / draft_prob > 1. The correct formula should be pi = pi * min(target_prob / draft_prob, 1.0).
Furthermore, the current implementation has unsafe handling of division by zero when draft_prob is 0. This relies on floating-point inf behavior and breaks the cumulative product logic in subsequent iterations.
The suggested change corrects the formula and handles the division-by-zero case safely, ensuring the logic is correct and robust.
| pi = min(pi * target_prob / draft_prob, 1.0) | |
| if draft_prob > 0 and pi >= uniform_prob: | |
| last_accepted_token_pos = pos | |
| rejected = False | |
| else: | |
| rejected = True | |
| ratio = 0.0 | |
| if draft_prob > 0.0: | |
| ratio = target_prob / draft_prob | |
| pi = pi * tl.min(ratio, 1.0) | |
| if pi >= uniform_prob: | |
| last_accepted_token_pos = pos | |
| rejected = False | |
| else: | |
| rejected = True |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
47958ae to
395a8ec
Compare
|
What performance improvement does this feature offer? |
b400a85 to
b84342b
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
fedc6cd to
752d372
Compare
19101b0 to
00e2410
Compare
Performance Improvement:Environment: Parameter: Data: Model: Improvement:
|
Increase the draft token acceptance rate by accepting more potential draft tokens. |
bcefa40 to
47186cd
Compare
Signed-off-by: chenaoxuan <[email protected]>
47186cd to
80a1293
Compare
vllm-project#5542) ### What this PR does / why we need it? 1. MagicMTP (paper: "Block Verification Accelerates Speculative Decoding") was introduced to consider the influence among multiple draft tokens, improving the acceptance rate without compromising accuracy. 2. Added Triton and PyTorch implementations, and added E2E test cases. ### Does this PR introduce _any_ user-facing change? MagicMTP will automatically take effect when the parameter "num_speculative_tokens" >= 3. - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@7157596 Signed-off-by: chenaoxuan <[email protected]>
…into FIA_rebase * 'FIA_rebase' of https://github.com/845473182/vllm-ascend: (39 commits) [CI] Drop outdated cases (vllm-project#5709) [EPLB][CI] EPLB add aclgraph and redundant expert ci (vllm-project#5625) [CI] fix image build tag (vllm-project#5703) Optimize the print info format when deprecated code is used in vllm-ascend (vllm-project#5696) [Feature] add the magicmtp speculative decoding acceleration algorithm (vllm-project#5542) [bugfix] adapt to new implemented get_kv_cache_spec in cpuoffload connector (vllm-project#4311) [refactor] Refactor the interface for shard weight and remove the flashcomm2 o_shared interface. (vllm-project#5181) [BugFix][P/D] Fix pre-create link parameter error (vllm-project#5694) [Kernel] Add moe_gating_top_k operator support for Ascend NPU (vllm-project#5579) [1/N][CI] Refactor accuracy test (vllm-project#5400) [BugFix][Fusion] Fix graph fusion failure problem (vllm-project#5676) [Tests] Add qwen3-8b nightly test (vllm-project#5597) [Refactor] Import global var form vllm instead of overwirte it (vllm-project#5469) [Refactor] Fix AttentionMaskBuilder singleton and remove redundant pcp_prefill_mask (vllm-project#4870) [CI] move image and wheel job to schedule way (vllm-project#5685) [Bugfix] Fix the graph capture failure issue in the eagle3+full scenario. (vllm-project#5553) [Bugfix] fix resource are insufficient when pcp and piecewise (vllm-project#5377) [CI] Add workflow to cancel running workflows on PR close (vllm-project#5646) [CI] Bump lm-eval version to v0.4.9.2 (vllm-project#5655) [CI] cleanup single/multi-card test (vllm-project#5623) ...
What this PR does / why we need it?
Does this PR introduce any user-facing change?
MagicMTP will automatically take effect when the parameter "num_speculative_tokens" >= 3.