Skip to content

Commit cb89b3e

Browse files
committed
[Feature] Allow using lists of tensors in vllm instead of padded tensors
ghstack-source-id: 46615f1 Pull Request resolved: #2861
1 parent d068095 commit cb89b3e

File tree

5 files changed

+415
-131
lines changed

5 files changed

+415
-131
lines changed

test/test_actors.py

+160-16
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010

1111
import pytest
1212
import torch
13-
from tensordict import NonTensorStack, TensorDict
13+
from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict
1414
from tensordict.nn import CompositeDistribution, TensorDictModule
1515
from tensordict.nn.distributions import NormalParamExtractor
1616

1717
from torch import distributions as dist, nn
18+
19+
from torchrl.collectors import SyncDataCollector
1820
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
1921
from torchrl.data.llm import LLMData
2022
from torchrl.data.llm.dataset import _has_transformers
23+
from torchrl.envs import LLMEnv
2124
from torchrl.modules import (
2225
from_hf_transformers,
2326
from_vllm,
@@ -42,10 +45,10 @@
4245

4346
if os.getenv("PYTORCH_TEST_FBCODE"):
4447
from pytorch.rl.test._utils_internal import get_default_devices
45-
from pytorch.rl.test.mocking_classes import NestedCountingEnv
48+
from pytorch.rl.test.mocking_classes import DummyStrDataLoader, NestedCountingEnv
4649
else:
4750
from _utils_internal import get_default_devices
48-
from mocking_classes import NestedCountingEnv
51+
from mocking_classes import DummyStrDataLoader, NestedCountingEnv
4952

5053
_has_vllm = importlib.util.find_spec("vllm") is not None
5154

@@ -922,6 +925,18 @@ def test_lmhead_actorvalueoperator(device):
922925
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
923926
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
924927
class TestLLMActor:
928+
@pytest.fixture(scope="module")
929+
def vllm_instance(self):
930+
try:
931+
import vllm
932+
except ImportError:
933+
pytest.skip(reason="missing vllm")
934+
935+
llm_model = vllm.LLM("gpt2")
936+
tokenizer = llm_model.get_tokenizer()
937+
tokenizer.pad_token = tokenizer.eos_token
938+
return llm_model
939+
925940
@pytest.mark.parametrize(
926941
"from_text, generate, return_log_probs, tokens, attention_mask",
927942
[
@@ -1005,12 +1020,17 @@ def test_from_hf_transformers(
10051020
],
10061021
)
10071022
def test_from_vllm(
1008-
self, from_text, generate, return_log_probs, tokens, attention_mask
1023+
self,
1024+
from_text,
1025+
generate,
1026+
return_log_probs,
1027+
tokens,
1028+
attention_mask,
1029+
vllm_instance,
10091030
):
10101031
torch.manual_seed(0)
1011-
from vllm import LLM
10121032

1013-
model = LLM(model="facebook/opt-125m")
1033+
model = vllm_instance
10141034
m = from_vllm(
10151035
model,
10161036
from_text=from_text,
@@ -1122,6 +1142,8 @@ def _run_check(
11221142

11231143
# If from text and not generating, the tokens are not returned for now
11241144
if not (from_text and not generate):
1145+
assert td.tokens_response is not None
1146+
assert td.tokens is not None
11251147
assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
11261148
# The convention is that the response only has new tokens
11271149
assert (
@@ -1166,28 +1188,43 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
11661188
)
11671189

11681190
@pytest.mark.parametrize(
1169-
"from_text, tokens, attention_mask",
1191+
"pad_output, from_text, tokens, attention_mask",
11701192
[
1171-
(True, None, None),
1193+
(True, True, None, None),
1194+
(False, True, None, None),
11721195
(
1196+
True,
11731197
False,
11741198
torch.randint(1024, (1, 10)),
11751199
torch.ones(1, 10, dtype=torch.int64),
11761200
),
1177-
(False, torch.randint(1024, (1, 10)), None),
1201+
(True, False, torch.randint(1024, (1, 10)), None),
11781202
],
11791203
)
1180-
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
1204+
def test_from_vllm_logprobs(
1205+
self, from_text, tokens, attention_mask, pad_output, vllm_instance
1206+
):
11811207
torch.manual_seed(0)
1182-
from vllm import LLM
11831208

1184-
model = LLM(model="facebook/opt-125m")
1209+
model = vllm_instance
11851210
m_generate = from_vllm(
1186-
model, from_text=from_text, generate=True, return_log_probs=True
1211+
model,
1212+
from_text=from_text,
1213+
generate=True,
1214+
return_log_probs=True,
1215+
pad_output=pad_output,
1216+
)
1217+
m_logprobs = from_vllm(
1218+
model, from_text=from_text, generate=False, pad_output=pad_output
11871219
)
1188-
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
11891220
self._check_lps(
1190-
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
1221+
m_generate,
1222+
m_logprobs,
1223+
tokens,
1224+
attention_mask,
1225+
from_text,
1226+
has_logits=False,
1227+
tol=1e-1,
11911228
)
11921229

11931230
def _check_lps(
@@ -1198,6 +1235,7 @@ def _check_lps(
11981235
attention_mask,
11991236
from_text,
12001237
has_logits,
1238+
tol=1e-2,
12011239
):
12021240
# Checks that the log-probs gathered with generate=False equate those with generate=True
12031241
tdin_genetate = self._make_data(
@@ -1218,8 +1256,114 @@ def _check_lps(
12181256
assert td_generate.log_probs.shape == td_generate.tokens_response.shape
12191257
assert td_logprobs.log_probs.shape == td_generate.tokens_response.shape
12201258
torch.testing.assert_close(
1221-
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
1259+
td_generate.log_probs, td_logprobs.log_probs, rtol=tol, atol=tol
1260+
)
1261+
1262+
@pytest.mark.parametrize("pad", [True, False])
1263+
@pytest.mark.parametrize("generate", [True, False])
1264+
@pytest.mark.parametrize("use_tensorclass", [True, False])
1265+
def test_vllm_batch_run(self, pad, generate, use_tensorclass, vllm_instance):
1266+
# Test generate - padding combinations
1267+
policy = from_vllm(
1268+
vllm_instance,
1269+
from_text=True,
1270+
generate=generate,
1271+
return_log_probs=True,
1272+
pad_output=pad,
1273+
generate_kwargs={"max_tokens": 10000},
1274+
)
1275+
if generate:
1276+
data = LazyStackedTensorDict(
1277+
*TensorDict(
1278+
text=NonTensorStack("a string", "another very long string"),
1279+
batch_size=[2],
1280+
).unbind(0)
1281+
)
1282+
else:
1283+
data = LazyStackedTensorDict(
1284+
*TensorDict(
1285+
text=NonTensorStack("a string", "another very long string"),
1286+
text_response=NonTensorStack(
1287+
" is a string", " is still a very long string"
1288+
),
1289+
batch_size=[2],
1290+
).unbind(0)
1291+
)
1292+
if use_tensorclass:
1293+
data = LLMData.from_tensordict(data)
1294+
output = policy(data)
1295+
try:
1296+
log_probs = output.get("log_probs")
1297+
except Exception:
1298+
log_probs = output.get("log_probs", as_list=True)
1299+
if pad:
1300+
assert isinstance(log_probs, torch.Tensor)
1301+
else:
1302+
assert isinstance(log_probs, list)
1303+
text = output.get("text", as_list=True)
1304+
# TODO: this is not ideal...
1305+
if use_tensorclass:
1306+
assert isinstance(text, list)
1307+
else:
1308+
assert isinstance(text, NonTensorStack)
1309+
text_response = output.get("text_response", as_list=True)
1310+
if use_tensorclass:
1311+
assert isinstance(text_response, list)
1312+
else:
1313+
assert isinstance(text_response, NonTensorStack)
1314+
try:
1315+
tokens_response = output.get("tokens_response")
1316+
except Exception:
1317+
tokens_response = output.get("tokens_response", as_list=True)
1318+
if pad:
1319+
assert isinstance(tokens_response, torch.Tensor)
1320+
else:
1321+
assert isinstance(tokens_response, list)
1322+
try:
1323+
tokens = output.get("tokens")
1324+
except Exception:
1325+
tokens = output.get("tokens", as_list=True)
1326+
if not generate:
1327+
assert tokens is None
1328+
elif pad:
1329+
assert isinstance(tokens, torch.Tensor), tokens
1330+
else:
1331+
assert isinstance(tokens, list)
1332+
1333+
def test_vllm_collection(self, vllm_instance):
1334+
policy = from_vllm(
1335+
vllm_instance,
1336+
from_text=True,
1337+
generate=True,
1338+
return_log_probs=True,
1339+
pad_output=False,
1340+
generate_kwargs={"max_tokens": 10},
1341+
)
1342+
self._run_check_collector(policy)
1343+
1344+
def test_transformers_collection(self):
1345+
...
1346+
1347+
@classmethod
1348+
def env_constructor(cls):
1349+
dl = DummyStrDataLoader(batch_size=32)
1350+
env = LLMEnv.from_dataloader(
1351+
dl, batch_size=16, repeats=4, str2str=True, group_repeats=True
1352+
)
1353+
assert env.batch_size == (64,)
1354+
return env
1355+
1356+
def _run_check_collector(self, policy):
1357+
collector = SyncDataCollector(
1358+
self.env_constructor,
1359+
policy=policy,
1360+
frames_per_batch=128,
1361+
total_frames=512,
1362+
use_buffers=False,
12221363
)
1364+
for data in collector:
1365+
assert isinstance(data, LazyStackedTensorDict)
1366+
assert isinstance(data.reshape(-1).get("text_response"), NonTensorStack)
12231367

12241368

12251369
if __name__ == "__main__":

torchrl/envs/custom/llm.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from tensordict import (
1212
is_leaf_nontensor,
13+
LazyStackedTensorDict,
1314
NestedKey,
1415
TensorDict,
1516
TensorDictBase,
@@ -266,6 +267,7 @@ def from_dataloader(
266267
stack_method: Callable[[Any], Any]
267268
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
268269
repeats: int | None = None,
270+
group_repeats: bool = False,
269271
) -> LLMEnv:
270272
"""Creates an LLMEnv instance from a dataloader.
271273
@@ -331,6 +333,8 @@ def from_dataloader(
331333
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
332334
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
333335
samples (rather than an advantage module).
336+
group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
337+
all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``.
334338
335339
Returns:
336340
LLMEnv: The created LLMEnv instance.
@@ -398,6 +402,8 @@ def from_dataloader(
398402
stack_method=stack_method,
399403
repeats=repeats,
400404
device=device,
405+
group_repeats=group_repeats,
406+
batch_size=batch_size,
401407
)
402408
env = LLMEnv(
403409
str2str=str2str,
@@ -411,7 +417,7 @@ def from_dataloader(
411417
no_stack=no_stack,
412418
assign_reward=assign_reward,
413419
assign_done=assign_done,
414-
batch_size=batch_size if batch_size is not None else primer.batch_size,
420+
batch_size=primer.batch_size,
415421
has_attention=has_attention,
416422
as_llm_data=as_llm_data,
417423
)
@@ -565,6 +571,8 @@ def check_str():
565571
f"{list(tensordict.keys(True, True, is_leaf=is_leaf_nontensor))}. Make sure a TensorDictPrimer (eg, "
566572
f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms."
567573
)
574+
if not isinstance(tensordict, LazyStackedTensorDict):
575+
tensordict = LazyStackedTensorDict(*tensordict.unbind(0))
568576
td_reset = tensordict.copy()
569577
if td_reset.device != self.device:
570578
if self.device is None:

torchrl/envs/transforms/llm.py

+3
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class DataLoadingPrimer(TensorDictPrimer):
123123
.. note:: The batch-size of the Primer must match the batch-size of the parent environment (typically a
124124
wrapper around :class:`~torchrl.envs.LLMEnv`).
125125
126+
group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
127+
all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``.
128+
126129
Attributes:
127130
dataloader (Iterable[Any]): The dataloader to load data from.
128131
endless_dataloader (Iterable[Any]): An endless iterator over the dataloader.

torchrl/envs/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def _set_single_key(
508508
if isinstance(key, str):
509509
key = (key,)
510510
for k in key:
511+
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
511512
try:
512513
val = source._get_str(k, None)
513514
if is_tensor_collection(val):
@@ -528,7 +529,7 @@ def _set_single_key(
528529
# This is a temporary solution to understand if a key is heterogeneous
529530
# while not having performance impact when the exception is not raised
530531
except RuntimeError as err:
531-
if re.match(r"Found more than one unique shape in the tensors", str(err)):
532+
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
532533
# this is a het key
533534
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):
534535
_set_single_key(s_td, d_td, k, clone=clone, device=device)
@@ -541,6 +542,7 @@ def _set(source, dest, key, total_key, excluded):
541542
total_key = total_key + (key,)
542543
non_empty = False
543544
if unravel_key(total_key) not in excluded:
545+
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
544546
try:
545547
val = source.get(key)
546548
if is_tensor_collection(val) and not isinstance(
@@ -571,7 +573,7 @@ def _set(source, dest, key, total_key, excluded):
571573
# This is a temporary solution to understand if a key is heterogeneous
572574
# while not having performance impact when the exception is not raised
573575
except RuntimeError as err:
574-
if re.match(r"Found more than one unique shape in the tensors", str(err)):
576+
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
575577
# this is a het key
576578
non_empty_local = False
577579
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):

0 commit comments

Comments
 (0)