Skip to content

Commit bcefa40

Browse files
committed
[Feature] add the magicmtp speculative decoding acceleration algorithm
Signed-off-by: chenaoxuan <cax1165@163.com>
1 parent 1140789 commit bcefa40

File tree

3 files changed

+372
-39
lines changed

3 files changed

+372
-39
lines changed

tests/e2e/nightly/single_node/ops/singlecard_ops/triton/test_rejection_sample.py

Lines changed: 135 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import torch
33
from vllm.v1.sample.rejection_sampler import \
44
rejection_random_sample_kernel as original_rejection_random_sample_kernel
5-
65
from vllm_ascend.ops.triton.reject_sample import (
7-
cal_grid_and_block_size, rejection_random_sample_kernel)
6+
cal_grid_and_block_size, rejection_random_sample_block_verify_kernel,
7+
rejection_random_sample_kernel)
88
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
9-
9+
from vllm_ascend.sample.rejection_sampler import \
10+
rejection_random_sample_block_verify_pytorch
1011

1112
@pytest.fixture(scope="function", autouse=True)
1213
def setup_device_properties():
@@ -93,3 +94,134 @@ def test_rejection_random_sample(max_spec_len, vocab_size, batch_size):
9394
BLOCK_SIZE=block_size)
9495
torch.npu.synchronize()
9596
assert torch.equal(original_output_token_ids, output_token_ids)
97+
98+
99+
DEVICE = "npu"
100+
BATCH_SIZE = 7
101+
MAX_SPEC_LEN = 3
102+
VOCAB_SIZE = 5
103+
CU_NUM_DRAFT_TOKENS = torch.tensor([2, 2, 5, 8, 11, 14, 15],
104+
dtype=torch.int32,
105+
device=DEVICE)
106+
DRAFT_TOKEN_IDS = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0],
107+
dtype=torch.int64,
108+
device=DEVICE)
109+
NUM_TOKENS = DRAFT_TOKEN_IDS.shape[0]
110+
DRAFT_PROBS = None
111+
TARGET_PROBS = torch.tensor(
112+
[
113+
[0.4, 0.3, 0.1, 0.1, 0.1], # 0
114+
[0.1, 0.9, 0.0, 0.0, 0.0], # 1
115+
[0.2, 0.1, 0.2, 0.4, 0.1], # 0
116+
[0.1, 0.4, 0.1, 0.1, 0.3], # 0
117+
[0.2, 0.1, 0.4, 0.1, 0.2], # 0
118+
[0.4, 0.2, 0.1, 0.2, 0.1], # 0
119+
[0.1, 0.6, 0.1, 0.1, 0.1], # 1
120+
[0.2, 0.2, 0.2, 0.3, 0.1], # 0
121+
[0.4, 0.2, 0.1, 0.2, 0.1], # 0
122+
[0.1, 0.6, 0.1, 0.1, 0.1], # 1
123+
[0.2, 0.2, 0.2, 0.3, 0.1], # 0
124+
[0.4, 0.4, 0.1, 0.0, 0.1], # 1
125+
[0.4, 0.3, 0.1, 0.1, 0.1], # 0
126+
[0.4, 0.0, 0.5, 0.0, 0.1], # 1
127+
[0.4, 0.1, 0.3, 0.1, 0.1], # 1
128+
],
129+
dtype=torch.float32,
130+
device=DEVICE)
131+
UNIFORM_PROBS = torch.tensor([
132+
0.9,
133+
0.0,
134+
0.9,
135+
0.7,
136+
0.8,
137+
0.5,
138+
0.45,
139+
1.0,
140+
0.5,
141+
0.45,
142+
1.0,
143+
0.39,
144+
0.4,
145+
0.1,
146+
0.3,
147+
],
148+
dtype=torch.float32,
149+
device=DEVICE)
150+
BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ),
151+
MAX_SPEC_LEN + 1,
152+
dtype=torch.int64,
153+
device=DEVICE)
154+
RECOVERED_TOKEN_IDS = torch.full((NUM_TOKENS, ),
155+
MAX_SPEC_LEN,
156+
dtype=torch.int64,
157+
device=DEVICE)
158+
IS_GREEDY = torch.zeros(BATCH_SIZE, dtype=torch.bool, device=DEVICE)
159+
IS_GREEDY[4] = True
160+
161+
162+
@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS])
163+
@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS])
164+
@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS])
165+
@pytest.mark.parametrize("target_probs", [TARGET_PROBS])
166+
@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS])
167+
@pytest.mark.parametrize("recovered_token_ids", [RECOVERED_TOKEN_IDS])
168+
@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS])
169+
@pytest.mark.parametrize("is_greedy", [IS_GREEDY])
170+
@pytest.mark.parametrize("batch_size", [BATCH_SIZE])
171+
@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN])
172+
@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE])
173+
@torch.inference_mode()
174+
def test_rejection_sampler_block_verify_triton_kernel(
175+
cu_num_draft_tokens, # [batch_size]
176+
draft_token_ids, # [num_tokens]
177+
draft_probs, # [num_tokens, vocab_size] or None
178+
target_probs, # [num_tokens, vocab_size]
179+
bonus_token_ids, # [batch_size]
180+
recovered_token_ids, # [num_tokens]
181+
uniform_probs, # [num_tokens]
182+
is_greedy, # [batch_size]
183+
batch_size, # int
184+
max_spec_len, # int
185+
vocab_size, # int
186+
) -> None:
187+
188+
grid, block_size = cal_grid_and_block_size(batch_size)
189+
190+
output_token_ids_ref = torch.full((batch_size, max_spec_len + 1),
191+
-1,
192+
dtype=torch.int64,
193+
device=DEVICE)
194+
195+
output_token_ids_triton = output_token_ids_ref.clone()
196+
197+
rejection_random_sample_block_verify_pytorch(
198+
output_token_ids=output_token_ids_ref,
199+
cu_num_draft_tokens=cu_num_draft_tokens,
200+
draft_token_ids=draft_token_ids,
201+
draft_probs=draft_probs,
202+
target_probs=target_probs,
203+
bonus_token_ids=bonus_token_ids,
204+
recovered_token_ids=recovered_token_ids,
205+
uniform_probs=uniform_probs,
206+
is_greedy=is_greedy,
207+
max_spec_len=max_spec_len,
208+
vocab_size=vocab_size,
209+
IS_NGRAM=draft_probs is None)
210+
211+
rejection_random_sample_block_verify_kernel[(grid, )](
212+
output_token_ids_ptr=output_token_ids_triton,
213+
cu_num_draft_tokens_ptr=cu_num_draft_tokens,
214+
draft_token_ids_ptr=draft_token_ids,
215+
draft_probs_ptr=draft_probs,
216+
target_probs_ptr=target_probs,
217+
bonus_token_ids_ptr=bonus_token_ids,
218+
recovered_token_ids_ptr=recovered_token_ids,
219+
uniform_probs_ptr=uniform_probs,
220+
is_greedy_ptr=is_greedy,
221+
max_spec_len=max_spec_len,
222+
vocab_size=vocab_size,
223+
vec_len=batch_size,
224+
NO_DRAFT_PROBS=draft_probs is None,
225+
BLOCK_SIZE=block_size)
226+
torch.npu.synchronize()
227+
assert torch.equal(output_token_ids_ref, output_token_ids_triton)

vllm_ascend/ops/triton/reject_sample.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,84 @@ def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from,
378378
MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation.
379379
BLOCK_SIZE=block_size,
380380
)
381+
382+
383+
@triton.jit(do_not_specialize=["max_spec_len"])
384+
def rejection_random_sample_block_verify_kernel(
385+
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
386+
cu_num_draft_tokens_ptr, # [batch_size]
387+
draft_token_ids_ptr, # [num_tokens]
388+
draft_probs_ptr, # [num_tokens, vocab_size] or None
389+
target_probs_ptr, # [num_tokens, vocab_size]
390+
bonus_token_ids_ptr, # [batch_size]
391+
recovered_token_ids_ptr, # [num_tokens]
392+
uniform_probs_ptr, # [num_tokens]
393+
is_greedy_ptr, # [batch_size]
394+
max_spec_len,
395+
vocab_size,
396+
vec_len,
397+
NO_DRAFT_PROBS: tl.constexpr,
398+
BLOCK_SIZE: tl.constexpr):
399+
block_idx = tl.program_id(0)
400+
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
401+
mask = offsets < vec_len
402+
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
403+
not_greedy_mask = is_greedy == 0
404+
start_idxs = tl.where(
405+
offsets == 0, 0,
406+
tl.load(cu_num_draft_tokens_ptr + offsets - 1, not_greedy_mask))
407+
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
408+
n_num_draft_tokens = end_idxs - start_idxs
409+
for req_i in range(BLOCK_SIZE):
410+
not_greedy = tl.get_element(not_greedy_mask, (req_i, ))
411+
if not_greedy:
412+
413+
rejected = False
414+
pi = 1.0
415+
uniform_prob = 1.0
416+
last_accepted_token_pos = -1
417+
start_idx = tl.get_element(start_idxs, (req_i, ))
418+
req_idx = block_idx * BLOCK_SIZE + req_i
419+
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i, ))
420+
421+
for pos in range(num_draft_tokens):
422+
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
423+
target_prob = tl.load(target_probs_ptr +
424+
(start_idx + pos) * vocab_size +
425+
draft_token_id)
426+
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
427+
uniform_prob = uniform_prob * tmp_uniform_prob
428+
429+
if NO_DRAFT_PROBS:
430+
draft_prob = 1
431+
else:
432+
draft_prob = tl.load(draft_probs_ptr +
433+
(start_idx + pos) * vocab_size +
434+
draft_token_id)
435+
436+
pi = min(pi * target_prob / draft_prob, 1.0)
437+
if draft_prob > 0 and pi >= uniform_prob:
438+
last_accepted_token_pos = pos
439+
rejected = False
440+
else:
441+
rejected = True
442+
443+
if last_accepted_token_pos > -1:
444+
for pos in range(last_accepted_token_pos + 1):
445+
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
446+
tl.store(
447+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
448+
pos, token_id)
449+
450+
if rejected:
451+
recovered_token_id = tl.load(recovered_token_ids_ptr +
452+
start_idx +
453+
last_accepted_token_pos + 1)
454+
tl.store(
455+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
456+
last_accepted_token_pos + 1, recovered_token_id)
457+
else:
458+
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
459+
tl.store(
460+
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
461+
num_draft_tokens, bonus_token_id)

0 commit comments

Comments
 (0)