Skip to content

Commit 3d64b0e

Browse files
committed
[Feature] add the magicmtp speculative decoding acceleration algorithm
Signed-off-by: chenaoxuan <cax1165@163.com>
1 parent e760aae commit 3d64b0e

File tree

3 files changed

+355
-33
lines changed

3 files changed

+355
-33
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import pytest
2+
import torch
3+
from torch.testing import assert_close
4+
5+
from vllm_ascend.ops.triton.reject_sample import rejection_random_sample_block_verify_kernel
6+
from vllm_ascend.sample.rejection_sampler import rejection_random_sample_block_verify_pytorch
7+
8+
DEVICE = "npu"
9+
BATCH_SIZE = 7
10+
MAX_SPEC_LEN = 3
11+
VOCAB_SIZE = 5
12+
13+
CU_NUM_DRAFT_TOKENS = torch.tensor([2, 2, 5, 8, 11, 14, 15],
14+
dtype=torch.int32,
15+
device=DEVICE)
16+
DRAFT_TOKEN_IDS = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0],
17+
dtype=torch.int64,
18+
device=DEVICE)
19+
NUM_TOKENS = DRAFT_TOKEN_IDS.shape[0]
20+
DRAFT_PROBS = None
21+
TARGET_PROBS = torch.tensor(
22+
[
23+
[0.4, 0.3, 0.1, 0.1, 0.1], # 0
24+
[0.1, 0.9, 0.0, 0.0, 0.0], # 1
25+
26+
[0.2, 0.1, 0.2, 0.4, 0.1], # 0
27+
[0.1, 0.4, 0.1, 0.1, 0.3], # 0
28+
[0.2, 0.1, 0.4, 0.1, 0.2], # 0
29+
30+
[0.4, 0.2, 0.1, 0.2, 0.1], # 0
31+
[0.1, 0.6, 0.1, 0.1, 0.1], # 1
32+
[0.2, 0.2, 0.2, 0.3, 0.1], # 0
33+
34+
[0.4, 0.2, 0.1, 0.2, 0.1], # 0
35+
[0.1, 0.6, 0.1, 0.1, 0.1], # 1
36+
[0.2, 0.2, 0.2, 0.3, 0.1], # 0
37+
38+
[0.4, 0.4, 0.1, 0.0, 0.1], # 1
39+
[0.4, 0.3, 0.1, 0.1, 0.1], # 0
40+
[0.4, 0.0, 0.5, 0.0, 0.1], # 1
41+
42+
[0.4, 0.1, 0.3, 0.1, 0.1] # 1
43+
],
44+
dtype=torch.float32,
45+
device=DEVICE)
46+
UNIFORM_PROBS = torch.tensor([
47+
0.9,
48+
0.0,
49+
50+
0.9,
51+
0.7,
52+
0.8,
53+
54+
0.5,
55+
0.45,
56+
1.0,
57+
58+
0.39,
59+
0.4,
60+
0.1,
61+
62+
0.3,
63+
],
64+
dtype=torch.float32,
65+
device=DEVICE)
66+
BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ),
67+
MAX_SPEC_LEN + 1,
68+
dtype=torch.int64,
69+
device=DEVICE)
70+
RECOVERED_TOKEN_IDS = torch.full((BATCH_SIZE,),
71+
MAX_SPEC_LEN + 1,
72+
detype=torch.int64,
73+
device=DEVICE)
74+
IS_GREEDY = torch.zeros(BATCH_SIZE, dtype=torch.bool, device=DEVICE)
75+
IS_GREEDY[4]=True
76+
77+
78+
@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS])
79+
@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS])
80+
@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS])
81+
@pytest.mark.parametrize("target_probs", [TARGET_PROBS])
82+
@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS])
83+
@pytest.mark.parametrize("recovered_token_ids", [RECOVERED_TOKEN_IDS])
84+
@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS])
85+
@pytest.mark.parametrize("is_greedy", [IS_GREEDY])
86+
@pytest.mark.parametrize("batch_size", [BATCH_SIZE])
87+
@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN])
88+
@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE])
89+
@torch.inference_mode()
90+
def test_rejection_sampler_block_verify_triton_kernel(
91+
cu_num_draft_tokens, # [batch_size]
92+
draft_token_ids, # [num_tokens]
93+
draft_probs, # [num_tokens, vocab_size] or None
94+
target_probs, # [num_tokens, vocab_size]
95+
bonus_token_ids, # [batch_size]
96+
recovered_token_ids, # [num_tokens]
97+
uniform_probs, # [num_tokens]
98+
is_greedy, # [batch_size]
99+
batch_size, # int
100+
max_spec_len, # int
101+
vocab_size, # int
102+
) -> None:
103+
output_token_ids_ref = torch.full((batch_size, max_spec_len + 1),
104+
-1,
105+
dtype=torch.int64,
106+
device=DEVICE)
107+
108+
output_token_ids_triton = output_token_ids_ref.clone()
109+
110+
rejection_random_sample_block_verify_pytorch(
111+
output_token_ids=output_token_ids_ref,
112+
cu_num_draft_tokens=cu_num_draft_tokens,
113+
draft_token_ids=draft_token_ids,
114+
draft_probs=draft_probs,
115+
target_probs=target_probs,
116+
bonus_token_ids=bonus_token_ids,
117+
recovered_token_ids=recovered_token_ids,
118+
uniform_probs=uniform_probs,
119+
is_greedy=is_greedy,
120+
max_spec_len=max_spec_len,
121+
vocab_size=vocab_size,
122+
IS_NGRAM=draft_probs is None)
123+
124+
rejection_random_sample_block_verify_kernel[(batch_size, )](
125+
output_token_ids_ptr=output_token_ids_triton,
126+
cu_num_draft_tokens_ptr=cu_num_draft_tokens,
127+
draft_token_ids_ptr=draft_token_ids,
128+
draft_probs_ptr=draft_probs,
129+
target_probs_ptr=target_probs,
130+
bonus_token_ids_ptr=bonus_token_ids,
131+
recovered_token_ids_ptr=recovered_token_ids,
132+
uniform_probs_ptr=uniform_probs,
133+
is_greedy_ptr=is_greedy,
134+
max_spec_len=max_spec_len,
135+
vocab_size=vocab_size,
136+
NO_DRAFT_PROBS=draft_probs is None)
137+
138+
assert_close(output_token_ids_ref, output_token_ids_triton)

vllm_ascend/ops/triton/reject_sample.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,73 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
375375
MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation.
376376
BLOCK_SIZE=BLOCK_SIZE,
377377
)
378+
379+
380+
@triton.jit(do_not_specialize=["max_spec_len"])
381+
def rejection_random_sample_block_verify_kernel(
382+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
383+
cu_num_draft_tokens_ptr, # [batch_size]
384+
draft_token_ids_ptr, # [num_tokens]
385+
draft_probs_ptr, # [num_tokens, vocab_size] or None
386+
target_probs_ptr, # [num_tokens, vocab_size]
387+
bonus_token_ids_ptr, # [batch_size]
388+
recovered_token_ids_ptr, # [num_tokens]
389+
uniform_probs_ptr, # [num_tokens]
390+
is_greedy_ptr, # [batch_size]
391+
max_spec_len,
392+
vocab_size,
393+
NO_DRAFT_PROBS: tl.constexpr,
394+
):
395+
req_idx = tl.program_id(0)
396+
is_greedy = tl.load(is_greedy_ptr + req_idx)
397+
if is_greedy:
398+
# Early exit for greedy sampling requests.
399+
return
400+
401+
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr +
402+
req_idx - 1)
403+
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
404+
num_draft_tokens = end_idx - start_idx
405+
406+
rejected = False
407+
pi = 1.0
408+
uniform_prob = 1.0
409+
last_accepted_token_pos = -1
410+
411+
for pos in range(num_draft_tokens):
412+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
413+
target_prob = tl.load(target_probs_ptr +
414+
(start_idx + pos) * vocab_size + draft_token_id)
415+
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
416+
uniform_prob = uniform_prob * tmp_uniform_prob
417+
418+
if NO_DRAFT_PROBS:
419+
draft_prob = 1
420+
else:
421+
draft_prob = tl.load(draft_probs_ptr +
422+
(start_idx + pos) * vocab_size +
423+
draft_token_id)
424+
425+
pi = min(pi * target_prob / draft_prob, 1.0)
426+
if draft_prob > 0 and pi >= uniform_prob:
427+
last_accepted_token_pos = pos
428+
rejected = False
429+
else:
430+
rejected = True
431+
432+
if last_accepted_token_pos > -1:
433+
for pos in range(last_accepted_token_pos + 1):
434+
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
435+
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
436+
token_id)
437+
438+
if rejected:
439+
recovered_token_id = tl.load(recovered_token_ids_ptr + start_idx + last_accepted_token_pos + 1)
440+
tl.store(
441+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
442+
last_accepted_token_pos + 1, recovered_token_id)
443+
else:
444+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
445+
tl.store(
446+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
447+
num_draft_tokens, bonus_token_id)

0 commit comments

Comments
 (0)