1+ import pytest
2+ import torch
3+ from torch .testing import assert_close
4+
5+ from vllm_ascend .ops .triton .reject_sample import rejection_random_sample_block_verify_kernel
6+ from vllm_ascend .sample .rejection_sampler import rejection_random_sample_block_verify_pytorch
7+
8+ DEVICE = "npu"
9+ BATCH_SIZE = 7
10+ MAX_SPEC_LEN = 3
11+ VOCAB_SIZE = 5
12+
13+ CU_NUM_DRAFT_TOKENS = torch .tensor ([2 , 2 , 5 , 8 , 11 , 14 , 15 ],
14+ dtype = torch .int32 ,
15+ device = DEVICE )
16+ DRAFT_TOKEN_IDS = torch .tensor ([0 , 1 , 0 , 1 , 2 , 0 , 1 , 2 , 0 , 1 , 2 , 0 , 1 , 2 , 0 ],
17+ dtype = torch .int64 ,
18+ device = DEVICE )
19+ NUM_TOKENS = DRAFT_TOKEN_IDS .shape [0 ]
20+ DRAFT_PROBS = None
21+ TARGET_PROBS = torch .tensor (
22+ [
23+ [0.4 , 0.3 , 0.1 , 0.1 , 0.1 ], # 0
24+ [0.1 , 0.9 , 0.0 , 0.0 , 0.0 ], # 1
25+
26+ [0.2 , 0.1 , 0.2 , 0.4 , 0.1 ], # 0
27+ [0.1 , 0.4 , 0.1 , 0.1 , 0.3 ], # 0
28+ [0.2 , 0.1 , 0.4 , 0.1 , 0.2 ], # 0
29+
30+ [0.4 , 0.2 , 0.1 , 0.2 , 0.1 ], # 0
31+ [0.1 , 0.6 , 0.1 , 0.1 , 0.1 ], # 1
32+ [0.2 , 0.2 , 0.2 , 0.3 , 0.1 ], # 0
33+
34+ [0.4 , 0.2 , 0.1 , 0.2 , 0.1 ], # 0
35+ [0.1 , 0.6 , 0.1 , 0.1 , 0.1 ], # 1
36+ [0.2 , 0.2 , 0.2 , 0.3 , 0.1 ], # 0
37+
38+ [0.4 , 0.4 , 0.1 , 0.0 , 0.1 ], # 1
39+ [0.4 , 0.3 , 0.1 , 0.1 , 0.1 ], # 0
40+ [0.4 , 0.0 , 0.5 , 0.0 , 0.1 ], # 1
41+
42+ [0.4 , 0.1 , 0.3 , 0.1 , 0.1 ] # 1
43+ ],
44+ dtype = torch .float32 ,
45+ device = DEVICE )
46+ UNIFORM_PROBS = torch .tensor ([
47+ 0.9 ,
48+ 0.0 ,
49+
50+ 0.9 ,
51+ 0.7 ,
52+ 0.8 ,
53+
54+ 0.5 ,
55+ 0.45 ,
56+ 1.0 ,
57+
58+ 0.39 ,
59+ 0.4 ,
60+ 0.1 ,
61+
62+ 0.3 ,
63+ ],
64+ dtype = torch .float32 ,
65+ device = DEVICE )
66+ BONUS_TOKEN_IDS = torch .full ((BATCH_SIZE , ),
67+ MAX_SPEC_LEN + 1 ,
68+ dtype = torch .int64 ,
69+ device = DEVICE )
70+ RECOVERED_TOKEN_IDS = torch .full ((BATCH_SIZE ,),
71+ MAX_SPEC_LEN + 1 ,
72+ detype = torch .int64 ,
73+ device = DEVICE )
74+ IS_GREEDY = torch .zeros (BATCH_SIZE , dtype = torch .bool , device = DEVICE )
75+ IS_GREEDY [4 ]= True
76+
77+
78+ @pytest .mark .parametrize ("cu_num_draft_tokens" , [CU_NUM_DRAFT_TOKENS ])
79+ @pytest .mark .parametrize ("draft_token_ids" , [DRAFT_TOKEN_IDS ])
80+ @pytest .mark .parametrize ("draft_probs" , [DRAFT_PROBS ])
81+ @pytest .mark .parametrize ("target_probs" , [TARGET_PROBS ])
82+ @pytest .mark .parametrize ("bonus_token_ids" , [BONUS_TOKEN_IDS ])
83+ @pytest .mark .parametrize ("recovered_token_ids" , [RECOVERED_TOKEN_IDS ])
84+ @pytest .mark .parametrize ("uniform_probs" , [UNIFORM_PROBS ])
85+ @pytest .mark .parametrize ("is_greedy" , [IS_GREEDY ])
86+ @pytest .mark .parametrize ("batch_size" , [BATCH_SIZE ])
87+ @pytest .mark .parametrize ("max_spec_len" , [MAX_SPEC_LEN ])
88+ @pytest .mark .parametrize ("vocab_size" , [VOCAB_SIZE ])
89+ @torch .inference_mode ()
90+ def test_rejection_sampler_block_verify_triton_kernel (
91+ cu_num_draft_tokens , # [batch_size]
92+ draft_token_ids , # [num_tokens]
93+ draft_probs , # [num_tokens, vocab_size] or None
94+ target_probs , # [num_tokens, vocab_size]
95+ bonus_token_ids , # [batch_size]
96+ recovered_token_ids , # [num_tokens]
97+ uniform_probs , # [num_tokens]
98+ is_greedy , # [batch_size]
99+ batch_size , # int
100+ max_spec_len , # int
101+ vocab_size , # int
102+ ) -> None :
103+ output_token_ids_ref = torch .full ((batch_size , max_spec_len + 1 ),
104+ - 1 ,
105+ dtype = torch .int64 ,
106+ device = DEVICE )
107+
108+ output_token_ids_triton = output_token_ids_ref .clone ()
109+
110+ rejection_random_sample_block_verify_pytorch (
111+ output_token_ids = output_token_ids_ref ,
112+ cu_num_draft_tokens = cu_num_draft_tokens ,
113+ draft_token_ids = draft_token_ids ,
114+ draft_probs = draft_probs ,
115+ target_probs = target_probs ,
116+ bonus_token_ids = bonus_token_ids ,
117+ recovered_token_ids = recovered_token_ids ,
118+ uniform_probs = uniform_probs ,
119+ is_greedy = is_greedy ,
120+ max_spec_len = max_spec_len ,
121+ vocab_size = vocab_size ,
122+ IS_NGRAM = draft_probs is None )
123+
124+ rejection_random_sample_block_verify_kernel [(batch_size , )](
125+ output_token_ids_ptr = output_token_ids_triton ,
126+ cu_num_draft_tokens_ptr = cu_num_draft_tokens ,
127+ draft_token_ids_ptr = draft_token_ids ,
128+ draft_probs_ptr = draft_probs ,
129+ target_probs_ptr = target_probs ,
130+ bonus_token_ids_ptr = bonus_token_ids ,
131+ recovered_token_ids_ptr = recovered_token_ids ,
132+ uniform_probs_ptr = uniform_probs ,
133+ is_greedy_ptr = is_greedy ,
134+ max_spec_len = max_spec_len ,
135+ vocab_size = vocab_size ,
136+ NO_DRAFT_PROBS = draft_probs is None )
137+
138+ assert_close (output_token_ids_ref , output_token_ids_triton )
0 commit comments