Skip to content

Commit afc5d59

Browse files
committed
[Quality] Better defaults for vllm wrapper and LLMEnv
ghstack-source-id: 88d6f0b Pull Request resolved: #2874
1 parent 33051e7 commit afc5d59

File tree

4 files changed

+40
-21
lines changed

4 files changed

+40
-21
lines changed

test/test_actors.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -1333,10 +1333,7 @@ def test_vllm_batch_run(self, pad, generate, use_tensorclass, vllm_instance):
13331333
def test_vllm_collection(self, vllm_instance):
13341334
policy = vLLMWrapper(
13351335
vllm_instance,
1336-
from_text=True,
1337-
generate=True,
13381336
return_log_probs=True,
1339-
pad_output=False,
13401337
generate_kwargs={"max_tokens": 10},
13411338
)
13421339
self._run_check_collector(policy)
@@ -1348,7 +1345,10 @@ def test_transformers_collection(self):
13481345
def env_constructor(cls):
13491346
dl = DummyStrDataLoader(batch_size=32)
13501347
env = LLMEnv.from_dataloader(
1351-
dl, batch_size=16, repeats=4, str2str=True, group_repeats=True
1348+
dl,
1349+
batch_size=16,
1350+
repeats=4,
1351+
# str2str=True, group_repeats=True
13521352
)
13531353
assert env.batch_size == (64,)
13541354
return env
@@ -1364,6 +1364,15 @@ def _run_check_collector(self, policy):
13641364
for data in collector:
13651365
assert isinstance(data, LazyStackedTensorDict)
13661366
assert isinstance(data.reshape(-1).get("text_response"), NonTensorStack)
1367+
# action
1368+
assert "text_response" in data
1369+
assert "tokens_response" in data
1370+
# obs
1371+
assert "text" in data
1372+
assert ("next", "text") in data
1373+
# tokens
1374+
assert "tokens" in data
1375+
# assert ("next", "tokens") in data
13671376

13681377

13691378
if __name__ == "__main__":

torchrl/envs/custom/llm.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class LLMEnv(EnvBase):
4848
Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via
4949
:meth:`~from_dataloader`.
5050
51+
.. note:: The default arguments of the `LLMEnv` class are set to make it easy to run this environment with
52+
the vllm backend (:class:`~torchrl.modules.vLLMWrapper`).
53+
5154
Keyword Args:
5255
token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`).
5356
Defaults to ``"tokens"``.
@@ -59,7 +62,7 @@ class LLMEnv(EnvBase):
5962
``"tokens_response"`` or ``"text_response"``.
6063
reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`.
6164
Defaults to ``"reward"``.
62-
str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``.
65+
str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True``.
6366
device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``.
6467
vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an
6568
unbounded vocabulary. Defaults to ``None``.
@@ -102,7 +105,7 @@ def __init__(
102105
attention_key: NestedKey | None = None,
103106
action_key: NestedKey | None = None,
104107
reward_key: NestedKey = "reward",
105-
str2str: bool = False,
108+
str2str: bool = True,
106109
device: torch.device | None = None,
107110
vocab_size: int | None = None,
108111
no_stack: bool = True,
@@ -250,7 +253,7 @@ def from_dataloader(
250253
attention_key: NestedKey | None = None,
251254
action_key: NestedKey | None = None,
252255
reward_key: NestedKey = "reward",
253-
str2str: bool = False,
256+
str2str: bool = True,
254257
device: torch.device | None = None,
255258
vocab_size: int | None = None,
256259
no_stack: bool = False,
@@ -267,7 +270,7 @@ def from_dataloader(
267270
stack_method: Callable[[Any], Any]
268271
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
269272
repeats: int | None = None,
270-
group_repeats: bool = False,
273+
group_repeats: bool = True,
271274
) -> LLMEnv:
272275
"""Creates an LLMEnv instance from a dataloader.
273276
@@ -297,7 +300,7 @@ def from_dataloader(
297300
``("tokens_out", "sequences")``.
298301
reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`.
299302
Defaults to ``"reward"``.
300-
str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``.
303+
str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True``.
301304
device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``.
302305
vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an
303306
unbounded vocabulary. Defaults to ``None``.
@@ -334,7 +337,7 @@ def from_dataloader(
334337
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
335338
samples (rather than an advantage module).
336339
group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
337-
all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``.
340+
all repeats are grouped in a single batch collected from the buffer. Defaults to ``True``.
338341
339342
Returns:
340343
LLMEnv: The created LLMEnv instance.

torchrl/modules/llm/transformers_wrapper.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TransformersWrapper(CategoricalSequential):
3131
encoding and decoding text. If `None`, the tokenizer associated with the model will be used. Defaults to
3232
`None`.
3333
from_text (bool, optional): Indicates whether the input is in text format. If `True`, the input is expected to
34-
be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `False`.
34+
be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `True`.
3535
device (torch.device | None, optional): The device to use for computation. If `None`, the default device will
3636
be used. Defaults to `None`.
3737
generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on
@@ -86,8 +86,8 @@ class TransformersWrapper(CategoricalSequential):
8686
>>> output_data = wrapper(input_data)
8787
>>> print(output_data["text_response"])
8888
89-
.. seealso:: :func:`~torchrl.modules.from_hf_transformers` for a similar interface using the Hugging Face
90-
Transformers library.
89+
.. seealso:: :func:`~torchrl.modules.vLLMWrapper` for a similar interface using vLLM.
90+
9191
"""
9292

9393
text_key: NestedKey = ("text",)
@@ -105,7 +105,7 @@ def __init__(
105105
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer # noqa
106106
| None = None,
107107
# noqa
108-
from_text: bool = False,
108+
from_text: bool = True,
109109
device: torch.device | None = None,
110110
generate: bool = True,
111111
generate_kwargs: dict | None = None,

torchrl/modules/llm/vllm_wrapper.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
LazyStackedTensorDict,
1414
NestedKey,
1515
TensorDict,
16-
TensorDictBase,
16+
TensorDictBase, maybe_dense_stack,
1717
)
1818
from tensordict.tensorclass import from_dataclass, NonTensorStack, TensorClass
1919
from tensordict.utils import _zip_strict, expand_as_right
@@ -28,6 +28,9 @@ class vLLMWrapper(CategoricalSequential):
2828
This class allows for handling both text and token inputs, enabling text generation and log probability
2929
computation based on the specified configuration.
3030
31+
.. note:: The default arguments of the `vLLMWrapper` class are set to make it easy to run this backend with
32+
the :class:`~torchrl.envs.custom.llm.LLMEnv` class.
33+
3134
Args:
3235
model (vllm.LLM): The vLLM model to wrap.
3336
@@ -38,7 +41,7 @@ class vLLMWrapper(CategoricalSequential):
3841
encoding and decoding text. If `None`, the tokenizer associated with the model will be used. Defaults to
3942
`None`.
4043
from_text (bool, optional): Indicates whether the input is in text format. If `True`, the input is expected to
41-
be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `False`.
44+
be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `True`.
4245
device (torch.device | None, optional): The device to use for computation. If `None`, the default device will
4346
be used. Defaults to `None`.
4447
generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on
@@ -50,7 +53,11 @@ class vLLMWrapper(CategoricalSequential):
5053
control aspects of the tokenization process, such as padding and truncation. Defaults to `None`.
5154
pad_output (bool, optional): Whether to pad the output sequences to a uniform length. If `True`, the output
5255
sequences will be padded and represented as tensors. If `False`, lists of tokens will be used without
53-
padding. Defaults to `True`.
56+
padding. Defaults to `False`.
57+
58+
.. warning:: The default value of `pad_output` differs from :func:`~torchrl.modules.TransformersWrapper`
59+
which does not handle non-padded inputs.
60+
5461
inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place
5562
operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be
5663
created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will
@@ -93,7 +100,7 @@ class vLLMWrapper(CategoricalSequential):
93100
>>> output_data = wrapper(input_data)
94101
>>> print(output_data.text_response)
95102
96-
.. seealso:: :func:`~torchrl.modules.from_hf_transformers` for a similar interface using the Hugging Face
103+
.. seealso:: :func:`~torchrl.modules.TransformersWrapper` for a similar interface using the Hugging Face
97104
Transformers library.
98105
"""
99106

@@ -112,12 +119,12 @@ def __init__(
112119
tokenizer: transformers.tokenization_utils.PreTrainedTokenizer # noqa
113120
| None = None,
114121
# noqa
115-
from_text: bool = False,
122+
from_text: bool = True,
116123
device: torch.device | None = None,
117124
generate: bool = True,
118125
generate_kwargs: dict | None = None,
119126
tokenizer_kwargs: dict | None = None,
120-
pad_output: bool = True,
127+
pad_output: bool = False,
121128
inplace: Literal[True, False, "empty"] | None = True,
122129
):
123130
super().__init__()
@@ -545,7 +552,7 @@ def get_logprob(output):
545552
if len(outputs) == 1:
546553
self.outputs = outputs[0]
547554
else:
548-
self.outputs = torch.stack(outputs)
555+
self.outputs = maybe_dense_stack(outputs)
549556
self.prompt_logprobs = torch.tensor(
550557
[
551558
v[tid].logprob if v is not None else 0.0

0 commit comments

Comments
 (0)