@@ -48,6 +48,9 @@ class LLMEnv(EnvBase):
48
48
Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via
49
49
:meth:`~from_dataloader`.
50
50
51
+ .. note:: The default arguments of the `LLMEnv` class are set to make it easy to run this environment with
52
+ the vllm backend (:class:`~torchrl.modules.vLLMWrapper`).
53
+
51
54
Keyword Args:
52
55
token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`).
53
56
Defaults to ``"tokens"``.
@@ -59,7 +62,7 @@ class LLMEnv(EnvBase):
59
62
``"tokens_response"`` or ``"text_response"``.
60
63
reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`.
61
64
Defaults to ``"reward"``.
62
- str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False ``.
65
+ str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True ``.
63
66
device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``.
64
67
vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an
65
68
unbounded vocabulary. Defaults to ``None``.
@@ -102,7 +105,7 @@ def __init__(
102
105
attention_key : NestedKey | None = None ,
103
106
action_key : NestedKey | None = None ,
104
107
reward_key : NestedKey = "reward" ,
105
- str2str : bool = False ,
108
+ str2str : bool = True ,
106
109
device : torch .device | None = None ,
107
110
vocab_size : int | None = None ,
108
111
no_stack : bool = True ,
@@ -250,7 +253,7 @@ def from_dataloader(
250
253
attention_key : NestedKey | None = None ,
251
254
action_key : NestedKey | None = None ,
252
255
reward_key : NestedKey = "reward" ,
253
- str2str : bool = False ,
256
+ str2str : bool = True ,
254
257
device : torch .device | None = None ,
255
258
vocab_size : int | None = None ,
256
259
no_stack : bool = False ,
@@ -267,7 +270,7 @@ def from_dataloader(
267
270
stack_method : Callable [[Any ], Any ]
268
271
| Literal ["as_nested_tensor" , "as_padded_tensor" ] = None ,
269
272
repeats : int | None = None ,
270
- group_repeats : bool = False ,
273
+ group_repeats : bool = True ,
271
274
) -> LLMEnv :
272
275
"""Creates an LLMEnv instance from a dataloader.
273
276
@@ -297,7 +300,7 @@ def from_dataloader(
297
300
``("tokens_out", "sequences")``.
298
301
reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`.
299
302
Defaults to ``"reward"``.
300
- str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False ``.
303
+ str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True ``.
301
304
device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``.
302
305
vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an
303
306
unbounded vocabulary. Defaults to ``None``.
@@ -334,7 +337,7 @@ def from_dataloader(
334
337
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
335
338
samples (rather than an advantage module).
336
339
group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that
337
- all repeats are grouped in a single batch collected from the buffer. Defaults to ``False ``.
340
+ all repeats are grouped in a single batch collected from the buffer. Defaults to ``True ``.
338
341
339
342
Returns:
340
343
LLMEnv: The created LLMEnv instance.
0 commit comments