Skip to content

[Feature] Allow using lists of tensors in vllm instead of padded tensors #2861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 160 additions & 16 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@

import pytest
import torch
from tensordict import NonTensorStack, TensorDict
from tensordict import LazyStackedTensorDict, NonTensorStack, TensorDict
from tensordict.nn import CompositeDistribution, TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

from torch import distributions as dist, nn

from torchrl.collectors import SyncDataCollector
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
from torchrl.data.llm import LLMData
from torchrl.data.llm.dataset import _has_transformers
from torchrl.envs import LLMEnv
from torchrl.modules import (
from_hf_transformers,
from_vllm,
Expand All @@ -42,10 +45,10 @@

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import get_default_devices
from pytorch.rl.test.mocking_classes import NestedCountingEnv
from pytorch.rl.test.mocking_classes import DummyStrDataLoader, NestedCountingEnv
else:
from _utils_internal import get_default_devices
from mocking_classes import NestedCountingEnv
from mocking_classes import DummyStrDataLoader, NestedCountingEnv

_has_vllm = importlib.util.find_spec("vllm") is not None

Expand Down Expand Up @@ -922,6 +925,18 @@ def test_lmhead_actorvalueoperator(device):
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
class TestLLMActor:
@pytest.fixture(scope="module")
def vllm_instance(self):
try:
import vllm
except ImportError:
pytest.skip(reason="missing vllm")

llm_model = vllm.LLM("gpt2")
tokenizer = llm_model.get_tokenizer()
tokenizer.pad_token = tokenizer.eos_token
return llm_model

@pytest.mark.parametrize(
"from_text, generate, return_log_probs, tokens, attention_mask",
[
Expand Down Expand Up @@ -1005,12 +1020,17 @@ def test_from_hf_transformers(
],
)
def test_from_vllm(
self, from_text, generate, return_log_probs, tokens, attention_mask
self,
from_text,
generate,
return_log_probs,
tokens,
attention_mask,
vllm_instance,
):
torch.manual_seed(0)
from vllm import LLM

model = LLM(model="facebook/opt-125m")
model = vllm_instance
m = from_vllm(
model,
from_text=from_text,
Expand Down Expand Up @@ -1122,6 +1142,8 @@ def _run_check(

# If from text and not generating, the tokens are not returned for now
if not (from_text and not generate):
assert td.tokens_response is not None
assert td.tokens is not None
assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
# The convention is that the response only has new tokens
assert (
Expand Down Expand Up @@ -1166,28 +1188,43 @@ def test_from_hf_logprobs(self, from_text, tokens, attention_mask):
)

@pytest.mark.parametrize(
"from_text, tokens, attention_mask",
"pad_output, from_text, tokens, attention_mask",
[
(True, None, None),
(True, True, None, None),
(False, True, None, None),
(
True,
False,
torch.randint(1024, (1, 10)),
torch.ones(1, 10, dtype=torch.int64),
),
(False, torch.randint(1024, (1, 10)), None),
(True, False, torch.randint(1024, (1, 10)), None),
],
)
def test_from_vllm_logprobs(self, from_text, tokens, attention_mask):
def test_from_vllm_logprobs(
self, from_text, tokens, attention_mask, pad_output, vllm_instance
):
torch.manual_seed(0)
from vllm import LLM

model = LLM(model="facebook/opt-125m")
model = vllm_instance
m_generate = from_vllm(
model, from_text=from_text, generate=True, return_log_probs=True
model,
from_text=from_text,
generate=True,
return_log_probs=True,
pad_output=pad_output,
)
m_logprobs = from_vllm(
model, from_text=from_text, generate=False, pad_output=pad_output
)
m_logprobs = from_vllm(model, from_text=from_text, generate=False)
self._check_lps(
m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False
m_generate,
m_logprobs,
tokens,
attention_mask,
from_text,
has_logits=False,
tol=1e-1,
)

def _check_lps(
Expand All @@ -1198,6 +1235,7 @@ def _check_lps(
attention_mask,
from_text,
has_logits,
tol=1e-2,
):
# Checks that the log-probs gathered with generate=False equate those with generate=True
tdin_genetate = self._make_data(
Expand All @@ -1218,8 +1256,114 @@ def _check_lps(
assert td_generate.log_probs.shape == td_generate.tokens_response.shape
assert td_logprobs.log_probs.shape == td_generate.tokens_response.shape
torch.testing.assert_close(
td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2
td_generate.log_probs, td_logprobs.log_probs, rtol=tol, atol=tol
)

@pytest.mark.parametrize("pad", [True, False])
@pytest.mark.parametrize("generate", [True, False])
@pytest.mark.parametrize("use_tensorclass", [True, False])
def test_vllm_batch_run(self, pad, generate, use_tensorclass, vllm_instance):
# Test generate - padding combinations
policy = from_vllm(
vllm_instance,
from_text=True,
generate=generate,
return_log_probs=True,
pad_output=pad,
generate_kwargs={"max_tokens": 10000},
)
if generate:
data = LazyStackedTensorDict(
*TensorDict(
text=NonTensorStack("a string", "another very long string"),
batch_size=[2],
).unbind(0)
)
else:
data = LazyStackedTensorDict(
*TensorDict(
text=NonTensorStack("a string", "another very long string"),
text_response=NonTensorStack(
" is a string", " is still a very long string"
),
batch_size=[2],
).unbind(0)
)
if use_tensorclass:
data = LLMData.from_tensordict(data)
output = policy(data)
try:
log_probs = output.get("log_probs")
except Exception:
log_probs = output.get("log_probs", as_list=True)
if pad:
assert isinstance(log_probs, torch.Tensor)
else:
assert isinstance(log_probs, list)
text = output.get("text", as_list=True)
# TODO: this is not ideal...
if use_tensorclass:
assert isinstance(text, list)
else:
assert isinstance(text, NonTensorStack)
text_response = output.get("text_response", as_list=True)
if use_tensorclass:
assert isinstance(text_response, list)
else:
assert isinstance(text_response, NonTensorStack)
try:
tokens_response = output.get("tokens_response")
except Exception:
tokens_response = output.get("tokens_response", as_list=True)
if pad:
assert isinstance(tokens_response, torch.Tensor)
else:
assert isinstance(tokens_response, list)
try:
tokens = output.get("tokens")
except Exception:
tokens = output.get("tokens", as_list=True)
if not generate:
assert tokens is None
elif pad:
assert isinstance(tokens, torch.Tensor), tokens
else:
assert isinstance(tokens, list)

def test_vllm_collection(self, vllm_instance):
policy = from_vllm(
vllm_instance,
from_text=True,
generate=True,
return_log_probs=True,
pad_output=False,
generate_kwargs={"max_tokens": 10},
)
self._run_check_collector(policy)

def test_transformers_collection(self):
...

@classmethod
def env_constructor(cls):
dl = DummyStrDataLoader(batch_size=32)
env = LLMEnv.from_dataloader(
dl, batch_size=16, repeats=4, str2str=True, group_repeats=True
)
assert env.batch_size == (64,)
return env

def _run_check_collector(self, policy):
collector = SyncDataCollector(
self.env_constructor,
policy=policy,
frames_per_batch=128,
total_frames=512,
use_buffers=False,
)
for data in collector:
assert isinstance(data, LazyStackedTensorDict)
assert isinstance(data.reshape(-1).get("text_response"), NonTensorStack)


if __name__ == "__main__":
Expand Down
10 changes: 9 additions & 1 deletion torchrl/envs/custom/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from tensordict import (
is_leaf_nontensor,
LazyStackedTensorDict,
NestedKey,
TensorDict,
TensorDictBase,
Expand Down Expand Up @@ -266,6 +267,7 @@ def from_dataloader(
stack_method: Callable[[Any], Any]
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
repeats: int | None = None,
group_repeats: bool = False,
) -> LLMEnv:
"""Creates an LLMEnv instance from a dataloader.

Expand Down Expand Up @@ -331,6 +333,8 @@ def from_dataloader(
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
samples (rather than an advantage module).
group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``.

Returns:
LLMEnv: The created LLMEnv instance.
Expand Down Expand Up @@ -398,6 +402,8 @@ def from_dataloader(
stack_method=stack_method,
repeats=repeats,
device=device,
group_repeats=group_repeats,
batch_size=batch_size,
)
env = LLMEnv(
str2str=str2str,
Expand All @@ -411,7 +417,7 @@ def from_dataloader(
no_stack=no_stack,
assign_reward=assign_reward,
assign_done=assign_done,
batch_size=batch_size if batch_size is not None else primer.batch_size,
batch_size=primer.batch_size,
has_attention=has_attention,
as_llm_data=as_llm_data,
)
Expand Down Expand Up @@ -565,6 +571,8 @@ def check_str():
f"{list(tensordict.keys(True, True, is_leaf=is_leaf_nontensor))}. Make sure a TensorDictPrimer (eg, "
f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms."
)
if not isinstance(tensordict, LazyStackedTensorDict):
tensordict = LazyStackedTensorDict(*tensordict.unbind(0))
td_reset = tensordict.copy()
if td_reset.device != self.device:
if self.device is None:
Expand Down
3 changes: 3 additions & 0 deletions torchrl/envs/transforms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ class DataLoadingPrimer(TensorDictPrimer):
.. note:: The batch-size of the Primer must match the batch-size of the parent environment (typically a
wrapper around :class:`~torchrl.envs.LLMEnv`).

group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``.

Attributes:
dataloader (Iterable[Any]): The dataloader to load data from.
endless_dataloader (Iterable[Any]): An endless iterator over the dataloader.
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def _set_single_key(
if isinstance(key, str):
key = (key,)
for k in key:
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
try:
val = source._get_str(k, None)
if is_tensor_collection(val):
Expand All @@ -528,7 +529,7 @@ def _set_single_key(
# This is a temporary solution to understand if a key is heterogeneous
# while not having performance impact when the exception is not raised
except RuntimeError as err:
if re.match(r"Found more than one unique shape in the tensors", str(err)):
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
# this is a het key
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):
_set_single_key(s_td, d_td, k, clone=clone, device=device)
Expand All @@ -541,6 +542,7 @@ def _set(source, dest, key, total_key, excluded):
total_key = total_key + (key,)
non_empty = False
if unravel_key(total_key) not in excluded:
# TODO: we can do better than try/except by leveraging the as_list / as_nested_tensor feature
try:
val = source.get(key)
if is_tensor_collection(val) and not isinstance(
Expand Down Expand Up @@ -571,7 +573,7 @@ def _set(source, dest, key, total_key, excluded):
# This is a temporary solution to understand if a key is heterogeneous
# while not having performance impact when the exception is not raised
except RuntimeError as err:
if re.match(r"Found more than one unique shape in the tensors", str(err)):
if re.match(r"Failed to stack tensors within a tensordict", str(err)):
# this is a het key
non_empty_local = False
for s_td, d_td in zip(source.tensordicts, dest.tensordicts):
Expand Down
Loading
Loading