Skip to content

Commit 978746b

Browse files
committed
[Feature] Allow using lists of tensors in vllm instead of padded tensors
ghstack-source-id: 020791a Pull Request resolved: #2861
1 parent 5ab3f24 commit 978746b

File tree

3 files changed

+289
-104
lines changed

3 files changed

+289
-104
lines changed

test/test_actors.py

+87-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import pytest
1212
import torch
13-
from tensordict import NonTensorStack, TensorDict
13+
from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict
1414
from tensordict.nn import CompositeDistribution, TensorDictModule
1515
from tensordict.nn.distributions import NormalParamExtractor
1616

@@ -1122,6 +1122,8 @@ def _run_check(
11221122

11231123
# If from text and not generating, the tokens are not returned for now
11241124
if not (from_text and not generate):
1125+
assert td.tokens_response is not None
1126+
assert td.tokens is not None
11251127
assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
11261128
# The convention is that the response only has new tokens
11271129
assert (
@@ -1166,26 +1168,34 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
11661168
)
11671169

11681170
@pytest.mark.parametrize(
1169-
"from_text, tokens, attention_mask",
1171+
"pad_output, from_text, tokens, attention_mask",
11701172
[
1171-
(True, None, None),
1173+
(True, True, None, None),
1174+
(False, True, None, None),
11721175
(
1176+
True,
11731177
False,
11741178
torch.randint(1024, (1, 10)),
11751179
torch.ones(1, 10, dtype=torch.int64),
11761180
),
1177-
(False, torch.randint(1024, (1, 10)), None),
1181+
(True, False, torch.randint(1024, (1, 10)), None),
11781182
],
11791183
)
1180-
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
1184+
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask, pad_output):
11811185
torch.manual_seed(0)
11821186
from vllm import LLM
11831187

11841188
model = LLM(model="facebook/opt-125m")
11851189
m_generate = from_vllm(
1186-
model, from_text=from_text, generate=True, return_log_probs=True
1190+
model,
1191+
from_text=from_text,
1192+
generate=True,
1193+
return_log_probs=True,
1194+
pad_output=pad_output,
1195+
)
1196+
m_logprobs = from_vllm(
1197+
model, from_text=from_text, generate=False, pad_output=pad_output
11871198
)
1188-
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
11891199
self._check_lps(
11901200
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
11911201
)
@@ -1221,6 +1231,76 @@ def _check_lps(
12211231
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
12221232
)
12231233

1234+
@pytest.fixture(scope="module")
1235+
def llm_model(self):
1236+
import vllm
1237+
1238+
llm_model = vllm.LLM("gpt2")
1239+
tokenizer = llm_model.get_tokenizer()
1240+
tokenizer.pad_token = tokenizer.eos_token
1241+
return llm_model
1242+
1243+
@pytest.mark.parametrize("pad", [True, False])
1244+
@pytest.mark.parametrize("generate", [True, False])
1245+
def test_vllm_batch_run(self, pad, generate, llm_model):
1246+
# Test generate - padding combinations
1247+
policy = from_vllm(
1248+
llm_model,
1249+
from_text=True,
1250+
generate=generate,
1251+
return_log_probs=True,
1252+
pad_output=pad,
1253+
generate_kwargs={"max_tokens": 10000},
1254+
)
1255+
if generate:
1256+
data = LazyStackedTensorDict(
1257+
*TensorDict(
1258+
text=NonTensorStack("a string", "another very long string"),
1259+
batch_size=[2],
1260+
).unbind(0)
1261+
)
1262+
else:
1263+
data = LazyStackedTensorDict(
1264+
*TensorDict(
1265+
text=NonTensorStack("a string", "another very long string"),
1266+
text_response=NonTensorStack(
1267+
" is a string", " is still a very long string"
1268+
),
1269+
batch_size=[2],
1270+
).unbind(0)
1271+
)
1272+
output = policy(data)
1273+
try:
1274+
log_probs = output.get("log_probs")
1275+
except Exception:
1276+
log_probs = output.get("log_probs", as_list=True)
1277+
if pad:
1278+
assert isinstance(log_probs, torch.Tensor)
1279+
else:
1280+
assert isinstance(log_probs, list)
1281+
text = output.get("text", as_list=True)
1282+
assert isinstance(text, NonTensorStack)
1283+
text_response = output.get("text_response", as_list=True)
1284+
assert isinstance(text_response, NonTensorStack)
1285+
try:
1286+
tokens_response = output.get("tokens_response")
1287+
except Exception:
1288+
tokens_response = output.get("tokens_response", as_list=True)
1289+
if pad:
1290+
assert isinstance(tokens_response, torch.Tensor)
1291+
else:
1292+
assert isinstance(tokens_response, list)
1293+
try:
1294+
tokens = output.get("tokens")
1295+
except Exception:
1296+
tokens = output.get("tokens", as_list=True)
1297+
if not generate:
1298+
assert tokens is None
1299+
elif pad:
1300+
assert isinstance(tokens, torch.Tensor), tokens
1301+
else:
1302+
assert isinstance(tokens, list)
1303+
12241304

12251305
if __name__ == "__main__":
12261306
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/envs/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def _set_single_key(
508508
if isinstance(key, str):
509509
key = (key,)
510510
for k in key:
511+
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
511512
try:
512513
val = source._get_str(k, None)
513514
if is_tensor_collection(val):
@@ -528,7 +529,7 @@ def _set_single_key(
528529
# This is a temporary solution to understand if a key is heterogeneous
529530
# while not having performance impact when the exception is not raised
530531
except RuntimeError as err:
531-
if re.match(r"Found more than one unique shape in the tensors", str(err)):
532+
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
532533
# this is a het key
533534
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):
534535
_set_single_key(s_td, d_td, k, clone=clone, device=device)
@@ -541,6 +542,7 @@ def _set(source, dest, key, total_key, excluded):
541542
total_key = total_key + (key,)
542543
non_empty = False
543544
if unravel_key(total_key) not in excluded:
545+
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
544546
try:
545547
val = source.get(key)
546548
if is_tensor_collection(val) and not isinstance(
@@ -571,7 +573,7 @@ def _set(source, dest, key, total_key, excluded):
571573
# This is a temporary solution to understand if a key is heterogeneous
572574
# while not having performance impact when the exception is not raised
573575
except RuntimeError as err:
574-
if re.match(r"Found more than one unique shape in the tensors", str(err)):
576+
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
575577
# this is a het key
576578
non_empty_local = False
577579
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):

0 commit comments

Comments
 (0)