Skip to content

Commit ed7512f

Browse files
committed
[Feature] Allow using lists of tensors in vllm instead of padded tensors
ghstack-source-id: c037e99 Pull Request resolved: #2861
1 parent 7b6e9a8 commit ed7512f

File tree

3 files changed

+333
-121
lines changed

3 files changed

+333
-121
lines changed

test/test_actors.py

+97-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import pytest
1212
import torch
13-
from tensordict import NonTensorStack, TensorDict
13+
from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict
1414
from tensordict.nn import CompositeDistribution, TensorDictModule
1515
from tensordict.nn.distributions import NormalParamExtractor
1616

@@ -1122,6 +1122,8 @@ def _run_check(
11221122

11231123
# If from text and not generating, the tokens are not returned for now
11241124
if not (from_text and not generate):
1125+
assert td.tokens_response is not None
1126+
assert td.tokens is not None
11251127
assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
11261128
# The convention is that the response only has new tokens
11271129
assert (
@@ -1166,26 +1168,34 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
11661168
)
11671169

11681170
@pytest.mark.parametrize(
1169-
"from_text, tokens, attention_mask",
1171+
"pad_output, from_text, tokens, attention_mask",
11701172
[
1171-
(True, None, None),
1173+
(True, True, None, None),
1174+
(False, True, None, None),
11721175
(
1176+
True,
11731177
False,
11741178
torch.randint(1024, (1, 10)),
11751179
torch.ones(1, 10, dtype=torch.int64),
11761180
),
1177-
(False, torch.randint(1024, (1, 10)), None),
1181+
(True, False, torch.randint(1024, (1, 10)), None),
11781182
],
11791183
)
1180-
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
1184+
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask, pad_output):
11811185
torch.manual_seed(0)
11821186
from vllm import LLM
11831187

11841188
model = LLM(model="facebook/opt-125m")
11851189
m_generate = from_vllm(
1186-
model, from_text=from_text, generate=True, return_log_probs=True
1190+
model,
1191+
from_text=from_text,
1192+
generate=True,
1193+
return_log_probs=True,
1194+
pad_output=pad_output,
1195+
)
1196+
m_logprobs = from_vllm(
1197+
model, from_text=from_text, generate=False, pad_output=pad_output
11871198
)
1188-
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
11891199
self._check_lps(
11901200
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
11911201
)
@@ -1221,6 +1231,86 @@ def _check_lps(
12211231
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
12221232
)
12231233

1234+
@pytest.fixture(scope="module")
1235+
def llm_model(self):
1236+
import vllm
1237+
1238+
llm_model = vllm.LLM("gpt2")
1239+
tokenizer = llm_model.get_tokenizer()
1240+
tokenizer.pad_token = tokenizer.eos_token
1241+
return llm_model
1242+
1243+
@pytest.mark.parametrize("pad", [True, False])
1244+
@pytest.mark.parametrize("generate", [True, False])
1245+
@pytest.mark.parametrize("use_tensorclass", [True, False])
1246+
def test_vllm_batch_run(self, pad, generate, use_tensorclass, llm_model):
1247+
# Test generate - padding combinations
1248+
policy = from_vllm(
1249+
llm_model,
1250+
from_text=True,
1251+
generate=generate,
1252+
return_log_probs=True,
1253+
pad_output=pad,
1254+
generate_kwargs={"max_tokens": 10000},
1255+
)
1256+
if generate:
1257+
data = LazyStackedTensorDict(
1258+
*TensorDict(
1259+
text=NonTensorStack("a string", "another very long string"),
1260+
batch_size=[2],
1261+
).unbind(0)
1262+
)
1263+
else:
1264+
data = LazyStackedTensorDict(
1265+
*TensorDict(
1266+
text=NonTensorStack("a string", "another very long string"),
1267+
text_response=NonTensorStack(
1268+
" is a string", " is still a very long string"
1269+
),
1270+
batch_size=[2],
1271+
).unbind(0)
1272+
)
1273+
if use_tensorclass:
1274+
data = LLMData.from_tensordict(data)
1275+
output = policy(data)
1276+
try:
1277+
log_probs = output.get("log_probs")
1278+
except Exception:
1279+
log_probs = output.get("log_probs", as_list=True)
1280+
if pad:
1281+
assert isinstance(log_probs, torch.Tensor)
1282+
else:
1283+
assert isinstance(log_probs, list)
1284+
text = output.get("text", as_list=True)
1285+
# TODO: this is not ideal...
1286+
if use_tensorclass:
1287+
assert isinstance(text, list)
1288+
else:
1289+
assert isinstance(text, NonTensorStack)
1290+
text_response = output.get("text_response", as_list=True)
1291+
if use_tensorclass:
1292+
assert isinstance(text_response, list)
1293+
else:
1294+
assert isinstance(text_response, NonTensorStack)
1295+
try:
1296+
tokens_response = output.get("tokens_response")
1297+
except Exception:
1298+
tokens_response = output.get("tokens_response", as_list=True)
1299+
if pad:
1300+
assert isinstance(tokens_response, torch.Tensor)
1301+
else:
1302+
assert isinstance(tokens_response, list)
1303+
try:
1304+
tokens = output.get("tokens")
1305+
except Exception:
1306+
tokens = output.get("tokens", as_list=True)
1307+
if not generate:
1308+
assert tokens is None
1309+
elif pad:
1310+
assert isinstance(tokens, torch.Tensor), tokens
1311+
else:
1312+
assert isinstance(tokens, list)
1313+
12241314

12251315
if __name__ == "__main__":
12261316
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/envs/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def _set_single_key(
508508
if isinstance(key, str):
509509
key = (key,)
510510
for k in key:
511+
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
511512
try:
512513
val = source._get_str(k, None)
513514
if is_tensor_collection(val):
@@ -528,7 +529,7 @@ def _set_single_key(
528529
# This is a temporary solution to understand if a key is heterogeneous
529530
# while not having performance impact when the exception is not raised
530531
except RuntimeError as err:
531-
if re.match(r"Found more than one unique shape in the tensors", str(err)):
532+
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
532533
# this is a het key
533534
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):
534535
_set_single_key(s_td, d_td, k, clone=clone, device=device)
@@ -541,6 +542,7 @@ def _set(source, dest, key, total_key, excluded):
541542
total_key = total_key + (key,)
542543
non_empty = False
543544
if unravel_key(total_key) not in excluded:
545+
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
544546
try:
545547
val = source.get(key)
546548
if is_tensor_collection(val) and not isinstance(
@@ -571,7 +573,7 @@ def _set(source, dest, key, total_key, excluded):
571573
# This is a temporary solution to understand if a key is heterogeneous
572574
# while not having performance impact when the exception is not raised
573575
except RuntimeError as err:
574-
if re.match(r"Found more than one unique shape in the tensors", str(err)):
576+
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
575577
# this is a het key
576578
non_empty_local = False
577579
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):

0 commit comments

Comments
 (0)