|
10 | 10 |
|
11 | 11 | import pytest
|
12 | 12 | import torch
|
13 |
| -from tensordict import NonTensorStack, TensorDict |
| 13 | +from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict |
14 | 14 | from tensordict.nn import CompositeDistribution, TensorDictModule
|
15 | 15 | from tensordict.nn.distributions import NormalParamExtractor
|
16 | 16 |
|
17 | 17 | from torch import distributions as dist, nn
|
| 18 | + |
| 19 | +from torchrl.collectors import SyncDataCollector |
18 | 20 | from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
|
19 | 21 | from torchrl.data.llm import LLMData
|
20 | 22 | from torchrl.data.llm.dataset import _has_transformers
|
| 23 | +from torchrl.envs import LLMEnv |
21 | 24 | from torchrl.modules import (
|
22 | 25 | from_hf_transformers,
|
23 | 26 | from_vllm,
|
|
42 | 45 |
|
43 | 46 | if os.getenv("PYTORCH_TEST_FBCODE"):
|
44 | 47 | from pytorch.rl.test._utils_internal import get_default_devices
|
45 |
| - from pytorch.rl.test.mocking_classes import NestedCountingEnv |
| 48 | + from pytorch.rl.test.mocking_classes import DummyStrDataLoader, NestedCountingEnv |
46 | 49 | else:
|
47 | 50 | from _utils_internal import get_default_devices
|
48 |
| - from mocking_classes import NestedCountingEnv |
| 51 | + from mocking_classes import DummyStrDataLoader, NestedCountingEnv |
49 | 52 |
|
50 | 53 | _has_vllm = importlib.util.find_spec("vllm") is not None
|
51 | 54 |
|
@@ -1122,6 +1125,8 @@ def _run_check(
|
1122 | 1125 |
|
1123 | 1126 | # If from text and not generating, the tokens are not returned for now
|
1124 | 1127 | if not (from_text and not generate):
|
| 1128 | + assert td.tokens_response is not None |
| 1129 | + assert td.tokens is not None |
1125 | 1130 | assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
|
1126 | 1131 | # The convention is that the response only has new tokens
|
1127 | 1132 | assert (
|
@@ -1166,26 +1171,34 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
|
1166 | 1171 | )
|
1167 | 1172 |
|
1168 | 1173 | @pytest.mark.parametrize(
|
1169 |
| - "from_text, tokens, attention_mask", |
| 1174 | + "pad_output, from_text, tokens, attention_mask", |
1170 | 1175 | [
|
1171 |
| - (True, None, None), |
| 1176 | + (True, True, None, None), |
| 1177 | + (False, True, None, None), |
1172 | 1178 | (
|
| 1179 | + True, |
1173 | 1180 | False,
|
1174 | 1181 | torch.randint(1024, (1, 10)),
|
1175 | 1182 | torch.ones(1, 10, dtype=torch.int64),
|
1176 | 1183 | ),
|
1177 |
| - (False, torch.randint(1024, (1, 10)), None), |
| 1184 | + (True, False, torch.randint(1024, (1, 10)), None), |
1178 | 1185 | ],
|
1179 | 1186 | )
|
1180 |
| - def test_from_vllm_logprobs(self, from_text, tokens, attention_mask): |
| 1187 | + def test_from_vllm_logprobs(self, from_text, tokens, attention_mask, pad_output): |
1181 | 1188 | torch.manual_seed(0)
|
1182 | 1189 | from vllm import LLM
|
1183 | 1190 |
|
1184 | 1191 | model = LLM(model="facebook/opt-125m")
|
1185 | 1192 | m_generate = from_vllm(
|
1186 |
| - model, from_text=from_text, generate=True, return_log_probs=True |
| 1193 | + model, |
| 1194 | + from_text=from_text, |
| 1195 | + generate=True, |
| 1196 | + return_log_probs=True, |
| 1197 | + pad_output=pad_output, |
| 1198 | + ) |
| 1199 | + m_logprobs = from_vllm( |
| 1200 | + model, from_text=from_text, generate=False, pad_output=pad_output |
1187 | 1201 | )
|
1188 |
| - m_logprobs = from_vllm(model, from_text=from_text, generate=False) |
1189 | 1202 | self._check_lps(
|
1190 | 1203 | m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
|
1191 | 1204 | )
|
@@ -1221,6 +1234,124 @@ def _check_lps(
|
1221 | 1234 | td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
|
1222 | 1235 | )
|
1223 | 1236 |
|
| 1237 | + @pytest.fixture(scope="module") |
| 1238 | + def llm_model(self): |
| 1239 | + import vllm |
| 1240 | + |
| 1241 | + llm_model = vllm.LLM("gpt2") |
| 1242 | + tokenizer = llm_model.get_tokenizer() |
| 1243 | + tokenizer.pad_token = tokenizer.eos_token |
| 1244 | + return llm_model |
| 1245 | + |
| 1246 | + @pytest.mark.parametrize("pad", [True, False]) |
| 1247 | + @pytest.mark.parametrize("generate", [True, False]) |
| 1248 | + @pytest.mark.parametrize("use_tensorclass", [True, False]) |
| 1249 | + def test_vllm_batch_run(self, pad, generate, use_tensorclass, llm_model): |
| 1250 | + # Test generate - padding combinations |
| 1251 | + policy = from_vllm( |
| 1252 | + llm_model, |
| 1253 | + from_text=True, |
| 1254 | + generate=generate, |
| 1255 | + return_log_probs=True, |
| 1256 | + pad_output=pad, |
| 1257 | + generate_kwargs={"max_tokens": 10000}, |
| 1258 | + ) |
| 1259 | + if generate: |
| 1260 | + data = LazyStackedTensorDict( |
| 1261 | + *TensorDict( |
| 1262 | + text=NonTensorStack("a string", "another very long string"), |
| 1263 | + batch_size=[2], |
| 1264 | + ).unbind(0) |
| 1265 | + ) |
| 1266 | + else: |
| 1267 | + data = LazyStackedTensorDict( |
| 1268 | + *TensorDict( |
| 1269 | + text=NonTensorStack("a string", "another very long string"), |
| 1270 | + text_response=NonTensorStack( |
| 1271 | + " is a string", " is still a very long string" |
| 1272 | + ), |
| 1273 | + batch_size=[2], |
| 1274 | + ).unbind(0) |
| 1275 | + ) |
| 1276 | + if use_tensorclass: |
| 1277 | + data = LLMData.from_tensordict(data) |
| 1278 | + output = policy(data) |
| 1279 | + try: |
| 1280 | + log_probs = output.get("log_probs") |
| 1281 | + except Exception: |
| 1282 | + log_probs = output.get("log_probs", as_list=True) |
| 1283 | + if pad: |
| 1284 | + assert isinstance(log_probs, torch.Tensor) |
| 1285 | + else: |
| 1286 | + assert isinstance(log_probs, list) |
| 1287 | + text = output.get("text", as_list=True) |
| 1288 | + # TODO: this is not ideal... |
| 1289 | + if use_tensorclass: |
| 1290 | + assert isinstance(text, list) |
| 1291 | + else: |
| 1292 | + assert isinstance(text, NonTensorStack) |
| 1293 | + text_response = output.get("text_response", as_list=True) |
| 1294 | + if use_tensorclass: |
| 1295 | + assert isinstance(text_response, list) |
| 1296 | + else: |
| 1297 | + assert isinstance(text_response, NonTensorStack) |
| 1298 | + try: |
| 1299 | + tokens_response = output.get("tokens_response") |
| 1300 | + except Exception: |
| 1301 | + tokens_response = output.get("tokens_response", as_list=True) |
| 1302 | + if pad: |
| 1303 | + assert isinstance(tokens_response, torch.Tensor) |
| 1304 | + else: |
| 1305 | + assert isinstance(tokens_response, list) |
| 1306 | + try: |
| 1307 | + tokens = output.get("tokens") |
| 1308 | + except Exception: |
| 1309 | + tokens = output.get("tokens", as_list=True) |
| 1310 | + if not generate: |
| 1311 | + assert tokens is None |
| 1312 | + elif pad: |
| 1313 | + assert isinstance(tokens, torch.Tensor), tokens |
| 1314 | + else: |
| 1315 | + assert isinstance(tokens, list) |
| 1316 | + |
| 1317 | + def test_vllm_collection(self): |
| 1318 | + from vllm import LLM |
| 1319 | + |
| 1320 | + llm = LLM("gpt2") |
| 1321 | + policy = from_vllm( |
| 1322 | + llm, |
| 1323 | + from_text=True, |
| 1324 | + generate=True, |
| 1325 | + return_log_probs=True, |
| 1326 | + pad_output=False, |
| 1327 | + generate_kwargs={"max_tokens": 10}, |
| 1328 | + ) |
| 1329 | + self._run_check_collector(policy) |
| 1330 | + |
| 1331 | + def test_transformers_collection(self): |
| 1332 | + ... |
| 1333 | + |
| 1334 | + @classmethod |
| 1335 | + def env_constructor(cls): |
| 1336 | + dl = DummyStrDataLoader(batch_size=32) |
| 1337 | + env = LLMEnv.from_dataloader( |
| 1338 | + dl, batch_size=16, repeats=4, str2str=True, group_repeats=True |
| 1339 | + ) |
| 1340 | + assert env.batch_size == (64,) |
| 1341 | + return env |
| 1342 | + |
| 1343 | + def _run_check_collector(self, policy): |
| 1344 | + collector = SyncDataCollector( |
| 1345 | + self.env_constructor, |
| 1346 | + policy=policy, |
| 1347 | + frames_per_batch=128, |
| 1348 | + total_frames=512, |
| 1349 | + use_buffers=False, |
| 1350 | + ) |
| 1351 | + for data in collector: |
| 1352 | + assert isinstance(data, LazyStackedTensorDict) |
| 1353 | + assert isinstance(data.reshape(-1).get("text_response"), NonTensorStack) |
| 1354 | + |
1224 | 1355 |
|
1225 | 1356 | if __name__ == "__main__":
|
1226 | 1357 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
|
0 commit comments