diff --git a/test/test_actors.py b/test/test_actors.py index 51fa2f9031c..629da3cbf7d 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -8,14 +8,14 @@ import pytest import torch -from tensordict import TensorDict +from tensordict import NonTensorStack, TensorDict from tensordict.nn import CompositeDistribution, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import distributions as dist, nn from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot from torchrl.data.llm.dataset import _has_transformers -from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal +from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal from torchrl.modules.tensordict_module.actors import ( _process_action_space_spec, ActorValueOperator, @@ -907,6 +907,55 @@ def test_lmhead_actorvalueoperator(device): ) == len(policy_params) +@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies") +class TestTransformerActor: + @pytest.mark.parametrize( + "from_text, generate, tokens, attention_mask", + [ + (True, True, None, None), + (True, False, None, None), + ( + False, + True, + torch.randint(1024, (1, 10)), + torch.ones(1, 10, dtype=torch.int64), + ), + (False, True, torch.randint(1024, (1, 10)), None), + ], + ) + def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask): + from torchrl.data.llm import LLMData + from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + model = GPT2LMHeadModel(GPT2Config()) + tokenizer.padding_side = "left" + m = from_hf_transformers( + model, tokenizer=tokenizer, from_text=from_text, generate=generate + ) + if from_text: + tdin = LLMData(text=NonTensorStack("a text"), batch_size=1) + else: + tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1) + td = m(tdin) + assert td is tdin + assert isinstance(td, LLMData) + if from_text and generate: + assert td.text_response is not None + else: + assert td.text_response is None + if attention_mask is not None or from_text: + assert td.attention_mask is not None + else: + assert td.attention_mask is None + if not generate: + assert td.text_response is None + assert td.tokens_response is None + assert td.log_probs is not None + assert td.logits is not None + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_env.py b/test/test_env.py index 96d898ddd31..dfcc5a5e87d 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4644,11 +4644,13 @@ def __next__(self): @pytest.mark.parametrize("batch_size", [0, 4]) @pytest.mark.parametrize("device", [None, "cpu"]) def test_llm_env(self, str2str, batched, stack_method, device, batch_size): - env = LLMEnv(str2str=str2str, device=device) + env = LLMEnv( + str2str=str2str, device=device, has_attention=False, no_stack=False + ) if str2str: primer = DataLoadingPrimer( dataloader=self.DummyDataLoader(batch_size=batch_size), - data_keys=["observation"], + data_keys=[LLMEnv._DEFAULT_STR_KEY], example_data="a string!", ) else: @@ -4658,7 +4660,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size): dataloader=self.DummyTensorDataLoader( batch_size=batch_size, padding=True ), - data_keys=["observation"], + data_keys=[LLMEnv._DEFAULT_TOKEN_KEY], data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], stack_method=stack_method, ) @@ -4668,7 +4670,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size): if batched: td = env.reset(TensorDict(batch_size=[3])) env.check_env_specs(break_when_any_done="both", tensordict=td) - r = env.rollout(10, tensordict=TensorDict(batch_size=[3])) + env.rollout(10, tensordict=TensorDict(batch_size=[3])) else: env.check_env_specs(break_when_any_done="both") @@ -4691,7 +4693,7 @@ def test_llm_from_dataloader( if str2str: kwargs = { "dataloader": self.DummyDataLoader(batch_size=batch_size), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_STR_KEY], "example_data": "a string!", } else: @@ -4701,11 +4703,18 @@ def test_llm_from_dataloader( "dataloader": self.DummyTensorDataLoader( padding=True, batch_size=batch_size ), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_TOKEN_KEY], "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], "stack_method": stack_method, } - kwargs.update({"str2str": str2str, "device": device}) + kwargs.update( + { + "str2str": str2str, + "device": device, + "has_attention": False, + "no_stack": False, + } + ) env = LLMEnv.from_dataloader(**kwargs) assert not env.batch_locked if batched: @@ -4718,13 +4727,15 @@ def test_llm_from_dataloader( def policy(td): if str2str: if not td.shape: - td["action"] = "" + td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "" else: - td["action"] = NonTensorStack( + td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack( *["" for _ in range(td.shape[0])] ) else: - td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64) + td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( + td.shape + (1,), dtype=torch.int64 + ) return td if batched: @@ -4732,32 +4743,48 @@ def policy(td): r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3])) assert r.ndim == 2 if str2str: - assert isinstance(r[0, 0]["observation"], str) - assert isinstance(r[0, 1]["observation"], str) + assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str) + assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str) assert ( - r[0, 0]["observation"] - == r[0, 1]["observation"][: -len(r[0, 0]["action"])] + r[0, 0][LLMEnv._DEFAULT_STR_KEY] + == r[0, 1][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) + ] ) assert ( - r[0, 1]["observation"] - == r[0, 2]["observation"][: -len(r[0, 1]["action"])] + r[0, 1][LLMEnv._DEFAULT_STR_KEY] + == r[0, 2][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) + ] ) assert ( - r[-1, 0]["observation"] - == r[-1, 1]["observation"][: -len(r[-1, 0]["action"])] + r[-1, 0][LLMEnv._DEFAULT_STR_KEY] + == r[-1, 1][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) + ] ) assert ( - r[-1, 1]["observation"] - == r[-1, 2]["observation"][: -len(r[-1, 1]["action"])] + r[-1, 1][LLMEnv._DEFAULT_STR_KEY] + == r[-1, 2][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) + ] ) else: - assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all() - assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all() assert ( - r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1] + r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] + ).all() + assert ( + r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY] + == r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] ).all() assert ( - r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1] + r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] + ).all() + assert ( + r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY] + == r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] ).all() else: r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[])) @@ -4783,7 +4810,7 @@ def test_llm_from_dataloader_repeats( if str2str: kwargs = { "dataloader": self.DummyDataLoader(batch_size=batch_size), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_STR_KEY], "example_data": "a string!", "repeats": repeats, } @@ -4794,12 +4821,19 @@ def test_llm_from_dataloader_repeats( "dataloader": self.DummyTensorDataLoader( padding=True, batch_size=batch_size ), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_TOKEN_KEY], "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], "stack_method": stack_method, "repeats": repeats, } - kwargs.update({"str2str": str2str, "device": device}) + kwargs.update( + { + "str2str": str2str, + "device": device, + "has_attention": False, + "no_stack": False, + } + ) env = LLMEnv.from_dataloader(**kwargs) assert env.transform.repeats == repeats @@ -4809,13 +4843,15 @@ def test_llm_from_dataloader_repeats( def policy(td): if str2str: if not td.shape: - td["action"] = "" + td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "" else: - td["action"] = NonTensorStack( + td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack( *["" for _ in range(td.shape[0])] ) else: - td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64) + td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( + td.shape + (1,), dtype=torch.int64 + ) return td if batched: @@ -4831,34 +4867,58 @@ def policy(td): r_reset = r[..., ::max_steps] if not batched: if str2str: - assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"] - assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"] - assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"] + assert ( + r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY] + == r_reset[..., 1][LLMEnv._DEFAULT_STR_KEY] + ) + assert ( + r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY] + == r_reset[..., 2][LLMEnv._DEFAULT_STR_KEY] + ) + assert ( + r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY] + != r_reset[..., 3][LLMEnv._DEFAULT_STR_KEY] + ) else: assert ( - r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"] + r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r_reset[..., 1][LLMEnv._DEFAULT_TOKEN_KEY] ).all() assert ( - r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"] + r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r_reset[..., 2][LLMEnv._DEFAULT_TOKEN_KEY] ).all() assert ( - r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"] + r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY] + != r_reset[..., 3][LLMEnv._DEFAULT_TOKEN_KEY] ).any() else: # When batched, each block contains the 3 reset packs if str2str: - assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"] - assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"] - assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"] + assert ( + r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY] + == r_reset[1, 0][LLMEnv._DEFAULT_STR_KEY] + ) + assert ( + r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY] + == r_reset[2, 0][LLMEnv._DEFAULT_STR_KEY] + ) + assert ( + r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY] + != r_reset[0, 1][LLMEnv._DEFAULT_STR_KEY] + ) else: assert ( - r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"] + r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r_reset[1, 0][LLMEnv._DEFAULT_TOKEN_KEY] ).all() assert ( - r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"] + r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r_reset[2, 0][LLMEnv._DEFAULT_TOKEN_KEY] ).all() assert ( - r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"] + r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] + != r_reset[0, 1][LLMEnv._DEFAULT_TOKEN_KEY] ).any() @pytest.mark.parametrize( @@ -4892,7 +4952,7 @@ def test_done_and_reward( if str2str: kwargs = { "dataloader": self.DummyDataLoader(batch_size=batch_size), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_STR_KEY], "example_data": "a string!", "repeats": repeats, "assign_reward": assign_reward, @@ -4905,20 +4965,27 @@ def test_done_and_reward( "dataloader": self.DummyTensorDataLoader( padding=True, batch_size=batch_size ), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_TOKEN_KEY], "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], "stack_method": stack_method, "repeats": repeats, "assign_reward": assign_reward, "assign_done": assign_done, } - kwargs.update({"str2str": str2str, "device": device}) + kwargs.update( + { + "str2str": str2str, + "device": device, + "has_attention": False, + "no_stack": False, + } + ) env = LLMEnv.from_dataloader(**kwargs) # We want to make sure that transforms that rely on the done state work appropriately env.append_transform(StepCounter(max_steps=10)) def policy(td): - td["action"] = torch.ones( + td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64 ) return td diff --git a/torchrl/data/llm/utils.py b/torchrl/data/llm/utils.py index c4282e8ab30..6ffa3641467 100644 --- a/torchrl/data/llm/utils.py +++ b/torchrl/data/llm/utils.py @@ -626,7 +626,7 @@ class LLMData(TensorClass["nocast"]): """ - tokens: torch.Tensor + tokens: torch.Tensor | None = None tokens_response: torch.Tensor | None = None attention_mask: torch.Tensor | None = None token_list: list[int] | list[list[int]] | None = None @@ -634,3 +634,4 @@ class LLMData(TensorClass["nocast"]): logits: torch.Tensor | None = None log_probs: torch.Tensor | None = None text: str | list[str] | None = None + text_response: torch.Tensor | None = None diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 1d6a4ac69e4..c9f6715984d 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1536,10 +1536,10 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): - if isinstance(storage, ListStorage): - return _stack_anything - elif isinstance(storage, TensorStorage): + if isinstance(storage, LazyStackStorage) or isinstance(storage, TensorStorage): return _collate_id + elif isinstance(storage, ListStorage): + return _stack_anything else: raise NotImplementedError( f"Could not find a default collate_fn for storage {type(storage)}." diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index a734b8ffd6e..07a0880ba5b 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Literal import torch + from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.tensorclass import NonTensorData, NonTensorStack from tensordict.utils import _zip_strict @@ -40,9 +41,16 @@ class LLMEnv(EnvBase): Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via :meth:`~from_dataloader` Keyword Args: - observation_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults to - ``"observation"``. - action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``"action"``. + token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`). + Defaults to ``"tokens"``. + str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `str2str=True`). + Defaults to ``"text"``. + attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored. + Defaults to ``"attention_mask"``. + action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to + ``tokens_response`` or ``"text_response"``. + reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`. + Defaults to ``"reward"``. str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an @@ -50,9 +58,21 @@ class LLMEnv(EnvBase): no_stack (bool, optional): If ``False`` (default), the environment should stack the action with the past observation, each action being a new, unseen part of a conversation. Otherwise, the action is assumed to be the plain output of the LLM, including the input tokens / strings. - assign_reward (bool, optional): TODO - assign_done (bool, optional): TODO - reward_key (NestedKey, optional): TODO + has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by + :attr:`attention_key`. Defaults to ``True``. + assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape + is written during calls to `step()`. Defaults to ``False``. + assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the + action shape is written during calls to `step()`. Defaults to ``False``. + + .. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root + as it is a requirement for all TorchRL environments. + + batch_size (int or torch.Size, optional): Batch size of the environment. If left empty, the environment + is batchless (or batch-unlocked), meaning that it can accept tensordicts of any batch size. + Defaults to ``None`` (batch-unlocked). + as_llm_data (bool, optional): If ``True``, the data will be of type :class:`~torchrl.data.LLMData`. + Defaults to ``False``. .. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples. @@ -61,31 +81,58 @@ class LLMEnv(EnvBase): """ + _DEFAULT_TOKEN_KEY = "tokens" + _DEFAULT_STR_KEY = "text" + _DEFAULT_ATTENTION_KEY = "attention_mask" + _DEFAULT_ACTION_TOKENS_KEY = "tokens_response" + _DEFAULT_ACTION_STR_KEY = "text_response" + def __init__( self, *, - token_key: NestedKey = "observation", + token_key: NestedKey | None = None, + str_key: NestedKey | None = None, attention_key: NestedKey | None = None, - action_key: NestedKey = "action", + action_key: NestedKey | None = None, + reward_key: NestedKey = "reward", str2str: bool = False, device: torch.device | None = None, vocab_size: int | None = None, - no_stack: bool = False, + no_stack: bool = True, assign_reward: bool = False, assign_done: bool = False, - reward_key: NestedKey = "reward", - batch_size: int | None = None, + batch_size: int | torch.Size | None = None, + has_attention: bool = True, + as_llm_data: bool = False, ) -> None: + self.as_llm_data = as_llm_data + if token_key is None: + token_key = self._DEFAULT_TOKEN_KEY + if str_key is None: + str_key = self._DEFAULT_STR_KEY + if attention_key is None: + attention_key = self._DEFAULT_ATTENTION_KEY + if action_key is None: + if str2str: + action_key = self._DEFAULT_ACTION_STR_KEY + else: + action_key = self._DEFAULT_ACTION_TOKENS_KEY if batch_size is None: self._batch_locked = False + batch_size = () else: self._batch_locked = True + if not isinstance(batch_size, (tuple, list)): + batch_size = (batch_size,) super().__init__( - device=device, batch_size=() if batch_size is None else (batch_size,) + device=device, + batch_size=batch_size, ) + self.has_attention = has_attention self.str2str = str2str self.vocab_size = vocab_size - self.observation_key = unravel_key(token_key) + self.token_key = unravel_key(token_key) + self.str_key = unravel_key(str_key) if attention_key is not None: attention_key = unravel_key(attention_key) self.attention_key = attention_key @@ -96,7 +143,7 @@ def __init__( # self.action_key = unravel_key(action_key) if str2str: self.full_observation_spec_unbatched = Composite( - {token_key: NonTensor(example_data="a string", batched=True, shape=())} + {self.str_key: NonTensor(example_data="a string", batched=True, shape=())} ) self.full_action_spec_unbatched = Composite( {action_key: NonTensor(example_data="a string", batched=True, shape=())} @@ -106,7 +153,7 @@ def __init__( observation_spec = { token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device) } - if attention_key is not None: + if self.has_attention: observation_spec[attention_key] = Unbounded( shape=(-1,), dtype=torch.int64, device=device ) @@ -157,15 +204,15 @@ def __init__( if not self.assign_done: # Use single done self.full_done_spec_unbatched = Composite( - done=Unbounded(shape=(-1,), dtype=torch.bool), - terminated=Unbounded(shape=(-1,), dtype=torch.bool), + done=Unbounded(shape=(1,), dtype=torch.bool), + terminated=Unbounded(shape=(1,), dtype=torch.bool), ) elif self.str2str: raise STR2STR_ERR else: # Use single done self.full_done_spec_unbatched = Composite( - tokens=Composite( + tokens_data=Composite( done=Unbounded(shape=(-1,), dtype=torch.bool), terminated=Unbounded(shape=(-1,), dtype=torch.bool), ), @@ -178,17 +225,20 @@ def from_dataloader( cls, dataloader: DataLoader, *, - token_key: NestedKey = "observation", + token_key: NestedKey | None = None, + str_key: NestedKey | None = None, attention_key: NestedKey | None = None, - action_key: NestedKey = "action", + action_key: NestedKey | None = None, + reward_key: NestedKey = "reward", str2str: bool = False, device: torch.device | None = None, vocab_size: int | None = None, no_stack: bool = False, - batch_size: int | None = None, + as_llm_data: bool = False, + batch_size: int | torch.Size | None = None, + has_attention: bool = True, assign_reward: bool = False, assign_done: bool = False, - reward_key: NestedKey = "reward", primers: Composite | None = None, data_keys: list[NestedKey] | None = None, data_specs: list[TensorSpec] | None = None, @@ -203,10 +253,16 @@ def from_dataloader( Args: dataloader (DataLoader): The dataloader to load data from. - token_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults - to ``"observation"``. - attention_key: TODO - action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``"action"``. + token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`). + Defaults to ``("tokens_in", "input_ids")``. + str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `str2str=True`). + Defaults to ``"test"``. + attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored. + Defaults to ``("tokens_in", "input_ids")`` + action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to + ``("tokens_out", "sequences")``. + reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`. + Defaults to ``"reward"``. str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an @@ -214,10 +270,19 @@ def from_dataloader( no_stack (bool, optional): If ``False`` (default), the environment should stack the action with the past observation, each action being a new, unseen part of a conversation. Otherwise, the action is assumed to be the plain output of the LLM, including the input tokens / strings. - assign_reward (bool, optional): TODO - assign_done (bool, optional): TODO - reward_key (NestedKey, optional): TODO - batch_size (int, optional): TODO + has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by + :attr:`attention_key`. Defaults to ``True``. + assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape + is written during calls to `step()`. Defaults to ``False``. + assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the + action shape is written during calls to `step()`. Defaults to ``False``. + + .. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root + as it is a requirement for all TorchRL environments. + + batch_size (int or torch.Size, optional): Batch size of the environment. If left empty, the environment + is batchless (or batch-unlocked), meaning that it can accept tensordicts of any batch size. + Defaults to ``None`` (batch-unlocked). primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to ``None``. data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. If not passed ``observation_key`` will be populated with the data. @@ -230,16 +295,35 @@ def from_dataloader( repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo samples (rather than an advantage module). + as_llm_data (bool, optional): If ``True``, the data will be of type :class:`~torchrl.data.LLMData`. + Defaults to ``False``. Returns: LLMEnv: The created LLMEnv instance. """ from torchrl.envs import DataLoadingPrimer + if data_keys is None: + if str2str: + if str_key is None: + data_keys = [LLMEnv._DEFAULT_STR_KEY] + else: + data_keys = [str_key] + else: + if token_key is None: + data_keys = [LLMEnv._DEFAULT_TOKEN_KEY] + else: + data_keys = [token_key] + if has_attention: + if attention_key is None: + data_keys.append(LLMEnv._DEFAULT_ATTENTION_KEY) + else: + data_keys.append(attention_key) + primer = DataLoadingPrimer( dataloader=dataloader, primers=primers, - data_keys=data_keys if data_keys is not None else [token_key], + data_keys=data_keys, data_specs=data_specs, example_data=example_data, stack_method=stack_method, @@ -249,14 +333,17 @@ def from_dataloader( str2str=str2str, device=device, token_key=token_key, + str_key=str_key, attention_key=attention_key, action_key=action_key, + reward_key=reward_key, vocab_size=vocab_size, no_stack=no_stack, assign_reward=assign_reward, assign_done=assign_done, - reward_key=reward_key, batch_size=batch_size, + has_attention=has_attention, + as_llm_data=as_llm_data, ) return env.append_transform(primer) @@ -276,6 +363,8 @@ def _step( self._make_next_obs(tensordict, next_td) self._maybe_make_reward(tensordict, next_td) self._maybe_make_done(tensordict, next_td) + if self.as_llm_data: + raise NotImplementedError() return next_td def _maybe_make_reward( @@ -301,14 +390,14 @@ def _maybe_make_done( ) else: done = torch.zeros_like(action, dtype=torch.bool) - next_td.set(("tokens", "terminated"), done) - next_td.set(("tokens", "done"), done.clone()) + next_td.set(("tokens_data", "terminated"), done) + next_td.set(("tokens_data", "done"), done.clone()) next_td.set( - "terminated", next_td.get(("tokens", "done")).any(-1, keepdim=True) + "terminated", next_td.get(("tokens_data", "done")).any(-1, keepdim=True) ) next_td.set( "terminated", - next_td.get(("tokens", "terminated")).any(-1, keepdim=True), + next_td.get(("tokens_data", "terminated")).any(-1, keepdim=True), ) return next_td @@ -316,9 +405,11 @@ def _make_next_obs( self, tensordict: TensorDictBase, nex_td: TensorDictBase ) -> TensorDictBase: if self.no_stack: + if self.str2str: + raise NotImplementedError action = tensordict.get(self.action_key) - nex_td.set(self.observation_key, action) - if self.attention_key is not None: + nex_td.set(self.token_key, action) + if self.has_attention: attention_mask = tensordict.get(self.attention_key) n = action.shape[-1] - attention_mask.shape[-1] if n > 0: @@ -335,7 +426,7 @@ def _make_next_obs( # Cat action entry with prev obs if self.str2str: - obs = tensordict[self.observation_key] + obs = tensordict[self.str_key] action = tensordict[self.action_key] if not tensordict.batch_size: if not isinstance(obs, str) or not isinstance(action, str): @@ -351,20 +442,15 @@ def _make_next_obs( for (_obs, _action) in _zip_strict(obs, action) ] ) + return nex_td.set(self.str_key, observation) else: try: - obs: torch.Tensor = tensordict.get(self.observation_key) + obs: torch.Tensor = tensordict.get(self.token_key) action = tensordict.get(self.action_key) if getattr(obs, "is_nested", False): observation = torch.nested.as_nested_tensor( [ - torch.cat( - [ - _obs, - _action, - ], - -1, - ) + torch.cat([_obs, _action], -1) for _obs, _action in _zip_strict( obs.unbind(0), action.unbind(0) ) @@ -372,31 +458,36 @@ def _make_next_obs( layout=obs.layout, ) else: - observation = torch.cat( - [ - obs, - action, - ], - -1, - ) + observation = torch.cat([obs, action], -1) except TypeError: raise TypeError( "Failed to cat action and observation tensors. Check that str2str argument is correctly " f"set in {type(self).__name__}." ) - return nex_td.set(self.observation_key, observation) + return nex_td.set(self.token_key, observation) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # We should have an observation by this time, if not raise an exception - if tensordict is None or self.observation_key not in tensordict.keys( - isinstance(self.observation_key, tuple) - ): + def check_token(): + return not self.str2str and ( + self.token_key not in tensordict.keys(isinstance(self.token_key, tuple)) + ) + + def check_str(): + return self.str2str and ( + self.str_key not in tensordict.keys(isinstance(self.str_key, tuple)) + ) + + if tensordict is None or check_token() or check_str(): raise KeyError( - f"Observation key {self.observation_key} is not defined. Make sure a TensorDictPrimer (eg, " + f"Observation key {self.token_key} is not defined. Make sure a TensorDictPrimer (eg, " f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms." ) td_reset = tensordict.copy() - return self._maybe_make_done(tensordict, td_reset) + tensordict = self._maybe_make_done(tensordict, td_reset) + if self.as_llm_data: + raise NotImplementedError() + return tensordict def _set_seed(self, seed: int | None): return seed diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index 4dd1b0f9b3d..94984ee1a63 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -24,6 +24,7 @@ from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param +from torchrl.envs.utils import make_composite_from_td def as_nested_tensor(list_of_tensordicts: list[TensorDictBase]) -> TensorDictBase: @@ -374,12 +375,26 @@ def __init__( use_buffer = True self.use_buffer = use_buffer + if self.use_buffer: + self._queue = deque() + # No auto_batch_size if we know we have a single element self.auto_batch_size = auto_batch_size and ( getattr(dataloader, "batch_size", 1) > 0 ) self.endless_dataloader = self._endless_iter(self.dataloader) - if primers is None: + + if stack_method is None: + stack_method = maybe_dense_stack + elif stack_method == "as_nested_tensor": + stack_method = as_nested_tensor + elif stack_method == "as_padded_tensor": + stack_method = as_padded_tensor + elif not callable(stack_method): + raise ValueError(f"Unknown stack_method={stack_method}") + self.stack_method = stack_method + + if primers is None and not self.use_buffer: if data_keys is None: data_keys = ["data"] if data_specs is None: @@ -391,19 +406,17 @@ def __init__( } ) self.data_keys = data_keys + elif primers is None: + self.data_keys = data_keys + # We can get the primer from the dataloader itself + data = self._load_from_dataloader() + primers = make_composite_from_td(data, dynamic_shape=True) + self._queue.insert(0, data) + if data_keys is None: + self.data_keys = list(primers.keys(True, True)) else: self.data_keys = list(primers.keys(True, True)) - if stack_method is None: - stack_method = maybe_dense_stack - elif stack_method == "as_nested_tensor": - stack_method = as_nested_tensor - elif stack_method == "as_padded_tensor": - stack_method = as_padded_tensor - elif not callable(stack_method): - raise ValueError(f"Unknown stack_method={stack_method}") - self.stack_method = stack_method - super().__init__( primers=primers, default_value=self._load_from_dataloader, @@ -413,8 +426,6 @@ def __init__( call_before_env_reset=True, ) self._reset_key = "_reset" - if self.use_buffer: - self._queue = deque() @classmethod def _endless_iter(self, obj): @@ -445,6 +456,12 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): out = TensorDict.from_dict( data, auto_batch_size=self.auto_batch_size, batch_dims=1 ) + elif self.data_keys is None: + raise RuntimeError( + f"Cannot lazily instantiate the {type(self).__name__} as the data_keys was " + f"not passed but the data is not a Mapping, therefore the keys cannot be retrieved " + f"automatically. Please pass the data_keys to the constructor." + ) elif len(self.data_keys) > 1 and isinstance(data, (list, tuple)): out = TensorDict.from_dict( {k: val for k, val in _zip_strict(self.data_keys, data)}, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0fffe6ccd7e..f2d38c66121 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1619,10 +1619,14 @@ def __len__(self): return len(self.transforms) def __repr__(self) -> str: - layers_str = ",\n".join( - [indent(str(trsf), 4 * " ") for trsf in self.transforms] - ) - return f"{self.__class__.__name__}(\n{indent(layers_str, 4 * ' ')})" + if len(self.transforms): + layers_str = ",\n".join( + [indent(str(trsf), 4 * " ") for trsf in self.transforms] + ) + layers_str = f"\n{indent(layers_str, 4 * ' ')}" + else: + layers_str = "" + return f"{self.__class__.__name__}({layers_str})" def empty_cache(self): for t in self.transforms: @@ -8127,7 +8131,7 @@ def _reset( _reset = tensordict.get(reset_key, None) if _reset is None: done_key = _replace_last(init_key, "done") - shape = self.parent.full_done_spec[done_key].shape + shape = self.parent.full_done_spec[done_key]._safe_shape tensordict_reset.set( init_key, torch.ones( diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 3225da8e437..e5b52a8a1f0 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -93,91 +93,93 @@ ) from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip +from .llm import from_hf_transformers __all__ = [ - "DistributionalDQNnet", - "Delta", - "distributions_maps", - "IndependentNormal", - "MaskedCategorical", - "MaskedOneHotCategorical", - "NormalParamExtractor", - "NormalParamWrapper", - "OneHotCategorical", - "OneHotOrdinal", - "Ordinal", - "ReparamGradientStrategy", - "TanhDelta", - "TanhNormal", - "TruncatedNormal", + "Actor", + "ActorCriticOperator", + "ActorCriticWrapper", + "ActorValueOperator", + "AdditiveGaussianModule", + "AdditiveGaussianWrapper", "BatchRenorm1d", + "CEMPlanner", "ConsistentDropout", "ConsistentDropoutModule", "Conv3dNet", "ConvNet", + "DTActor", "DdpgCnnActor", "DdpgCnnQNet", "DdpgMlpActor", "DdpgMlpQNet", "DecisionTransformer", - "DreamerActor", - "DTActor", - "DuelingCnnDQNet", - "MLP", - "MultiAgentConvNet", - "MultiAgentMLP", - "MultiAgentNetBase", - "NoisyLazyLinear", - "NoisyLinear", - "ObsDecoder", - "ObsEncoder", - "OnlineDTActor", - "QMixer", - "reset_noise", - "RSSMPosterior", - "RSSMPrior", - "RSSMRollout", - "Squeeze2dLayer", - "SqueezeLayer", - "VDNMixer", - "Actor", - "ActorCriticOperator", - "ActorCriticWrapper", - "ActorValueOperator", - "AdditiveGaussianModule", - "AdditiveGaussianWrapper", "DecisionTransformerInferenceWrapper", + "Delta", + "DistributionalDQNnet", "DistributionalQValueActor", "DistributionalQValueHook", "DistributionalQValueModule", + "DreamerActor", + "DuelingCnnDQNet", "EGreedyModule", "EGreedyWrapper", "GRU", "GRUCell", "GRUModule", + "IndependentNormal", "LMHeadActorValueOperator", "LSTM", "LSTMCell", "LSTMModule", + "MLP", + "MPCPlannerBase", + "MPPIPlanner", + "MaskedCategorical", + "MaskedOneHotCategorical", + "MultiAgentConvNet", + "MultiAgentMLP", + "MultiAgentNetBase", "MultiStepActorWrapper", + "NoisyLazyLinear", + "NoisyLinear", + "NormalParamExtractor", + "NormalParamWrapper", + "ObsDecoder", + "ObsEncoder", + "OneHotCategorical", + "OneHotOrdinal", + "OnlineDTActor", + "Ordinal", "OrnsteinUhlenbeckProcessModule", "OrnsteinUhlenbeckProcessWrapper", "ProbabilisticActor", + "QMixer", "QValueActor", "QValueHook", "QValueModule", - "recurrent_mode", + "RSSMPosterior", + "RSSMPrior", + "RSSMRollout", + "ReparamGradientStrategy", "SafeModule", "SafeProbabilisticModule", "SafeProbabilisticTensorDictSequential", "SafeSequential", - "set_recurrent_mode", + "Squeeze2dLayer", + "SqueezeLayer", + "TanhDelta", "TanhModule", + "TanhNormal", + "TruncatedNormal", + "VDNMixer", "ValueOperator", "VmapModule", "WorldModelWrapper", + "distributions_maps", + "from_hf_transformers", "get_primers_from_module", - "CEMPlanner", - "MPCPlannerBase", - "MPPIPlanner", + "recurrent_mode", + "reset_noise", + "set_recurrent_mode", ] diff --git a/torchrl/modules/llm/__init__.py b/torchrl/modules/llm/__init__.py new file mode 100644 index 00000000000..467ecfd24aa --- /dev/null +++ b/torchrl/modules/llm/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .transformers_policy import from_hf_transformers + +__all__ = ["from_hf_transformers"] diff --git a/torchrl/modules/llm/transformers_policy.py b/torchrl/modules/llm/transformers_policy.py new file mode 100644 index 00000000000..7494fe8b10b --- /dev/null +++ b/torchrl/modules/llm/transformers_policy.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# TODO: lazy imports + +import torch + +import transformers +from tensordict import NestedKey, TensorDictBase +from tensordict.nn import ( + TensorDictModule as Mod, + TensorDictModuleBase, + TensorDictSequential as Seq, + WrapModule, +) +from tensordict.tensorclass import NonTensorData, NonTensorStack +from torchrl.data.llm import LLMData +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + + +def _maybe_clear_device(td): + if td.device is None: + return td + return td.set(NonTensorData("_source_device"), td.device).clear_device_() + + +def _maybe_set_device(td): + device = td.pop("_source_device", None) + if device is None: + return td + elif isinstance(device, NonTensorData): + device: torch.device = device.data + return td.to(device) + + +def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: + """Computes the log_probs from a Transformer formatted TensorDict. + + Required keys in tensordict: + + - "tokens_out": containing + + - "scores": logits of shape (B, seq-len, vocab_size) + - "sequences": token sequences of shape (B, seq-len) + + Written keys in tensordict: + + - "logits": normalized scores of shape (B, seq-len, vocab_size) + - "log_probs": log probabilities of shape (B, seq-len, 1) + + Note: The following keys will be deleted from the tensordict: + + - "tokens_out", "past_key_values" + - "tokens_out", "scores" + + """ + # TODO: how do we avoid getting these? + del td["tokens_out", "past_key_values"] + scores = dict(td["tokens_out", "scores"].items()) + scores = torch.stack( + [scores[str(k)] for k in range(len(scores))], 1 + ) # shape (B, seq-len, vocab_size) + logits = scores - scores.logsumexp(dim=-1, keepdim=True) + td["logits"] = scores + del td["tokens_out", "scores"] + seq_len = scores.shape[1] + tokens = td["tokens_out", "sequences"][..., -seq_len:] # shape (B, seq-len) + log_probs = logits.gather(-1, tokens.unsqueeze(-1)) + td["log_probs"] = log_probs + return td + + +def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: + """Computes the log_probs from a Transformer formatted TensorDict. + + Required keys in tensordict: + + - "forward": containing + - "logits": logits of shape (B, seq-len, vocab_size) + - "tokens_in": containing + - "input_ids": token sequences of shape (B, seq-len) + + Written keys in tensordict: + + - "logits": normalized scores of shape (B, seq-len, vocab_size) + - "log_probs": log probabilities of shape (B, seq-len, 1) + + Note: The following keys will be deleted from the tensordict: + - "forward", "past_key_values" + - "forward" + """ + # TODO: how do we avoid getting these? + del td["forward", "past_key_values"] + scores = td["forward", "logits"] + logits = scores - scores.logsumexp(dim=-1, keepdim=True) + td["logits"] = scores + del td["forward"] + scores.shape[1] + tokens = td["tokens_in", "input_ids"] + log_probs = logits.gather(-1, tokens.unsqueeze(-1)) + td["log_probs"] = log_probs + return td + + +def from_hf_transformers( + model: transformers.modeling_utils.PreTrainedModel, + *, + generate: bool = True, + return_log_probs: bool = True, + tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, + from_text: bool = False, + device: torch.device | None = None, + # Keys: + text_key: NestedKey = "text", + token_key: NestedKey = "tokens", + attention_mask_key: NestedKey = "attention_mask", + kwargs: dict | None = None, + tokenizer_kwargs: dict | None = None, +) -> TensorDictModuleBase: + # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks + + + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + if from_text: + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + # TODO: add other paddings + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + + module_dict["encode"] = Mod( + tokenizer, + in_keys=[text_key], + out_keys=["tokens_in"], + method_kwargs=tokenizer_kwargs, + strict=True, + # We don't need the text after this + inplace=False, + ) + else: + module_dict["format"] = Mod( + lambda *x: x, + in_keys=[token_key, attention_mask_key], + out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + strict=False, + # We don't need the text after this + inplace=False, + ) + + if device: + module_dict["to_dest_device"] = Mod( + lambda tensor: tensor.to(device), + in_keys=["tokens_in"], + out_keys=["tokens_in"], + strict=True, + ) + + if generate: + if not kwargs: + kwargs = {} + if return_log_probs: + if not kwargs.setdefault("output_scores", True): + raise RuntimeError + if not kwargs.setdefault("return_dict_in_generate", True): + raise RuntimeError + if ( + kwargs.setdefault("tokenizer", tokenizer) is not tokenizer + and tokenizer is not None + ): + raise RuntimeError + + module_dict["generate"] = Mod( + model, + method="generate", + method_kwargs=kwargs, + in_keys={ + "input_ids": ("tokens_in", "input_ids"), + "attention_mask": ("tokens_in", "attention_mask"), + }, + out_keys=["tokens_out"], + out_to_in_map=True, + strict=True, + ) + if return_log_probs: + module_dict["extract_log_probs"] = WrapModule( + log_probs_from_scores, + in_keys=[("tokens_out", "sequences"), ("tokens_out", "scores")], + out_keys=["logits", "log_probs"], + ) + if from_text: + module_dict["decode"] = Mod( + tokenizer.batch_decode, + in_keys=[("tokens_out", "sequences")], + out_keys=["text_out"], + strict=True, + ) + if device: + module_dict["to_source_device"] = _maybe_set_device + module_dict["rebuild"] = Mod( + lambda *x: x, + in_keys=[ + ("tokens_out", "sequences"), + ("tokens_in", "input_ids"), + ("tokens_in", "attention_mask"), + "text_out", + "log_probs", + "logits", + ], + out_keys=[ + "tokens_response", + "tokens", + "attention_mask", + "text_response", + "log_probs", + "logits", + ], + strict=True, + inplace=False, + ) + else: + if device: + module_dict["to_source_device"] = _maybe_set_device + module_dict["rebuild"] = Mod( + lambda *x: x, + in_keys=[("tokens_out", "sequences"), "log_probs", "logits"], + out_keys=["tokens_response", "log_probs", "logits"], + inplace=False, + ) + else: + if not kwargs: + kwargs = {} + if not kwargs.setdefault("return_dict", True): + raise RuntimeError + if not return_log_probs: + raise RuntimeError + module_dict["get_logprobs"] = Mod( + model, + method_kwargs=kwargs, + in_keys={ + "input_ids": ("tokens_in", "input_ids"), + "attention_mask": ("tokens_in", "attention_mask"), + }, + out_keys=["forward"], + out_to_in_map=True, + strict=True, + ) + module_dict["extract_log_probs"] = WrapModule( + log_probs_from_logits, + in_keys=[("tokens_in", "input_ids"), ("forward", "logits")], + out_keys=["logits", "log_probs"], + ) + if device: + module_dict["to_source_device"] = _maybe_set_device + if from_text: + module_dict["rebuild"] = Mod( + lambda *x: x, + in_keys=["log_probs", "logits", ("tokens_in", "attention_mask")], + out_keys=["log_probs", "logits", "attention_mask"], + inplace=False, + ) + else: + module_dict["rebuild"] = Mod( + lambda *x: x, + in_keys=["log_probs", "logits"], + out_keys=["log_probs", "logits"], + inplace=False, + ) + + return Seq(module_dict, inplace=True) + + +if __name__ == "__main__": + max_seq_length = 50000 + + tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer.pad_token = tokenizer.eos_token + model = GPT2LMHeadModel(GPT2Config()) + + tokenizer.padding_side = "left" + + m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, generate=True) + td = m(LLMData(text=NonTensorStack("a text"), batch_size=1)) + + m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, generate=False) + td = m(LLMData(text=NonTensorStack("a text"), batch_size=1)) + + m = from_hf_transformers(model, tokenizer=tokenizer, from_text=False, generate=True) + td = m( + LLMData( + tokens=torch.randint(1024, (1, 10)), + attention_mask=torch.ones(1, 10, dtype=torch.int64), + batch_size=1, + ) + ) + + m = from_hf_transformers(model, tokenizer=tokenizer, from_text=False, generate=True) + td = m(LLMData(tokens=torch.randint(1024, (1, 10)), batch_size=1)) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index e4b91c1a543..358ef2006d2 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -2208,8 +2208,8 @@ class MultiStepActorWrapper(TensorDictModuleBase): Args: actor (TensorDictModuleBase): An actor. - n_steps (int): the number of actions the actor outputs at once - (lookahead window). + n_steps (int, optional): the number of actions the actor outputs at once + (lookahead window). Defaults to `None`. Keyword Args: action_keys (list of NestedKeys, optional): the action keys from @@ -2220,6 +2220,8 @@ class MultiStepActorWrapper(TensorDictModuleBase): when the environment has gone through a reset. Defaults to ``"is_init"`` which is the ``out_key`` from the :class:`~torchrl.envs.transforms.InitTracker` transform. + keep_dim (bool, optional): whether to keep the time dimension of + the macro during indexing. Defaults to ``False``. Examples: >>> import torch.nn @@ -2288,14 +2290,16 @@ class MultiStepActorWrapper(TensorDictModuleBase): def __init__( self, actor: TensorDictModuleBase, - n_steps: int, + n_steps: int | None = None, *, action_keys: list[NestedKey] | None = None, init_key: list[NestedKey] | None = None, + keep_dim: bool = False, ): self.action_keys = action_keys self.init_key = init_key self.n_steps = n_steps + self.keep_dim = keep_dim super().__init__() self.actor = actor @@ -2367,7 +2371,10 @@ def forward( action_entry = parent_td.get(action_key_orig[-1], None) if action_entry is None: raise self._NO_INIT_ERR - if action_entry.shape[parent_td.ndim] != self.n_steps: + if ( + self.n_steps is not None + and action_entry.shape[parent_td.ndim] != self.n_steps + ): raise RuntimeError( f"The action's time dimension (dim={parent_td.ndim}) doesn't match the n_steps argument ({self.n_steps}). " f"The action shape was {action_entry.shape}." @@ -2377,7 +2384,10 @@ def forward( None, ), ) * parent_td.ndim - cur_action = action_entry[base_idx + (0,)] + if not self.keep_dim: + cur_action = action_entry[base_idx + (0,)] + else: + cur_action = action_entry[base_idx + (slice(1),)] tensordict.set(action_key, cur_action) tensordict.set( action_key_orig, diff --git a/tutorials/sphinx-tutorials-save/README.rst b/tutorials/sphinx-tutorials-save/README.rst deleted file mode 100644 index 7995a1fbb2e..00000000000 --- a/tutorials/sphinx-tutorials-save/README.rst +++ /dev/null @@ -1,4 +0,0 @@ -README Tutos -============ - -Check the tutorials on torchrl documentation: https://pytorch.org/rl