|
2 | 2 | import torch |
3 | 3 | from vllm.v1.sample.rejection_sampler import \ |
4 | 4 | rejection_random_sample_kernel as original_rejection_random_sample_kernel |
5 | | - |
6 | 5 | from vllm_ascend.ops.triton.reject_sample import ( |
7 | | - cal_grid_and_block_size, rejection_random_sample_kernel) |
| 6 | + cal_grid_and_block_size, rejection_random_sample_block_verify_kernel, |
| 7 | + rejection_random_sample_kernel) |
8 | 8 | from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton |
9 | | - |
| 9 | +from vllm_ascend.sample.rejection_sampler import \ |
| 10 | + rejection_random_sample_block_verify_pytorch |
10 | 11 |
|
11 | 12 | @pytest.fixture(scope="function", autouse=True) |
12 | 13 | def setup_device_properties(): |
@@ -93,3 +94,134 @@ def test_rejection_random_sample(max_spec_len, vocab_size, batch_size): |
93 | 94 | BLOCK_SIZE=block_size) |
94 | 95 | torch.npu.synchronize() |
95 | 96 | assert torch.equal(original_output_token_ids, output_token_ids) |
| 97 | + |
| 98 | + |
| 99 | +DEVICE = "npu" |
| 100 | +BATCH_SIZE = 7 |
| 101 | +MAX_SPEC_LEN = 3 |
| 102 | +VOCAB_SIZE = 5 |
| 103 | +CU_NUM_DRAFT_TOKENS = torch.tensor([2, 2, 5, 8, 11, 14, 15], |
| 104 | + dtype=torch.int32, |
| 105 | + device=DEVICE) |
| 106 | +DRAFT_TOKEN_IDS = torch.tensor([0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0], |
| 107 | + dtype=torch.int64, |
| 108 | + device=DEVICE) |
| 109 | +NUM_TOKENS = DRAFT_TOKEN_IDS.shape[0] |
| 110 | +DRAFT_PROBS = None |
| 111 | +TARGET_PROBS = torch.tensor( |
| 112 | + [ |
| 113 | + [0.4, 0.3, 0.1, 0.1, 0.1], # 0 |
| 114 | + [0.1, 0.9, 0.0, 0.0, 0.0], # 1 |
| 115 | + [0.2, 0.1, 0.2, 0.4, 0.1], # 0 |
| 116 | + [0.1, 0.4, 0.1, 0.1, 0.3], # 0 |
| 117 | + [0.2, 0.1, 0.4, 0.1, 0.2], # 0 |
| 118 | + [0.4, 0.2, 0.1, 0.2, 0.1], # 0 |
| 119 | + [0.1, 0.6, 0.1, 0.1, 0.1], # 1 |
| 120 | + [0.2, 0.2, 0.2, 0.3, 0.1], # 0 |
| 121 | + [0.4, 0.2, 0.1, 0.2, 0.1], # 0 |
| 122 | + [0.1, 0.6, 0.1, 0.1, 0.1], # 1 |
| 123 | + [0.2, 0.2, 0.2, 0.3, 0.1], # 0 |
| 124 | + [0.4, 0.4, 0.1, 0.0, 0.1], # 1 |
| 125 | + [0.4, 0.3, 0.1, 0.1, 0.1], # 0 |
| 126 | + [0.4, 0.0, 0.5, 0.0, 0.1], # 1 |
| 127 | + [0.4, 0.1, 0.3, 0.1, 0.1], # 1 |
| 128 | + ], |
| 129 | + dtype=torch.float32, |
| 130 | + device=DEVICE) |
| 131 | +UNIFORM_PROBS = torch.tensor([ |
| 132 | + 0.9, |
| 133 | + 0.0, |
| 134 | + 0.9, |
| 135 | + 0.7, |
| 136 | + 0.8, |
| 137 | + 0.5, |
| 138 | + 0.45, |
| 139 | + 1.0, |
| 140 | + 0.5, |
| 141 | + 0.45, |
| 142 | + 1.0, |
| 143 | + 0.39, |
| 144 | + 0.4, |
| 145 | + 0.1, |
| 146 | + 0.3, |
| 147 | +], |
| 148 | + dtype=torch.float32, |
| 149 | + device=DEVICE) |
| 150 | +BONUS_TOKEN_IDS = torch.full((BATCH_SIZE, ), |
| 151 | + MAX_SPEC_LEN + 1, |
| 152 | + dtype=torch.int64, |
| 153 | + device=DEVICE) |
| 154 | +RECOVERED_TOKEN_IDS = torch.full((NUM_TOKENS, ), |
| 155 | + MAX_SPEC_LEN, |
| 156 | + dtype=torch.int64, |
| 157 | + device=DEVICE) |
| 158 | +IS_GREEDY = torch.zeros(BATCH_SIZE, dtype=torch.bool, device=DEVICE) |
| 159 | +IS_GREEDY[4] = True |
| 160 | + |
| 161 | + |
| 162 | +@pytest.mark.parametrize("cu_num_draft_tokens", [CU_NUM_DRAFT_TOKENS]) |
| 163 | +@pytest.mark.parametrize("draft_token_ids", [DRAFT_TOKEN_IDS]) |
| 164 | +@pytest.mark.parametrize("draft_probs", [DRAFT_PROBS]) |
| 165 | +@pytest.mark.parametrize("target_probs", [TARGET_PROBS]) |
| 166 | +@pytest.mark.parametrize("bonus_token_ids", [BONUS_TOKEN_IDS]) |
| 167 | +@pytest.mark.parametrize("recovered_token_ids", [RECOVERED_TOKEN_IDS]) |
| 168 | +@pytest.mark.parametrize("uniform_probs", [UNIFORM_PROBS]) |
| 169 | +@pytest.mark.parametrize("is_greedy", [IS_GREEDY]) |
| 170 | +@pytest.mark.parametrize("batch_size", [BATCH_SIZE]) |
| 171 | +@pytest.mark.parametrize("max_spec_len", [MAX_SPEC_LEN]) |
| 172 | +@pytest.mark.parametrize("vocab_size", [VOCAB_SIZE]) |
| 173 | +@torch.inference_mode() |
| 174 | +def test_rejection_sampler_block_verify_triton_kernel( |
| 175 | + cu_num_draft_tokens, # [batch_size] |
| 176 | + draft_token_ids, # [num_tokens] |
| 177 | + draft_probs, # [num_tokens, vocab_size] or None |
| 178 | + target_probs, # [num_tokens, vocab_size] |
| 179 | + bonus_token_ids, # [batch_size] |
| 180 | + recovered_token_ids, # [num_tokens] |
| 181 | + uniform_probs, # [num_tokens] |
| 182 | + is_greedy, # [batch_size] |
| 183 | + batch_size, # int |
| 184 | + max_spec_len, # int |
| 185 | + vocab_size, # int |
| 186 | +) -> None: |
| 187 | + |
| 188 | + grid, block_size = cal_grid_and_block_size(batch_size) |
| 189 | + |
| 190 | + output_token_ids_ref = torch.full((batch_size, max_spec_len + 1), |
| 191 | + -1, |
| 192 | + dtype=torch.int64, |
| 193 | + device=DEVICE) |
| 194 | + |
| 195 | + output_token_ids_triton = output_token_ids_ref.clone() |
| 196 | + |
| 197 | + rejection_random_sample_block_verify_pytorch( |
| 198 | + output_token_ids=output_token_ids_ref, |
| 199 | + cu_num_draft_tokens=cu_num_draft_tokens, |
| 200 | + draft_token_ids=draft_token_ids, |
| 201 | + draft_probs=draft_probs, |
| 202 | + target_probs=target_probs, |
| 203 | + bonus_token_ids=bonus_token_ids, |
| 204 | + recovered_token_ids=recovered_token_ids, |
| 205 | + uniform_probs=uniform_probs, |
| 206 | + is_greedy=is_greedy, |
| 207 | + max_spec_len=max_spec_len, |
| 208 | + vocab_size=vocab_size, |
| 209 | + IS_NGRAM=draft_probs is None) |
| 210 | + |
| 211 | + rejection_random_sample_block_verify_kernel[(grid, )]( |
| 212 | + output_token_ids_ptr=output_token_ids_triton, |
| 213 | + cu_num_draft_tokens_ptr=cu_num_draft_tokens, |
| 214 | + draft_token_ids_ptr=draft_token_ids, |
| 215 | + draft_probs_ptr=draft_probs, |
| 216 | + target_probs_ptr=target_probs, |
| 217 | + bonus_token_ids_ptr=bonus_token_ids, |
| 218 | + recovered_token_ids_ptr=recovered_token_ids, |
| 219 | + uniform_probs_ptr=uniform_probs, |
| 220 | + is_greedy_ptr=is_greedy, |
| 221 | + max_spec_len=max_spec_len, |
| 222 | + vocab_size=vocab_size, |
| 223 | + vec_len=batch_size, |
| 224 | + NO_DRAFT_PROBS=draft_probs is None, |
| 225 | + BLOCK_SIZE=block_size) |
| 226 | + torch.npu.synchronize() |
| 227 | + assert torch.equal(output_token_ids_ref, output_token_ids_triton) |
0 commit comments