Skip to content

Commit 6ce4708

Browse files
committed
[Feature] Allow using lists of tensors in vllm instead of padded tensors
ghstack-source-id: cab5367 Pull Request resolved: #2861
1 parent 1d5cba6 commit 6ce4708

File tree

5 files changed

+395
-124
lines changed

5 files changed

+395
-124
lines changed

test/test_actors.py

+140-9
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@
1010

1111
import pytest
1212
import torch
13-
from tensordict import NonTensorStack, TensorDict
13+
from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict
1414
from tensordict.nn import CompositeDistribution, TensorDictModule
1515
from tensordict.nn.distributions import NormalParamExtractor
1616

1717
from torch import distributions as dist, nn
18+
19+
from torchrl.collectors import SyncDataCollector
1820
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
1921
from torchrl.data.llm import LLMData
2022
from torchrl.data.llm.dataset import _has_transformers
23+
from torchrl.envs import LLMEnv
2124
from torchrl.modules import (
2225
from_hf_transformers,
2326
from_vllm,
@@ -42,10 +45,10 @@
4245

4346
if os.getenv("PYTORCH_TEST_FBCODE"):
4447
from pytorch.rl.test._utils_internal import get_default_devices
45-
from pytorch.rl.test.mocking_classes import NestedCountingEnv
48+
from pytorch.rl.test.mocking_classes import DummyStrDataLoader, NestedCountingEnv
4649
else:
4750
from _utils_internal import get_default_devices
48-
from mocking_classes import NestedCountingEnv
51+
from mocking_classes import DummyStrDataLoader, NestedCountingEnv
4952

5053
_has_vllm = importlib.util.find_spec("vllm") is not None
5154

@@ -1122,6 +1125,8 @@ def _run_check(
11221125

11231126
# If from text and not generating, the tokens are not returned for now
11241127
if not (from_text and not generate):
1128+
assert td.tokens_response is not None
1129+
assert td.tokens is not None
11251130
assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
11261131
# The convention is that the response only has new tokens
11271132
assert (
@@ -1166,26 +1171,34 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
11661171
)
11671172

11681173
@pytest.mark.parametrize(
1169-
"from_text, tokens, attention_mask",
1174+
"pad_output, from_text, tokens, attention_mask",
11701175
[
1171-
(True, None, None),
1176+
(True, True, None, None),
1177+
(False, True, None, None),
11721178
(
1179+
True,
11731180
False,
11741181
torch.randint(1024, (1, 10)),
11751182
torch.ones(1, 10, dtype=torch.int64),
11761183
),
1177-
(False, torch.randint(1024, (1, 10)), None),
1184+
(True, False, torch.randint(1024, (1, 10)), None),
11781185
],
11791186
)
1180-
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
1187+
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask, pad_output):
11811188
torch.manual_seed(0)
11821189
from vllm import LLM
11831190

11841191
model = LLM(model="facebook/opt-125m")
11851192
m_generate = from_vllm(
1186-
model, from_text=from_text, generate=True, return_log_probs=True
1193+
model,
1194+
from_text=from_text,
1195+
generate=True,
1196+
return_log_probs=True,
1197+
pad_output=pad_output,
1198+
)
1199+
m_logprobs = from_vllm(
1200+
model, from_text=from_text, generate=False, pad_output=pad_output
11871201
)
1188-
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
11891202
self._check_lps(
11901203
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
11911204
)
@@ -1221,6 +1234,124 @@ def _check_lps(
12211234
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
12221235
)
12231236

1237+
@pytest.fixture(scope="module")
1238+
def llm_model(self):
1239+
import vllm
1240+
1241+
llm_model = vllm.LLM("gpt2")
1242+
tokenizer = llm_model.get_tokenizer()
1243+
tokenizer.pad_token = tokenizer.eos_token
1244+
return llm_model
1245+
1246+
@pytest.mark.parametrize("pad", [True, False])
1247+
@pytest.mark.parametrize("generate", [True, False])
1248+
@pytest.mark.parametrize("use_tensorclass", [True, False])
1249+
def test_vllm_batch_run(self, pad, generate, use_tensorclass, llm_model):
1250+
# Test generate - padding combinations
1251+
policy = from_vllm(
1252+
llm_model,
1253+
from_text=True,
1254+
generate=generate,
1255+
return_log_probs=True,
1256+
pad_output=pad,
1257+
generate_kwargs={"max_tokens": 10000},
1258+
)
1259+
if generate:
1260+
data = LazyStackedTensorDict(
1261+
*TensorDict(
1262+
text=NonTensorStack("a string", "another very long string"),
1263+
batch_size=[2],
1264+
).unbind(0)
1265+
)
1266+
else:
1267+
data = LazyStackedTensorDict(
1268+
*TensorDict(
1269+
text=NonTensorStack("a string", "another very long string"),
1270+
text_response=NonTensorStack(
1271+
" is a string", " is still a very long string"
1272+
),
1273+
batch_size=[2],
1274+
).unbind(0)
1275+
)
1276+
if use_tensorclass:
1277+
data = LLMData.from_tensordict(data)
1278+
output = policy(data)
1279+
try:
1280+
log_probs = output.get("log_probs")
1281+
except Exception:
1282+
log_probs = output.get("log_probs", as_list=True)
1283+
if pad:
1284+
assert isinstance(log_probs, torch.Tensor)
1285+
else:
1286+
assert isinstance(log_probs, list)
1287+
text = output.get("text", as_list=True)
1288+
# TODO: this is not ideal...
1289+
if use_tensorclass:
1290+
assert isinstance(text, list)
1291+
else:
1292+
assert isinstance(text, NonTensorStack)
1293+
text_response = output.get("text_response", as_list=True)
1294+
if use_tensorclass:
1295+
assert isinstance(text_response, list)
1296+
else:
1297+
assert isinstance(text_response, NonTensorStack)
1298+
try:
1299+
tokens_response = output.get("tokens_response")
1300+
except Exception:
1301+
tokens_response = output.get("tokens_response", as_list=True)
1302+
if pad:
1303+
assert isinstance(tokens_response, torch.Tensor)
1304+
else:
1305+
assert isinstance(tokens_response, list)
1306+
try:
1307+
tokens = output.get("tokens")
1308+
except Exception:
1309+
tokens = output.get("tokens", as_list=True)
1310+
if not generate:
1311+
assert tokens is None
1312+
elif pad:
1313+
assert isinstance(tokens, torch.Tensor), tokens
1314+
else:
1315+
assert isinstance(tokens, list)
1316+
1317+
def test_vllm_collection(self):
1318+
from vllm import LLM
1319+
1320+
llm = LLM("gpt2")
1321+
policy = from_vllm(
1322+
llm,
1323+
from_text=True,
1324+
generate=True,
1325+
return_log_probs=True,
1326+
pad_output=False,
1327+
generate_kwargs={"max_tokens": 10},
1328+
)
1329+
self._run_check_collector(policy)
1330+
1331+
def test_transformers_collection(self):
1332+
...
1333+
1334+
@classmethod
1335+
def env_constructor(cls):
1336+
dl = DummyStrDataLoader(batch_size=32)
1337+
env = LLMEnv.from_dataloader(
1338+
dl, batch_size=16, repeats=4, str2str=True, group_repeats=True
1339+
)
1340+
assert env.batch_size == (64,)
1341+
return env
1342+
1343+
def _run_check_collector(self, policy):
1344+
collector = SyncDataCollector(
1345+
self.env_constructor,
1346+
policy=policy,
1347+
frames_per_batch=128,
1348+
total_frames=512,
1349+
use_buffers=False,
1350+
)
1351+
for data in collector:
1352+
assert isinstance(data, LazyStackedTensorDict)
1353+
assert isinstance(data.reshape(-1).get("text_response"), NonTensorStack)
1354+
12241355

12251356
if __name__ == "__main__":
12261357
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/envs/custom/llm.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from tensordict import (
1212
is_leaf_nontensor,
13+
LazyStackedTensorDict,
1314
NestedKey,
1415
TensorDict,
1516
TensorDictBase,
@@ -266,6 +267,7 @@ def from_dataloader(
266267
stack_method: Callable[[Any], Any]
267268
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
268269
repeats: int | None = None,
270+
group_repeats: bool = False,
269271
) -> LLMEnv:
270272
"""Creates an LLMEnv instance from a dataloader.
271273
@@ -331,6 +333,8 @@ def from_dataloader(
331333
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
332334
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
333335
samples (rather than an advantage module).
336+
group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
337+
all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``.
334338
335339
Returns:
336340
LLMEnv: The created LLMEnv instance.
@@ -398,6 +402,8 @@ def from_dataloader(
398402
stack_method=stack_method,
399403
repeats=repeats,
400404
device=device,
405+
group_repeats=group_repeats,
406+
batch_size=batch_size,
401407
)
402408
env = LLMEnv(
403409
str2str=str2str,
@@ -411,7 +417,7 @@ def from_dataloader(
411417
no_stack=no_stack,
412418
assign_reward=assign_reward,
413419
assign_done=assign_done,
414-
batch_size=batch_size if batch_size is not None else primer.batch_size,
420+
batch_size=primer.batch_size,
415421
has_attention=has_attention,
416422
as_llm_data=as_llm_data,
417423
)
@@ -565,6 +571,8 @@ def check_str():
565571
f"{list(tensordict.keys(True, True, is_leaf=is_leaf_nontensor))}. Make sure a TensorDictPrimer (eg, "
566572
f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms."
567573
)
574+
if not isinstance(tensordict, LazyStackedTensorDict):
575+
tensordict = LazyStackedTensorDict(*tensordict.unbind(0))
568576
td_reset = tensordict.copy()
569577
if td_reset.device != self.device:
570578
if self.device is None:

torchrl/envs/transforms/llm.py

+3
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ class DataLoadingPrimer(TensorDictPrimer):
123123
.. note:: The batch-size of the Primer must match the batch-size of the parent environment (typically a
124124
wrapper around :class:`~torchrl.envs.LLMEnv`).
125125
126+
group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
127+
all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``.
128+
126129
Attributes:
127130
dataloader (Iterable[Any]): The dataloader to load data from.
128131
endless_dataloader (Iterable[Any]): An endless iterator over the dataloader.

torchrl/envs/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def _set_single_key(
508508
if isinstance(key, str):
509509
key = (key,)
510510
for k in key:
511+
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
511512
try:
512513
val = source._get_str(k, None)
513514
if is_tensor_collection(val):
@@ -528,7 +529,7 @@ def _set_single_key(
528529
# This is a temporary solution to understand if a key is heterogeneous
529530
# while not having performance impact when the exception is not raised
530531
except RuntimeError as err:
531-
if re.match(r"Found more than one unique shape in the tensors", str(err)):
532+
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
532533
# this is a het key
533534
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):
534535
_set_single_key(s_td, d_td, k, clone=clone, device=device)
@@ -541,6 +542,7 @@ def _set(source, dest, key, total_key, excluded):
541542
total_key = total_key + (key,)
542543
non_empty = False
543544
if unravel_key(total_key) not in excluded:
545+
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
544546
try:
545547
val = source.get(key)
546548
if is_tensor_collection(val) and not isinstance(
@@ -571,7 +573,7 @@ def _set(source, dest, key, total_key, excluded):
571573
# This is a temporary solution to understand if a key is heterogeneous
572574
# while not having performance impact when the exception is not raised
573575
except RuntimeError as err:
574-
if re.match(r"Found more than one unique shape in the tensors", str(err)):
576+
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
575577
# this is a het key
576578
non_empty_local = False
577579
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):

0 commit comments

Comments
 (0)