diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 141b5beaba2..b53ac84585d 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -440,6 +440,7 @@ TorchRL offers a series of custom built-in environments. ChessEnv PendulumEnv TicTacToeEnv + LLMEnv LLMHashingEnv @@ -1033,6 +1034,7 @@ to be able to create this other composition: Compose ConditionalSkip Crop + DataLoadingPrimer DTypeCastTransform DeviceCastTransform DiscreteActionProjection diff --git a/test/test_env.py b/test/test_env.py index 440614037de..8a2642efb05 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -12,6 +12,7 @@ import pickle import random import re +import string from collections import defaultdict from functools import partial from sys import platform @@ -43,9 +44,11 @@ CatTensors, ChessEnv, ConditionalSkip, + DataLoadingPrimer, DoubleToFloat, EnvBase, EnvCreator, + LLMEnv, LLMHashingEnv, ParallelEnv, PendulumEnv, @@ -57,6 +60,7 @@ from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv +from torchrl.envs.transforms.rlhf import as_padded_tensor from torchrl.envs.transforms.transforms import ( AutoResetEnv, AutoResetTransform, @@ -95,7 +99,7 @@ try: this_dir = os.path.dirname(os.path.realpath(__file__)) - with open(os.path.join(this_dir, "configs", "atari.yaml"), "r") as file: + with open(os.path.join(this_dir, "configs", "atari.yaml")) as file: atari_confs = yaml.load(file, Loader=yaml.FullLoader) _atari_found = True except FileNotFoundError: @@ -4503,6 +4507,141 @@ def test_skipping_history_env_collector(self, device_env, collector_cls): count += 1 +class TestLLMEnv: + @pytest.fixture(scope="class", autouse=True) + def set_capture(self): + with set_capture_non_tensor_stack(False): + yield None + return + + class DummyDataLoader: + def __init__(self, batch_size=0): + self.batch_size = batch_size + + def generate_random_string(self, length=10): + """Generate a random string of a given length.""" + return "".join(random.choice(string.ascii_lowercase) for _ in range(length)) + + def __iter__(self): + return self + + def __next__(self): + if self.batch_size == 0: + return self.generate_random_string() + else: + return [self.generate_random_string() for _ in range(self.batch_size)] + + class DummyTensorDataLoader: + def __init__(self, batch_size=0, max_length=10, padding=False): + self.batch_size = batch_size + self.max_length = max_length + self.padding = padding + + def generate_random_tensor(self): + """Generate a tensor of random int64 values.""" + length = random.randint(1, self.max_length) + return torch.tensor( + [random.randint(0, 100) for _ in range(length)], dtype=torch.int64 + ) + + def pad_tensor(self, tensor): + """Pad a tensor to the maximum length.""" + padding_length = self.max_length - len(tensor) + return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor)) + + def __iter__(self): + return self + + def __next__(self): + if self.batch_size == 0: + tensor = self.generate_random_tensor() + return self.pad_tensor(tensor) if self.padding else tensor + else: + tensors = [ + self.generate_random_tensor() for _ in range(self.batch_size) + ] + if self.padding: + tensors = [self.pad_tensor(tensor) for tensor in tensors] + return torch.stack(tensors) + else: + return tensors + + @pytest.mark.parametrize( + "str2str,stack_method", + [ + [True, None], + [False, "as_padded_tensor"], + # TODO: a bit experimental, fails with check_env_specs + # [False, "as_nested_tensor"], + [False, None], + ], + ) + @pytest.mark.parametrize("batched", [True, False]) + @pytest.mark.parametrize("device", [None, "cpu"]) + def test_llm_env(self, str2str, batched, stack_method, device): + env = LLMEnv(str2str=str2str, device=device) + if str2str: + primer = DataLoadingPrimer( + dataloader=self.DummyDataLoader(), + data_keys=["observation"], + example_data="a string!", + ) + else: + if stack_method is None: + stack_method = as_padded_tensor + primer = DataLoadingPrimer( + dataloader=self.DummyTensorDataLoader(padding=True), + data_keys=["observation"], + data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], + stack_method=stack_method, + ) + assert not env.batch_locked + env = env.append_transform(primer) + assert not env.batch_locked + if batched: + td = env.reset(TensorDict(batch_size=[3])) + env.check_env_specs(break_when_any_done="both", tensordict=td) + else: + env.check_env_specs(break_when_any_done="both") + + @pytest.mark.parametrize( + "str2str,stack_method", + [ + [True, None], + [False, "as_padded_tensor"], + # TODO: a bit experimental, fails with check_env_specs + # [False, "as_nested_tensor"], + [False, None], + ], + ) + @pytest.mark.parametrize("batched", [True, False]) + @pytest.mark.parametrize("device", [None, "cpu"]) + def test_llm_from_dataloader(self, str2str, batched, stack_method, device): + if str2str: + kwargs = { + "dataloader": self.DummyDataLoader(), + "data_keys": ["observation"], + "example_data": "a string!", + } + else: + if stack_method is None: + stack_method = as_padded_tensor + kwargs = { + "dataloader": self.DummyTensorDataLoader(padding=True), + "data_keys": ["observation"], + "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], + "stack_method": stack_method, + } + kwargs.update({"str2str": str2str, "device": device}) + env = LLMEnv.from_dataloader(**kwargs) + assert not env.batch_locked + if batched: + td = env.reset(TensorDict(batch_size=[3])) + env.check_env_specs(break_when_any_done="both", tensordict=td) + else: + env.check_env_specs(break_when_any_done="both") + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 84f9fa8b0a6..d9753eafc08 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -5,7 +5,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict -from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv +from .custom import ChessEnv, LLMEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( @@ -58,6 +58,7 @@ Compose, ConditionalSkip, Crop, + DataLoadingPrimer, DeviceCastTransform, DiscreteActionProjection, DoubleToFloat, diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index d2c85a7198f..bbd780aadd7 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -4,6 +4,6 @@ # LICENSE file in the root directory of this source tree. from .chess import ChessEnv -from .llm import LLMHashingEnv +from .llm import LLMEnv, LLMHashingEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 92f265e85d2..f6dfc835e87 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -4,23 +4,259 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from typing import Callable, List, Union +from typing import Any, Callable, Literal import torch -from tensordict import NestedKey, TensorDict, TensorDictBase +from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.tensorclass import NonTensorData, NonTensorStack - +from tensordict.utils import _zip_strict +from torch.utils.data import DataLoader from torchrl.data import ( + Bounded, Categorical as CategoricalSpec, Composite, NonTensor, SipHash, + TensorSpec, Unbounded, ) from torchrl.envs import EnvBase from torchrl.envs.utils import _StepMDP +class LLMEnv(EnvBase): + """A text generation environment. + + This environment is designed to work with language models, where the observation is a string or a tensor of + integers representing a sequence of tokens. + The action is also a string or a tensor of integers, which is concatenated to the previous observation to form the + new observation. + + By default, this environment is meant to track history for a prompt. Users can append transforms to tailor + this to their use case, such as Chain of Thought (CoT) reasoning or other custom processing. + + Users must append a transform to set the "done" condition, which would trigger the loading of the next prompt. + + Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via :meth:`~from_dataloader` + + Args: + observation_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults to + ``"observation"``. + action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``"action"``. + str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. + device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. + vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an + unbounded vocabulary. Defaults to ``None``. + + .. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples. + + Methods: + from_dataloader: Creates an LLMEnv instance from a dataloader. + + """ + + def __init__( + self, + *, + observation_key: NestedKey = "observation", + action_key: NestedKey = "action", + str2str: bool = False, + device: torch.device | None = None, + vocab_size: int | None = None, + ) -> None: + super().__init__(device=device) + self._batch_locked = False + self.str2str = str2str + self.vocab_size = vocab_size + self.observation_key = unravel_key(observation_key) + # self.action_key = unravel_key(action_key) + if str2str: + self.observation_spec = Composite( + { + observation_key: NonTensor( + example_data="a string", batched=True, shape=() + ) + } + ) + self.action_spec = Composite( + {action_key: NonTensor(example_data="a string", batched=True, shape=())} + ) + else: + if vocab_size is None: + self.observation_spec = Composite( + { + observation_key: Unbounded( + shape=(-1,), dtype=torch.int64, device=device + ) + } + ) + self.action_spec = Composite( + { + action_key: Unbounded( + shape=(-1,), dtype=torch.int64, device=device + ) + } + ) + else: + self.observation_spec = Composite( + { + observation_key: Bounded( + shape=(-1,), + dtype=torch.int64, + low=0, + high=vocab_size, + device=device, + ) + } + ) + self.action_spec = Composite( + { + action_key: Bounded( + shape=(-1,), + dtype=torch.int64, + low=0, + high=vocab_size, + device=device, + ) + } + ) + self.full_done_spec = Composite( + done=Unbounded(shape=(1,), dtype=torch.bool), + truncated=Unbounded(shape=(1,), dtype=torch.bool), + terminated=Unbounded(shape=(1,), dtype=torch.bool), + ) + + @classmethod + def from_dataloader( + cls, + dataloader: DataLoader, + *, + observation_key: NestedKey = "observation", + action_key: NestedKey = "action", + str2str: bool = False, + device: torch.device | None = None, + vocab_size: int | None = None, + primers: Composite | None = None, + data_keys: list[NestedKey] | None = None, + data_specs: list[TensorSpec] | None = None, + example_data: Any = None, + stack_method: Callable[[Any], Any] + | Literal["as_nested_tensor", "as_padded_tensor"] = None, + ) -> LLMEnv: + """Creates an LLMEnv instance from a dataloader. + + This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which populates ``data_keys`` (by default ``observation_key``) with data from the provided dataloader when the environment is reset. + + Args: + dataloader (DataLoader): The dataloader to load data from. + observation_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults + to ``"observation"``. + action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``"action"``. + str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. + device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. + vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an + unbounded vocabulary. Defaults to ``None``. + primers (Composite | None, optional): The primers to use for each key in the dataloader. + Defaults to ``None``. + data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. If not passed ``observation_key`` will be populated with the data. + Defaults to ``None``. + data_specs (list[TensorSpec] | None, optional): The specs to use for each item in the dataloader. + Defaults to ``None``. + example_data (Any, optional): Example data to use for initializing the primer. Defaults to ``None``. + stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The + method to use for stacking the data. Defaults to ``None``. + + Returns: + LLMEnv: The created LLMEnv instance. + """ + from torchrl.envs import DataLoadingPrimer + + primer = DataLoadingPrimer( + dataloader=dataloader, + primers=primers, + data_keys=data_keys if data_keys is not None else [observation_key], + data_specs=data_specs, + example_data=example_data, + stack_method=stack_method, + ) + env = LLMEnv( + str2str=str2str, + device=device, + observation_key=observation_key, + action_key=action_key, + vocab_size=vocab_size, + ) + return env.append_transform(primer) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + # Cat action entry with prev obs + if self.str2str: + obs = tensordict[self.observation_key] + action = tensordict[self.action_key] + if not tensordict.batch_size: + if not isinstance(obs, str) or not isinstance(action, str): + raise TypeError( + "The tensordict is batchless, yet the action and/or observations are not " + f"strings but {type(action)} and {type(obs)}, respectivly." + ) + observation = obs + action + else: + observation = [ + _obs + _action for (_obs, _action) in _zip_strict(obs, action) + ] + else: + try: + obs: torch.Tensor = tensordict.get(self.observation_key) + action = tensordict.get(self.action_key) + if getattr(obs, "is_nested", False): + observation = torch.nested.as_nested_tensor( + [ + torch.cat( + [ + _obs, + _action, + ], + -1, + ) + for _obs, _action in _zip_strict( + obs.unbind(0), action.unbind(0) + ) + ], + layout=obs.layout, + ) + else: + observation = torch.cat( + [ + obs, + action, + ], + -1, + ) + except TypeError: + raise TypeError( + "Failed to cat action and observation tensors. Check that str2str argument is correctly " + f"set in {type(self).__name__}." + ) + return tensordict.empty().set(self.observation_key, observation) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + # We should have an observation by this time, if not raise an exception + if tensordict is None or self.observation_key not in tensordict.keys( + isinstance(self.observation_key, tuple) + ): + raise KeyError( + f"Observation key {self.observation_key} is not defined. Make sure a TensorDictPrimer (eg, " + f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms." + ) + return tensordict.copy() + + def _set_seed(self, seed: int | None): + return seed + + class LLMHashingEnv(EnvBase): """A text generation environment that uses a hashing module to identify unique observations. @@ -84,7 +320,7 @@ def __init__( hashing_module: Callable[[torch.Tensor], torch.Tensor] = None, observation_key: NestedKey = "observation", text_output: bool = True, - tokenizer: Callable[[Union[str, List[str]]], torch.Tensor] | None = None, + tokenizer: Callable[[str | list[str]], torch.Tensor] | None = None, text_key: NestedKey | None = "text", ): super().__init__() @@ -117,7 +353,7 @@ def __init__( self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,))) _StepMDP(self) - def make_tensordict(self, input: str | List[str]) -> TensorDict: + def make_tensordict(self, input: str | list[str]) -> TensorDict: """Converts a string or list of strings in a TensorDict with appropriate shape and device.""" list_len = len(input) if isinstance(input, list) else 0 tensordict = TensorDict( diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index de12f1a0302..736bb7a2c9a 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -6,7 +6,7 @@ from .gym_transforms import EndOfLifeTransform from .r3m import R3MTransform from .rb_transforms import MultiStepTransform -from .rlhf import KLRewardTransform +from .rlhf import DataLoadingPrimer, KLRewardTransform from .transforms import ( ActionDiscretizer, ActionMask, diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index feae60a1c59..963002e8c05 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -4,18 +4,413 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from collections.abc import Mapping from copy import copy, deepcopy +from typing import Any, Callable, Iterable, List, Literal import torch -from tensordict import TensorDict, TensorDictBase, unravel_key +from tensordict import ( + maybe_dense_stack, + NestedKey, + TensorDict, + TensorDictBase, + unravel_key, +) from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams -from tensordict.utils import is_seq_of_nested_key +from tensordict.utils import _zip_strict, is_seq_of_nested_key from torch import nn -from torchrl.data.tensor_specs import Composite, Unbounded -from torchrl.envs.transforms.transforms import Transform +from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded +from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param +def as_nested_tensor(list_of_tensordicts: list[TensorDictBase]) -> TensorDictBase: + """Stacks a list of tensordicts into a single tensordict with nested tensors. + + Args: + list_of_tensordicts (list[TensorDictBase]): A list of tensordicts to stack. + + Returns: + TensorDictBase: A tensordict with nested tensors. + + """ + + def _as_nested_tensor(*list_of_tensors): + return torch.nested.as_nested_tensor(list_of_tensors, layout=torch.jagged) + + batch_size = list(list_of_tensordicts[0].shape) + batch_size.insert(0, len(list_of_tensordicts)) + return list_of_tensordicts[0].apply( + _as_nested_tensor, *list_of_tensordicts[1:], batch_size=batch_size + ) + + +def as_padded_tensor( + list_of_tensordicts: list[[TensorDictBase]], dim=0, stack_dim: int = 0 +) -> TensorDictBase: + """Stacks a list of tensordicts into a single tensordict with padded tensors. + + Args: + list_of_tensordicts (list[[TensorDictBase]]): A list of tensordicts to stack. + dim (int, optional): The dimension along which to pad. Defaults to 0. + stack_dim (int, optional): The dimension along which to stack. Defaults to 0. + + Returns: + TensorDictBase: A tensordict with padded tensors. + """ + + def _stack_tensors(*list_of_tensors): + if dim < 0: + raise ValueError("dim must be >= 0") + max_length = max([t.size(dim) for t in list_of_tensors]) + + def pad_tensor(tensor): + padding_length = max_length - tensor.size(dim) + shape = [ + s if i != dim else padding_length for i, s in enumerate(tensor.shape) + ] + return torch.cat((tensor.new_zeros(shape), tensor), dim=dim) + + return torch.stack([pad_tensor(t) for t in list_of_tensors], dim=stack_dim) + + batch_size = list(list_of_tensordicts[0].shape) + batch_size.insert(dim, len(list_of_tensordicts)) + result = list_of_tensordicts[0].apply( + _stack_tensors, *list_of_tensordicts[1:], batch_size=batch_size + ) + return result + + +class DataLoadingPrimer(TensorDictPrimer): + """A primer that loads data from a dataloader and converts it into a tensordict using ``stack_method``. + + Args: + dataloader (Iterable[Any]): The dataloader to load data from. + primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None. + data_keys (List[NestedKey] | None, optional): The keys to use for each item in the dataloader. Defaults to None. + data_specs (List[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to None. + example_data (Any, optional): Example data to use for initializing the primer. Defaults to None. + stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``maybe_dense_stack``. + + Attributes: + dataloader (Iterable[Any]): The dataloader to load data from. + endless_dataloader (Iterable[Any]): An endless iterator over the dataloader. + data_keys (List[NestedKey]): The keys to use for each item in the dataloader. + stack_method (Callable[[Any], Any]): The method to use for stacking the data. + + .. seealso:: :class:`~torchrl.envs.LLMEnv` and :class:`~torchrl.envs.LLMEnv.from_dataloader`. + + Example of a dataloader yielding strings: + >>> import random + >>> import string + >>> import tensordict as td + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.data import Unbounded + >>> from torchrl.envs import DataLoadingPrimer, LLMEnv + >>> td.set_capture_non_tensor_stack(False).set() + >>> class DummyDataLoader: + ... '''A dummy dataloader that generates random strings.''' + ... def __init__(self, batch_size: int = 0): + ... self.batch_size = batch_size + ... def generate_random_string(self, length: int = 10) -. str: + ... '''Generate a random string of a given length.''' + ... return ''.join(random.choice(string.ascii_lowercase) for _ in range(length)) + ... def __iter__(self): + ... return self + ... def __next__(self): + ... if self.batch_size == 0: + ... return self.generate_random_string() + ... else: + ... return [self.generate_random_string() for _ in range(self.batch_size)] + >>> # Create an LLM environment with string-to-string input/output. + >>> env = LLMEnv(str2str=True) + >>> # Append a DataLoadingPrimer to the environment. + >>> env = env.append_transform( + >>> DataLoadingPrimer( + >>> dataloader=DummyDataLoader(), + >>> data_keys=["observation"], + >>> example_data="a string!", + >>> ) + >>> ) + >>> # Test the environment. + >>> print(env.rand_action(TensorDict())) + TensorDict( + fields={ + action: NonTensorData(data=a string, batch_size=torch.Size([]), device=None)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> print(env.rollout(3)) + TensorDict( + fields={ + action: NonTensorStack( + ['a string', 'a string', 'a string'], + batch_size=torch.Size([3]), + device=None), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: NonTensorStack( + ['zxwvupirska string', 'zxwvupirska stringa string..., + batch_size=torch.Size([3]), + device=None), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False), + observation: NonTensorStack( + ['zxwvupirsk', 'zxwvupirska string', 'zxwvupirska ..., + batch_size=torch.Size([3]), + device=None), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + >>> # Roll out the environment with a specific initial state. + >>> init_state = env.reset(TensorDict(batch_size=[3])) + >>> print(env.rollout(3, auto_reset=False, tensordict=init_state)) + TensorDict( + fields={ + action: NonTensorStack( + [['a string', 'a string', 'a string'], ['a string'..., + batch_size=torch.Size([3, 3]), + device=None), + done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: NonTensorStack( + [[array(['nngcmflsana string', 'vrrbnhzpmga string..., + batch_size=torch.Size([3, 3]), + device=None), + terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3, 3]), + device=None, + is_shared=False), + observation: NonTensorStack( + [['nngcmflsan', array(['nngcmflsana string', 'vrrb..., + batch_size=torch.Size([3, 3]), + device=None), + terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([3, 3]), + device=None, + is_shared=False) + + Example of dataloader yielding tensors: + >>> import random + >>> import string + >>> + >>> import tensordict as td + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.data import Unbounded + >>> from torchrl.envs import DataLoadingPrimer, LLMEnv + >>> + >>> td.set_capture_non_tensor_stack(False).set() + >>> + >>> + >>> class DummyTensorDataLoader: + ... '''A dummy dataloader that generates tensors of random int64 values.''' + ... + ... def __init__(self, batch_size: int = 0, max_length: int = 10, padding: bool = False): + ... ''' + ... Args: + ... batch_size (int, optional): The batch size of the generated tensors. Defaults to 0. + ... max_length (int, optional): The maximum length of the generated tensors. Defaults to 10. + ... padding (bool, optional): Whether to pad the tensors to the maximum length. Defaults to False. + ... ''' + ... self.batch_size = batch_size + ... self.max_length = max_length + ... self.padding = padding + ... + ... def generate_random_tensor(self) -. torch.Tensor: + ... '''Generate a tensor of random int64 values.''' + ... length = random.randint(1, self.max_length) + ... return torch.tensor([random.randint(0, 100) for _ in range(length)], dtype=torch.int64) + ... + ... def pad_tensor(self, tensor: torch.Tensor) -. torch.Tensor: + ... '''Pad a tensor to the maximum length.''' + ... padding_length = self.max_length - len(tensor) + ... return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor)) + ... + ... def __iter__(self): + ... return self + ... + ... def __next__(self): + ... if self.batch_size == 0: + ... tensor = self.generate_random_tensor() + ... return self.pad_tensor(tensor) if self.padding else tensor + ... else: + ... tensors = [self.generate_random_tensor() for _ in range(self.batch_size)] + ... if self.padding: + ... tensors = [self.pad_tensor(tensor) for tensor in tensors] + ... return torch.stack(tensors) + ... else: + ... return tensors + >>> + >>> # Create an LLM environment with non-string input/output and append a DataLoadingPrimer. + >>> env = LLMEnv(str2str=False) + >>> env = env.append_transform( + >>> DataLoadingPrimer( + >>> dataloader=DummyTensorDataLoader(), + >>> data_keys=["observation"], + >>> data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], + >>> ) + >>> ) + >>> print(env.rand_action(TensorDict())) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> print(env.rollout(3)) + LazyStackedTensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: LazyStackedTensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, -1]), device=cpu, dtype=torch.int64, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([3]), + device=None, + is_shared=False, + stack_dim=0), + observation: Tensor(shape=torch.Size([3, -1]), device=cpu, dtype=torch.int64, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([3]), + device=None, + is_shared=False, + stack_dim=0) + >>> # Create an LLM environment with padded tensor input/output and append a DataLoadingPrimer. + >>> env = LLMEnv(str2str=False) + >>> env = env.append_transform( + >>> DataLoadingPrimer( + >>> dataloader=DummyTensorDataLoader(padding=True), + >>> data_keys=["observation"], + >>> data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], + >>> stack_method="as_padded_tensor", + >>> ) + >>> ) + >>> print(env.rollout(3, auto_reset=False, tensordict=env.reset(TensorDict(batch_size=[3])))) + LazyStackedTensorDict( + fields={ + action: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: LazyStackedTensorDict( + fields={ + done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([3, 3, -1]), device=cpu, dtype=torch.int64, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([3, 3]), + device=None, + is_shared=False, + stack_dim=1), + observation: Tensor(shape=torch.Size([3, 3, -1]), device=cpu, dtype=torch.int64, is_shared=False), + terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + exclusive_fields={ + }, + batch_size=torch.Size([3, 3]), + device=None, + is_shared=False, + stack_dim=1) + + """ + + def __init__( + self, + dataloader: Iterable[Any], + primers: Composite | None = None, + data_keys: List[NestedKey] | None = None, + data_specs: List[TensorSpec] | None = None, + example_data: Any = None, + stack_method: Callable[[Any], Any] + | Literal["as_nested_tensor", "as_padded_tensor"] = None, + ): + self.dataloader = dataloader + self.endless_dataloader = self._endless_iter(self.dataloader) + if primers is None: + if data_keys is None: + data_keys = ["data"] + if data_specs is None: + data_specs = [NonTensor(example_data=example_data, shape=())] + primers = Composite( + { + data_key: data_spec + for data_key, data_spec in _zip_strict(data_keys, data_specs) + } + ) + self.data_keys = data_keys + else: + self.data_keys = list(primers.keys(True, True)) + + if stack_method is None: + stack_method = maybe_dense_stack + elif stack_method == "as_nested_tensor": + stack_method = as_nested_tensor + elif stack_method == "as_padded_tensor": + stack_method = as_padded_tensor + elif not callable(stack_method): + raise ValueError(f"Unknown stack_method={stack_method}") + self.stack_method = stack_method + + super().__init__( + primers=primers, + default_value=self._load_from_dataloader, + reset_key=None, + expand_specs=None, + single_default_value=True, + call_before_env_reset=True, + ) + + @classmethod + def _endless_iter(self, obj): + while True: + yield from obj + + def _load_from_dataloader(self, reset: torch.Tensor | None = None): + if reset is not None: + if not reset.any(): + raise RuntimeError("reset must have at least one True value.") + if reset.ndim > 0: + return self.stack_method( + [self._load_from_dataloader() for i in range(reset.sum())] + ) + data = next(self.endless_dataloader) + # Some heuristic here: + # if data is a map, assume its keys match the keys in spec + # TODO: one could rename the keys too + if isinstance(data, Mapping): + out = TensorDict(data) + elif len(self.data_keys) > 1 and isinstance(data, (list, tuple)): + out = TensorDict({k: val for k, val in _zip_strict(self.data_keys, data)}) + elif len(self.data_keys) == 1: + out = TensorDict({self.data_keys[0]: data}) + else: + raise ValueError( + f"Unrecognized data type: {type(data)} with keys {self.data_keys}." + ) + return out + + class KLRewardTransform(Transform): """A transform to add a KL[pi_current||pi_0] correction term to the reward.