Skip to content

Commit f72153c

Browse files
Apply suggestions from code review
Co-authored-by: mikaylagawarecki <[email protected]>
1 parent e8cd9f9 commit f72153c

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

torchrl/envs/custom/llm.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class LLMEnv(EnvBase):
3131
integers representing a sequence of tokens.
3232
The action is also a string or a tensor of integers, which is concatenated to the previous observation to form the
3333
new observation.
34-
34+
Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via :meth:`~from_dataloader`
3535
Args:
3636
observation_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults to
3737
``"observation"``.
@@ -138,7 +138,7 @@ def from_dataloader(
138138
) -> LLMEnv:
139139
"""Creates an LLMEnv instance from a dataloader.
140140
141-
This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which loads data from the provided dataloader.
141+
This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which populates ``data_keys`` (by default ``observation_key``) with data from the provided dataloader when the environment is reset.
142142
143143
Args:
144144
dataloader (DataLoader): The dataloader to load data from.
@@ -151,7 +151,7 @@ def from_dataloader(
151151
unbounded vocabulary. Defaults to ``None``.
152152
primers (Composite | None, optional): The primers to use for each key in the dataloader.
153153
Defaults to ``None``.
154-
data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader.
154+
data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. If not passed ``observation_key`` will be populated with the data.
155155
Defaults to ``None``.
156156
data_specs (list[TensorSpec] | None, optional): The specs to use for each item in the dataloader.
157157
Defaults to ``None``.

torchrl/envs/transforms/rlhf.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,15 @@ def pad_tensor(tensor):
8282

8383

8484
class DataLoadingPrimer(TensorDictPrimer):
85-
"""A primer that loads data from a dataloader and stacks it into a tensordict.
85+
"""A primer that loads data from a dataloader and converts it into a tensordict using ``stack_method``.
8686
8787
Args:
8888
dataloader (Iterable[Any]): The dataloader to load data from.
8989
primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None.
9090
data_keys (List[NestedKey] | None, optional): The keys to use for each item in the dataloader. Defaults to None.
9191
data_specs (List[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to None.
9292
example_data (Any, optional): Example data to use for initializing the primer. Defaults to None.
93-
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to None.
93+
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``maybe_dense_stack``.
9494
9595
Attributes:
9696
dataloader (Iterable[Any]): The dataloader to load data from.

0 commit comments

Comments
 (0)