|
10 | 10 |
|
11 | 11 | import pytest
|
12 | 12 | import torch
|
13 |
| -from tensordict import NonTensorStack, TensorDict |
| 13 | +from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict |
14 | 14 | from tensordict.nn import CompositeDistribution, TensorDictModule
|
15 | 15 | from tensordict.nn.distributions import NormalParamExtractor
|
16 | 16 |
|
@@ -1122,6 +1122,8 @@ def _run_check(
|
1122 | 1122 |
|
1123 | 1123 | # If from text and not generating, the tokens are not returned for now
|
1124 | 1124 | if not (from_text and not generate):
|
| 1125 | + assert td.tokens_response is not None |
| 1126 | + assert td.tokens is not None |
1125 | 1127 | assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
|
1126 | 1128 | # The convention is that the response only has new tokens
|
1127 | 1129 | assert (
|
@@ -1166,26 +1168,34 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
|
1166 | 1168 | )
|
1167 | 1169 |
|
1168 | 1170 | @pytest.mark.parametrize(
|
1169 |
| - "from_text, tokens, attention_mask", |
| 1171 | + "pad_output, from_text, tokens, attention_mask", |
1170 | 1172 | [
|
1171 |
| - (True, None, None), |
| 1173 | + (True, True, None, None), |
| 1174 | + (False, True, None, None), |
1172 | 1175 | (
|
| 1176 | + True, |
1173 | 1177 | False,
|
1174 | 1178 | torch.randint(1024, (1, 10)),
|
1175 | 1179 | torch.ones(1, 10, dtype=torch.int64),
|
1176 | 1180 | ),
|
1177 |
| - (False, torch.randint(1024, (1, 10)), None), |
| 1181 | + (True, False, torch.randint(1024, (1, 10)), None), |
1178 | 1182 | ],
|
1179 | 1183 | )
|
1180 |
| - def test_from_vllm_logprobs(self, from_text, tokens, attention_mask): |
| 1184 | + def test_from_vllm_logprobs(self, from_text, tokens, attention_mask, pad_output): |
1181 | 1185 | torch.manual_seed(0)
|
1182 | 1186 | from vllm import LLM
|
1183 | 1187 |
|
1184 | 1188 | model = LLM(model="facebook/opt-125m")
|
1185 | 1189 | m_generate = from_vllm(
|
1186 |
| - model, from_text=from_text, generate=True, return_log_probs=True |
| 1190 | + model, |
| 1191 | + from_text=from_text, |
| 1192 | + generate=True, |
| 1193 | + return_log_probs=True, |
| 1194 | + pad_output=pad_output, |
| 1195 | + ) |
| 1196 | + m_logprobs = from_vllm( |
| 1197 | + model, from_text=from_text, generate=False, pad_output=pad_output |
1187 | 1198 | )
|
1188 |
| - m_logprobs = from_vllm(model, from_text=from_text, generate=False) |
1189 | 1199 | self._check_lps(
|
1190 | 1200 | m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
|
1191 | 1201 | )
|
@@ -1221,6 +1231,76 @@ def _check_lps(
|
1221 | 1231 | td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
|
1222 | 1232 | )
|
1223 | 1233 |
|
| 1234 | + @pytest.fixture(scope="module") |
| 1235 | + def llm_model(self): |
| 1236 | + import vllm |
| 1237 | + |
| 1238 | + llm_model = vllm.LLM("gpt2") |
| 1239 | + tokenizer = llm_model.get_tokenizer() |
| 1240 | + tokenizer.pad_token = tokenizer.eos_token |
| 1241 | + return llm_model |
| 1242 | + |
| 1243 | + @pytest.mark.parametrize("pad", [True, False]) |
| 1244 | + @pytest.mark.parametrize("generate", [True, False]) |
| 1245 | + def test_vllm_batch_run(self, pad, generate, llm_model): |
| 1246 | + # Test generate - padding combinations |
| 1247 | + policy = from_vllm( |
| 1248 | + llm_model, |
| 1249 | + from_text=True, |
| 1250 | + generate=generate, |
| 1251 | + return_log_probs=True, |
| 1252 | + pad_output=pad, |
| 1253 | + generate_kwargs={"max_tokens": 10000}, |
| 1254 | + ) |
| 1255 | + if generate: |
| 1256 | + data = LazyStackedTensorDict( |
| 1257 | + *TensorDict( |
| 1258 | + text=NonTensorStack("a string", "another very long string"), |
| 1259 | + batch_size=[2], |
| 1260 | + ).unbind(0) |
| 1261 | + ) |
| 1262 | + else: |
| 1263 | + data = LazyStackedTensorDict( |
| 1264 | + *TensorDict( |
| 1265 | + text=NonTensorStack("a string", "another very long string"), |
| 1266 | + text_response=NonTensorStack( |
| 1267 | + " is a string", " is still a very long string" |
| 1268 | + ), |
| 1269 | + batch_size=[2], |
| 1270 | + ).unbind(0) |
| 1271 | + ) |
| 1272 | + output = policy(data) |
| 1273 | + try: |
| 1274 | + log_probs = output.get("log_probs") |
| 1275 | + except Exception: |
| 1276 | + log_probs = output.get("log_probs", as_list=True) |
| 1277 | + if pad: |
| 1278 | + assert isinstance(log_probs, torch.Tensor) |
| 1279 | + else: |
| 1280 | + assert isinstance(log_probs, list) |
| 1281 | + text = output.get("text", as_list=True) |
| 1282 | + assert isinstance(text, NonTensorStack) |
| 1283 | + text_response = output.get("text_response", as_list=True) |
| 1284 | + assert isinstance(text_response, NonTensorStack) |
| 1285 | + try: |
| 1286 | + tokens_response = output.get("tokens_response") |
| 1287 | + except Exception: |
| 1288 | + tokens_response = output.get("tokens_response", as_list=True) |
| 1289 | + if pad: |
| 1290 | + assert isinstance(tokens_response, torch.Tensor) |
| 1291 | + else: |
| 1292 | + assert isinstance(tokens_response, list) |
| 1293 | + try: |
| 1294 | + tokens = output.get("tokens") |
| 1295 | + except Exception: |
| 1296 | + tokens = output.get("tokens", as_list=True) |
| 1297 | + if not generate: |
| 1298 | + assert tokens is None |
| 1299 | + elif pad: |
| 1300 | + assert isinstance(tokens, torch.Tensor), tokens |
| 1301 | + else: |
| 1302 | + assert isinstance(tokens, list) |
| 1303 | + |
1224 | 1304 |
|
1225 | 1305 | if __name__ == "__main__":
|
1226 | 1306 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
|
0 commit comments