From d5cdc481955a991fe3bf6f950028122b7e968f2b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 16:33:24 -0800 Subject: [PATCH 1/3] Update [ghstack-poisoned] --- torchrl/modules/llm/transformers.py | 186 ++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 torchrl/modules/llm/transformers.py diff --git a/torchrl/modules/llm/transformers.py b/torchrl/modules/llm/transformers.py new file mode 100644 index 00000000000..79fb1fa458e --- /dev/null +++ b/torchrl/modules/llm/transformers.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# TODO: lazy imports + +from transformers import AutoModelForCausalLM, AutoTokenizer +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq, TensorDictModuleBase, WrapModule +from tensordict import NestedKey, TensorDictBase, TensorDict +import transformers +import torch + +def _maybe_clear_device(td): + if td.device is None: + return td + return td.set(NonTensorData("_source_device"), td.device).clear_device_() + + +def _maybe_set_device(td): + device = td.pop("_source_device", None) + if device is None: + return td + elif isinstance(device, NonTensorData): + device: torch.device = device.data + return td.to(device) + + +def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: + # TODO: how do we avoid getting these? + del td["tokens_out", "past_key_values"] + scores = dict(td["tokens_out", "scores"].items()) + scores = torch.stack([scores[str(k)] for k in range(len(scores))], 1) # shape (B, seq-len, vocab_size) + logits = scores - scores.logsumexp(dim=-1, keepdim=True) + td["logits"] = scores + del td["tokens_out", "scores"] + seq_len = scores.shape[1] + tokens = td["tokens_out", "sequences"][..., -seq_len:] # shape (B, seq-len) + log_probs = logits.gather(-1, tokens.unsqueeze(-1)) + td["log_probs"] = log_probs + return td + +def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: + # TODO: how do we avoid getting these? + del td["forward", "past_key_values"] + scores = td["forward", "logits"] + logits = scores - scores.logsumexp(dim=-1, keepdim=True) + td["logits"] = scores + del td["forward"] + seq_len = scores.shape[1] + tokens = td["tokens_in", "input_ids"] + log_probs = logits.gather(-1, tokens.unsqueeze(-1)) + td["log_probs"] = log_probs + return td + + +def from_hf_transformers( + model: transformers.modeling_utils.PreTrainedModel, + *, + generate: bool = True, + return_log_probs: bool = True, + tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, + from_text: bool = False, + device: torch.device | None = None, + text_key: NestedKey = "text", + input_key: NestedKey = "input_ids", + kwargs: dict | None = None, + tokenizer_kwargs: dict | None = None, + ) -> TensorDictModuleBase: + + # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks + + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + if from_text: + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + # TODO: add other paddings + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + + module_dict["encode"] = Mod( + tokenizer, + in_keys=[text_key], + out_keys=["tokens_in"], + method_kwargs=tokenizer_kwargs, + strict=True, + ) + if device: + module_dict["to_dest_device"] = Mod( + lambda tensor: tensor.to(device), + in_keys=["tokens_in"], + out_keys=["tokens_in"], + strict=True + ) + + if generate: + if not kwargs: + kwargs = {} + if return_log_probs: + if not kwargs.setdefault("output_scores", True): + raise RuntimeError + if not kwargs.setdefault("return_dict_in_generate", True): + raise RuntimeError + if kwargs.setdefault("tokenizer", tokenizer) is not tokenizer and tokenizer is not None: + raise RuntimeError + + module_dict["generate"] = Mod( + model, + method="generate", + method_kwargs=kwargs, + in_keys={ + "input_ids": ("tokens_in", "input_ids"), + "attention_mask": ("tokens_in", "attention_mask"), + }, + out_keys=["tokens_out"], + out_to_in_map=True, + strict=True, + ) + if return_log_probs: + module_dict["extract_log_probs"] = WrapModule( + log_probs_from_scores, + in_keys=[("tokens_out", "sequences"), ("tokens_out", "scores")], + out_keys=["logits", "log_probs"] + ) + if from_text: + module_dict["decode"] = Mod( + tokenizer.batch_decode, + in_keys=[("tokens_out", "sequences")], + out_keys=["action"], + strict=True, + ) + + else: + if not kwargs: + kwargs = {} + if not kwargs.setdefault("return_dict", True): + raise RuntimeError + if not return_log_probs: + raise RuntimeError + module_dict["get_logprobs"] = Mod( + model, + method_kwargs=kwargs, + in_keys={ + "input_ids": ("tokens_in", "input_ids"), + "attention_mask": ("tokens_in", "attention_mask"), + }, + out_keys=["forward"], + out_to_in_map=True, + strict=True, + ) + module_dict["extract_log_probs"] = WrapModule( + log_probs_from_logits, + in_keys=[("tokens_in", "input_ids"), ("forward", "logits")], + out_keys=["logits", "log_probs"] + ) + if device: + module_dict["to_source_device"] = _maybe_set_device + return Seq(module_dict) + + +if __name__ == "__main__": + max_seq_length = 50000 + model_name = "Qwen/Qwen2.5-7B-Instruct" + + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + tokenizer.padding_side = "left" + + m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=True) + td = m(TensorDict(text="a text")) + + m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=False) + td = m(TensorDict(text="a text")) From d37086784b35433c083085833a8d96af1435eec1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 3 Mar 2025 17:40:19 -0800 Subject: [PATCH 2/3] Update [ghstack-poisoned] --- torchrl/data/postprocs/postprocs.py | 2 +- torchrl/envs/custom/llm.py | 34 +++++++------- torchrl/envs/transforms/rlhf.py | 2 +- torchrl/modules/llm/transformers.py | 72 +++++++++++++++++------------ 4 files changed, 62 insertions(+), 48 deletions(-) diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 53d283dbdad..a0ded99c892 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -12,7 +12,6 @@ from torch import nn - def _get_reward( gamma: float, reward: torch.Tensor, @@ -367,6 +366,7 @@ def __init__( discount: float = 1.0, ): from torchrl.objectives.value.functional import reward2go + super().__init__() self.in_keys = [unravel_key(reward_key), unravel_key(done_key)] if reward_key_out is None: diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 7282a7f60b2..13778f3c5d7 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -80,7 +80,9 @@ def __init__( self._batch_locked = False else: self._batch_locked = True - super().__init__(device=device, batch_size=() if batch_size is None else (batch_size,)) + super().__init__( + device=device, batch_size=() if batch_size is None else (batch_size,) + ) self.str2str = str2str self.vocab_size = vocab_size self.observation_key = unravel_key(token_key) @@ -92,11 +94,7 @@ def __init__( # self.action_key = unravel_key(action_key) if str2str: self.full_observation_spec_unbatched = Composite( - { - token_key: NonTensor( - example_data="a string", batched=True, shape=() - ) - } + {token_key: NonTensor(example_data="a string", batched=True, shape=())} ) self.full_action_spec_unbatched = Composite( {action_key: NonTensor(example_data="a string", batched=True, shape=())} @@ -104,17 +102,13 @@ def __init__( else: if vocab_size is None: observation_spec = { - token_key: Unbounded( - shape=(-1,), dtype=torch.int64, device=device - ) - } + token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device) + } if attention_key is not None: observation_spec[attention_key] = Unbounded( - shape=(-1,), dtype=torch.int64, device=device - ) - self.full_observation_spec_unbatched = Composite( - observation_spec - ) + shape=(-1,), dtype=torch.int64, device=device + ) + self.full_observation_spec_unbatched = Composite(observation_spec) self.full_action_spec_unbatched = Composite( { action_key: Unbounded( @@ -325,7 +319,13 @@ def _make_next_obs( if self.attention_key is not None: attention_mask = tensordict.get(self.attention_key) n = action.shape[-1] - attention_mask.shape[-1] - attention_mask = torch.cat([attention_mask, attention_mask.new_ones(attention_mask.shape[:-1] + (n,))], -1) + attention_mask = torch.cat( + [ + attention_mask, + attention_mask.new_ones(attention_mask.shape[:-1] + (n,)), + ], + -1, + ) nex_td.set(self.attention_key, attention_mask) return nex_td @@ -384,7 +384,7 @@ def _make_next_obs( def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # We should have an observation by this time, if not raise an exception - print('tensordict', tensordict) + print("tensordict", tensordict) if tensordict is None or self.observation_key not in tensordict.keys( isinstance(self.observation_key, tuple) ): diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 85bdd1f4bbc..99fb89efe62 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -461,7 +461,7 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): raise ValueError( f"Unrecognized data type: {type(data)} with keys {self.data_keys}." ) - print('out', out) + print("out", out) if self.use_buffer: if not out.ndim: out = out.unsqueeze(0) diff --git a/torchrl/modules/llm/transformers.py b/torchrl/modules/llm/transformers.py index 79fb1fa458e..19e6b21b57e 100644 --- a/torchrl/modules/llm/transformers.py +++ b/torchrl/modules/llm/transformers.py @@ -5,11 +5,17 @@ # TODO: lazy imports -from transformers import AutoModelForCausalLM, AutoTokenizer -from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq, TensorDictModuleBase, WrapModule -from tensordict import NestedKey, TensorDictBase, TensorDict -import transformers import torch +import transformers +from tensordict import NestedKey, TensorDict, TensorDictBase +from tensordict.nn import ( + TensorDictModule as Mod, + TensorDictModuleBase, + TensorDictSequential as Seq, + WrapModule, +) +from transformers import AutoModelForCausalLM, AutoTokenizer + def _maybe_clear_device(td): if td.device is None: @@ -30,7 +36,9 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: # TODO: how do we avoid getting these? del td["tokens_out", "past_key_values"] scores = dict(td["tokens_out", "scores"].items()) - scores = torch.stack([scores[str(k)] for k in range(len(scores))], 1) # shape (B, seq-len, vocab_size) + scores = torch.stack( + [scores[str(k)] for k in range(len(scores))], 1 + ) # shape (B, seq-len, vocab_size) logits = scores - scores.logsumexp(dim=-1, keepdim=True) td["logits"] = scores del td["tokens_out", "scores"] @@ -40,6 +48,7 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: td["log_probs"] = log_probs return td + def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: # TODO: how do we avoid getting these? del td["forward", "past_key_values"] @@ -47,7 +56,7 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: logits = scores - scores.logsumexp(dim=-1, keepdim=True) td["logits"] = scores del td["forward"] - seq_len = scores.shape[1] + scores.shape[1] tokens = td["tokens_in", "input_ids"] log_probs = logits.gather(-1, tokens.unsqueeze(-1)) td["log_probs"] = log_probs @@ -55,18 +64,18 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: def from_hf_transformers( - model: transformers.modeling_utils.PreTrainedModel, - *, - generate: bool = True, - return_log_probs: bool = True, - tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, - from_text: bool = False, - device: torch.device | None = None, - text_key: NestedKey = "text", - input_key: NestedKey = "input_ids", - kwargs: dict | None = None, - tokenizer_kwargs: dict | None = None, - ) -> TensorDictModuleBase: + model: transformers.modeling_utils.PreTrainedModel, + *, + generate: bool = True, + return_log_probs: bool = True, + tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, + from_text: bool = False, + device: torch.device | None = None, + text_key: NestedKey = "text", + input_key: NestedKey = "input_ids", + kwargs: dict | None = None, + tokenizer_kwargs: dict | None = None, +) -> TensorDictModuleBase: # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks @@ -98,7 +107,7 @@ def from_hf_transformers( lambda tensor: tensor.to(device), in_keys=["tokens_in"], out_keys=["tokens_in"], - strict=True + strict=True, ) if generate: @@ -109,7 +118,10 @@ def from_hf_transformers( raise RuntimeError if not kwargs.setdefault("return_dict_in_generate", True): raise RuntimeError - if kwargs.setdefault("tokenizer", tokenizer) is not tokenizer and tokenizer is not None: + if ( + kwargs.setdefault("tokenizer", tokenizer) is not tokenizer + and tokenizer is not None + ): raise RuntimeError module_dict["generate"] = Mod( @@ -128,8 +140,8 @@ def from_hf_transformers( module_dict["extract_log_probs"] = WrapModule( log_probs_from_scores, in_keys=[("tokens_out", "sequences"), ("tokens_out", "scores")], - out_keys=["logits", "log_probs"] - ) + out_keys=["logits", "log_probs"], + ) if from_text: module_dict["decode"] = Mod( tokenizer.batch_decode, @@ -159,8 +171,8 @@ def from_hf_transformers( module_dict["extract_log_probs"] = WrapModule( log_probs_from_logits, in_keys=[("tokens_in", "input_ids"), ("forward", "logits")], - out_keys=["logits", "log_probs"] - ) + out_keys=["logits", "log_probs"], + ) if device: module_dict["to_source_device"] = _maybe_set_device return Seq(module_dict) @@ -171,16 +183,18 @@ def from_hf_transformers( model_name = "Qwen/Qwen2.5-7B-Instruct" model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype="auto", - device_map="auto" + model_name, torch_dtype="auto", device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.padding_side = "left" - m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=True) + m = from_hf_transformers( + model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=True + ) td = m(TensorDict(text="a text")) - m = from_hf_transformers(model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=False) + m = from_hf_transformers( + model, tokenizer=tokenizer, from_text=True, device="cuda:0", generate=False + ) td = m(TensorDict(text="a text")) From fc163e473a64d437b31ac80fb83c4aebabc47030 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 4 Mar 2025 17:36:34 -0800 Subject: [PATCH 3/3] Update [ghstack-poisoned] --- test/test_env.py | 159 ++++++++++----- torchrl/data/replay_buffers/storages.py | 6 +- torchrl/envs/custom/llm.py | 202 ++++++++++++++------ torchrl/envs/libs/unity_mlagents.py | 1 - torchrl/envs/transforms/rlhf.py | 45 +++-- torchrl/envs/transforms/transforms.py | 2 +- torchrl/envs/utils.py | 34 +++- torchrl/modules/llm/transformers.py | 2 + torchrl/modules/tensordict_module/actors.py | 17 +- tutorials/sphinx-tutorials-save/README.rst | 4 - 10 files changed, 326 insertions(+), 146 deletions(-) delete mode 100644 tutorials/sphinx-tutorials-save/README.rst diff --git a/test/test_env.py b/test/test_env.py index 6f962538f95..f4b83b92d19 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4581,11 +4581,13 @@ def __next__(self): @pytest.mark.parametrize("batch_size", [0, 4]) @pytest.mark.parametrize("device", [None, "cpu"]) def test_llm_env(self, str2str, batched, stack_method, device, batch_size): - env = LLMEnv(str2str=str2str, device=device) + env = LLMEnv( + str2str=str2str, device=device, has_attention=False, no_stack=False + ) if str2str: primer = DataLoadingPrimer( dataloader=self.DummyDataLoader(batch_size=batch_size), - data_keys=["observation"], + data_keys=[LLMEnv._DEFAULT_STR_KEY], example_data="a string!", ) else: @@ -4595,7 +4597,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size): dataloader=self.DummyTensorDataLoader( batch_size=batch_size, padding=True ), - data_keys=["observation"], + data_keys=[LLMEnv._DEFAULT_TOKEN_KEY], data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], stack_method=stack_method, ) @@ -4605,7 +4607,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size): if batched: td = env.reset(TensorDict(batch_size=[3])) env.check_env_specs(break_when_any_done="both", tensordict=td) - r = env.rollout(10, tensordict=TensorDict(batch_size=[3])) + env.rollout(10, tensordict=TensorDict(batch_size=[3])) else: env.check_env_specs(break_when_any_done="both") @@ -4628,7 +4630,7 @@ def test_llm_from_dataloader( if str2str: kwargs = { "dataloader": self.DummyDataLoader(batch_size=batch_size), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_STR_KEY], "example_data": "a string!", } else: @@ -4638,11 +4640,18 @@ def test_llm_from_dataloader( "dataloader": self.DummyTensorDataLoader( padding=True, batch_size=batch_size ), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_TOKEN_KEY], "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], "stack_method": stack_method, } - kwargs.update({"str2str": str2str, "device": device}) + kwargs.update( + { + "str2str": str2str, + "device": device, + "has_attention": False, + "no_stack": False, + } + ) env = LLMEnv.from_dataloader(**kwargs) assert not env.batch_locked if batched: @@ -4655,13 +4664,15 @@ def test_llm_from_dataloader( def policy(td): if str2str: if not td.shape: - td["action"] = "" + td[LLMEnv._DEFAULT_ACTION_KEY] = "" else: - td["action"] = NonTensorStack( + td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack( *["" for _ in range(td.shape[0])] ) else: - td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64) + td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones( + td.shape + (1,), dtype=torch.int64 + ) return td if batched: @@ -4669,32 +4680,48 @@ def policy(td): r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3])) assert r.ndim == 2 if str2str: - assert isinstance(r[0, 0]["observation"], str) - assert isinstance(r[0, 1]["observation"], str) + assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str) + assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str) assert ( - r[0, 0]["observation"] - == r[0, 1]["observation"][: -len(r[0, 0]["action"])] + r[0, 0][LLMEnv._DEFAULT_STR_KEY] + == r[0, 1][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[0, 0][LLMEnv._DEFAULT_ACTION_KEY]) + ] ) assert ( - r[0, 1]["observation"] - == r[0, 2]["observation"][: -len(r[0, 1]["action"])] + r[0, 1][LLMEnv._DEFAULT_STR_KEY] + == r[0, 2][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[0, 1][LLMEnv._DEFAULT_ACTION_KEY]) + ] ) assert ( - r[-1, 0]["observation"] - == r[-1, 1]["observation"][: -len(r[-1, 0]["action"])] + r[-1, 0][LLMEnv._DEFAULT_STR_KEY] + == r[-1, 1][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_KEY]) + ] ) assert ( - r[-1, 1]["observation"] - == r[-1, 2]["observation"][: -len(r[-1, 1]["action"])] + r[-1, 1][LLMEnv._DEFAULT_STR_KEY] + == r[-1, 2][LLMEnv._DEFAULT_STR_KEY][ + : -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_KEY]) + ] ) else: - assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all() - assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all() assert ( - r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1] + r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] + ).all() + assert ( + r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY] + == r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] ).all() assert ( - r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1] + r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] + ).all() + assert ( + r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY] + == r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] ).all() else: r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[])) @@ -4720,7 +4747,7 @@ def test_llm_from_dataloader_repeats( if str2str: kwargs = { "dataloader": self.DummyDataLoader(batch_size=batch_size), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_STR_KEY], "example_data": "a string!", "repeats": repeats, } @@ -4731,12 +4758,19 @@ def test_llm_from_dataloader_repeats( "dataloader": self.DummyTensorDataLoader( padding=True, batch_size=batch_size ), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_TOKEN_KEY], "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], "stack_method": stack_method, "repeats": repeats, } - kwargs.update({"str2str": str2str, "device": device}) + kwargs.update( + { + "str2str": str2str, + "device": device, + "has_attention": False, + "no_stack": False, + } + ) env = LLMEnv.from_dataloader(**kwargs) assert env.transform.repeats == repeats @@ -4746,13 +4780,15 @@ def test_llm_from_dataloader_repeats( def policy(td): if str2str: if not td.shape: - td["action"] = "" + td[LLMEnv._DEFAULT_ACTION_KEY] = "" else: - td["action"] = NonTensorStack( + td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack( *["" for _ in range(td.shape[0])] ) else: - td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64) + td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones( + td.shape + (1,), dtype=torch.int64 + ) return td if batched: @@ -4768,34 +4804,58 @@ def policy(td): r_reset = r[..., ::max_steps] if not batched: if str2str: - assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"] - assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"] - assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"] + assert ( + r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY] + == r_reset[..., 1][LLMEnv._DEFAULT_STR_KEY] + ) + assert ( + r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY] + == r_reset[..., 2][LLMEnv._DEFAULT_STR_KEY] + ) + assert ( + r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY] + != r_reset[..., 3][LLMEnv._DEFAULT_STR_KEY] + ) else: assert ( - r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"] + r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r_reset[..., 1][LLMEnv._DEFAULT_TOKEN_KEY] ).all() assert ( - r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"] + r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r_reset[..., 2][LLMEnv._DEFAULT_TOKEN_KEY] ).all() assert ( - r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"] + r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY] + != r_reset[..., 3][LLMEnv._DEFAULT_TOKEN_KEY] ).any() else: # When batched, each block contains the 3 reset packs if str2str: - assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"] - assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"] - assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"] + assert ( + r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY] + == r_reset[1, 0][LLMEnv._DEFAULT_STR_KEY] + ) + assert ( + r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY] + == r_reset[2, 0][LLMEnv._DEFAULT_STR_KEY] + ) + assert ( + r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY] + != r_reset[0, 1][LLMEnv._DEFAULT_STR_KEY] + ) else: assert ( - r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"] + r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r_reset[1, 0][LLMEnv._DEFAULT_TOKEN_KEY] ).all() assert ( - r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"] + r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] + == r_reset[2, 0][LLMEnv._DEFAULT_TOKEN_KEY] ).all() assert ( - r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"] + r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] + != r_reset[0, 1][LLMEnv._DEFAULT_TOKEN_KEY] ).any() @pytest.mark.parametrize( @@ -4829,7 +4889,7 @@ def test_done_and_reward( if str2str: kwargs = { "dataloader": self.DummyDataLoader(batch_size=batch_size), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_STR_KEY], "example_data": "a string!", "repeats": repeats, "assign_reward": assign_reward, @@ -4842,20 +4902,27 @@ def test_done_and_reward( "dataloader": self.DummyTensorDataLoader( padding=True, batch_size=batch_size ), - "data_keys": ["observation"], + "data_keys": [LLMEnv._DEFAULT_TOKEN_KEY], "data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)], "stack_method": stack_method, "repeats": repeats, "assign_reward": assign_reward, "assign_done": assign_done, } - kwargs.update({"str2str": str2str, "device": device}) + kwargs.update( + { + "str2str": str2str, + "device": device, + "has_attention": False, + "no_stack": False, + } + ) env = LLMEnv.from_dataloader(**kwargs) # We want to make sure that transforms that rely on the done state work appropriately env.append_transform(StepCounter(max_steps=10)) def policy(td): - td["action"] = torch.ones( + td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones( td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64 ) return td diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 1d6a4ac69e4..c9f6715984d 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -1536,10 +1536,10 @@ def _collate_id(x): def _get_default_collate(storage, _is_tensordict=False): - if isinstance(storage, ListStorage): - return _stack_anything - elif isinstance(storage, TensorStorage): + if isinstance(storage, LazyStackStorage) or isinstance(storage, TensorStorage): return _collate_id + elif isinstance(storage, ListStorage): + return _stack_anything else: raise NotImplementedError( f"Could not find a default collate_fn for storage {type(storage)}." diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 13778f3c5d7..fad9d0d2c29 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Literal import torch + from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.tensorclass import NonTensorData, NonTensorStack from tensordict.utils import _zip_strict @@ -40,9 +41,16 @@ class LLMEnv(EnvBase): Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via :meth:`~from_dataloader` Keyword Args: - observation_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults to - ``"observation"``. - action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``"action"``. + token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`). + Defaults to ``("tokens_in", "input_ids")``. + str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `str2str=True`). + Defaults to ``"test"``. + attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored. + Defaults to ``("tokens_in", "input_ids")`` + action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to + ``("tokens_out", "sequences")``. + reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`. + Defaults to ``"reward"``. str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an @@ -50,9 +58,19 @@ class LLMEnv(EnvBase): no_stack (bool, optional): If ``False`` (default), the environment should stack the action with the past observation, each action being a new, unseen part of a conversation. Otherwise, the action is assumed to be the plain output of the LLM, including the input tokens / strings. - assign_reward (bool, optional): TODO - assign_done (bool, optional): TODO - reward_key (NestedKey, optional): TODO + has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by + :attr:`attention_key`. Defaults to ``True``. + assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape + is written during calls to `step()`. Defaults to ``False``. + assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the + action shape is written during calls to `step()`. Defaults to ``False``. + + .. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root + as it is a requirement for all TorchRL environments. + + batch_size (int or torch.Size, optional): Batch size of the environment. If left empty, the environment + is batchless (or batch-unlocked), meaning that it can accept tensordicts of any batch size. + Defaults to ``None`` (batch-unlocked). .. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples. @@ -61,31 +79,52 @@ class LLMEnv(EnvBase): """ + _DEFAULT_TOKEN_KEY = ("tokens_in", "input_ids") + _DEFAULT_STR_KEY = "text" + _DEFAULT_ATTENTION_KEY = ("tokens_in", "attention_mask") + _DEFAULT_ACTION_KEY = ("tokens_out", "sequences") + def __init__( self, *, - token_key: NestedKey = "observation", + token_key: NestedKey | None = None, + str_key: NestedKey | None = None, attention_key: NestedKey | None = None, - action_key: NestedKey = "action", + action_key: NestedKey | None = None, + reward_key: NestedKey = "reward", str2str: bool = False, device: torch.device | None = None, vocab_size: int | None = None, - no_stack: bool = False, + no_stack: bool = True, assign_reward: bool = False, assign_done: bool = False, - reward_key: NestedKey = "reward", - batch_size: int | None = None, + batch_size: int | torch.Size | None = None, + has_attention: bool = True, ) -> None: + if token_key is None: + token_key = self._DEFAULT_TOKEN_KEY + if str_key is None: + str_key = self._DEFAULT_STR_KEY + if attention_key is None: + attention_key = self._DEFAULT_ATTENTION_KEY + if action_key is None: + action_key = self._DEFAULT_ACTION_KEY if batch_size is None: self._batch_locked = False + batch_size = () else: self._batch_locked = True + if not isinstance(batch_size, (tuple, list)): + batch_size = (batch_size,) super().__init__( - device=device, batch_size=() if batch_size is None else (batch_size,) + device=device, + batch_size=batch_size, ) + self.has_attention = has_attention self.str2str = str2str self.vocab_size = vocab_size - self.observation_key = unravel_key(token_key) + self.token_key = unravel_key(token_key) + self.str_key = unravel_key(str_key) self.attention_key = unravel_key(attention_key) self.no_stack = no_stack self.assign_reward = assign_reward @@ -94,7 +133,11 @@ def __init__( # self.action_key = unravel_key(action_key) if str2str: self.full_observation_spec_unbatched = Composite( - {token_key: NonTensor(example_data="a string", batched=True, shape=())} + { + self.str_key: NonTensor( + example_data="a string", batched=True, shape=() + ) + } ) self.full_action_spec_unbatched = Composite( {action_key: NonTensor(example_data="a string", batched=True, shape=())} @@ -104,7 +147,7 @@ def __init__( observation_spec = { token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device) } - if attention_key is not None: + if self.has_attention: observation_spec[attention_key] = Unbounded( shape=(-1,), dtype=torch.int64, device=device ) @@ -155,8 +198,8 @@ def __init__( if not self.assign_done: # Use single done self.full_done_spec_unbatched = Composite( - done=Unbounded(shape=(-1,), dtype=torch.bool), - terminated=Unbounded(shape=(-1,), dtype=torch.bool), + done=Unbounded(shape=(1,), dtype=torch.bool), + terminated=Unbounded(shape=(1,), dtype=torch.bool), ) elif self.str2str: raise STR2STR_ERR @@ -176,17 +219,19 @@ def from_dataloader( cls, dataloader: DataLoader, *, - token_key: NestedKey = "observation", + token_key: NestedKey | None = None, + str_key: NestedKey | None = None, attention_key: NestedKey | None = None, - action_key: NestedKey = "action", + action_key: NestedKey | None = None, + reward_key: NestedKey = "reward", str2str: bool = False, device: torch.device | None = None, vocab_size: int | None = None, no_stack: bool = False, - batch_size: int | None = None, + batch_size: int | torch.Size | None = None, + has_attention: bool = True, assign_reward: bool = False, assign_done: bool = False, - reward_key: NestedKey = "reward", primers: Composite | None = None, data_keys: list[NestedKey] | None = None, data_specs: list[TensorSpec] | None = None, @@ -201,10 +246,16 @@ def from_dataloader( Args: dataloader (DataLoader): The dataloader to load data from. - token_key (NestedKey, optional): The key in the tensordict where the observation is stored. Defaults - to ``"observation"``. - attention_key: TODO - action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to ``"action"``. + token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `str2str=False`). + Defaults to ``("tokens_in", "input_ids")``. + str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `str2str=True`). + Defaults to ``"test"``. + attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored. + Defaults to ``("tokens_in", "input_ids")`` + action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to + ``("tokens_out", "sequences")``. + reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`. + Defaults to ``"reward"``. str2str (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``False``. device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an @@ -212,10 +263,19 @@ def from_dataloader( no_stack (bool, optional): If ``False`` (default), the environment should stack the action with the past observation, each action being a new, unseen part of a conversation. Otherwise, the action is assumed to be the plain output of the LLM, including the input tokens / strings. - assign_reward (bool, optional): TODO - assign_done (bool, optional): TODO - reward_key (NestedKey, optional): TODO - batch_size (int, optional): TODO + has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by + :attr:`attention_key`. Defaults to ``True``. + assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape + is written during calls to `step()`. Defaults to ``False``. + assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the + action shape is written during calls to `step()`. Defaults to ``False``. + + .. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root + as it is a requirement for all TorchRL environments. + + batch_size (int or torch.Size, optional): Batch size of the environment. If left empty, the environment + is batchless (or batch-unlocked), meaning that it can accept tensordicts of any batch size. + Defaults to ``None`` (batch-unlocked). primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to ``None``. data_keys (list[NestedKey] | None, optional): The keys to use for each item in the dataloader. If not passed ``observation_key`` will be populated with the data. @@ -234,10 +294,27 @@ def from_dataloader( """ from torchrl.envs import DataLoadingPrimer + if data_keys is None: + if str2str: + if str_key is None: + data_keys = [LLMEnv._DEFAULT_STR_KEY] + else: + data_keys = [str_key] + else: + if token_key is None: + data_keys = [LLMEnv._DEFAULT_TOKEN_KEY] + else: + data_keys = [token_key] + if has_attention: + if attention_key is None: + data_keys.append(LLMEnv._DEFAULT_ATTENTION_KEY) + else: + data_keys.append(attention_key) + primer = DataLoadingPrimer( dataloader=dataloader, primers=primers, - data_keys=data_keys if data_keys is not None else [token_key], + data_keys=data_keys, data_specs=data_specs, example_data=example_data, stack_method=stack_method, @@ -247,14 +324,16 @@ def from_dataloader( str2str=str2str, device=device, token_key=token_key, + str_key=str_key, attention_key=attention_key, action_key=action_key, + reward_key=reward_key, vocab_size=vocab_size, no_stack=no_stack, assign_reward=assign_reward, assign_done=assign_done, - reward_key=reward_key, batch_size=batch_size, + has_attention=has_attention, ) return env.append_transform(primer) @@ -314,24 +393,27 @@ def _make_next_obs( self, tensordict: TensorDictBase, nex_td: TensorDictBase ) -> TensorDictBase: if self.no_stack: + if self.str2str: + raise NotImplementedError action = tensordict.get(self.action_key) - nex_td.set(self.observation_key, action) - if self.attention_key is not None: + nex_td.set(self.token_key, action) + if self.has_attention: attention_mask = tensordict.get(self.attention_key) n = action.shape[-1] - attention_mask.shape[-1] - attention_mask = torch.cat( - [ - attention_mask, - attention_mask.new_ones(attention_mask.shape[:-1] + (n,)), - ], - -1, - ) + if n: + attention_mask = torch.cat( + [ + attention_mask, + attention_mask.new_ones(attention_mask.shape[:-1] + (n,)), + ], + -1, + ) nex_td.set(self.attention_key, attention_mask) return nex_td # Cat action entry with prev obs if self.str2str: - obs = tensordict[self.observation_key] + obs = tensordict[self.str_key] action = tensordict[self.action_key] if not tensordict.batch_size: if not isinstance(obs, str) or not isinstance(action, str): @@ -347,20 +429,15 @@ def _make_next_obs( for (_obs, _action) in _zip_strict(obs, action) ] ) + return nex_td.set(self.str_key, observation) else: try: - obs: torch.Tensor = tensordict.get(self.observation_key) + obs: torch.Tensor = tensordict.get(self.token_key) action = tensordict.get(self.action_key) if getattr(obs, "is_nested", False): observation = torch.nested.as_nested_tensor( [ - torch.cat( - [ - _obs, - _action, - ], - -1, - ) + torch.cat([_obs, _action], -1) for _obs, _action in _zip_strict( obs.unbind(0), action.unbind(0) ) @@ -368,28 +445,29 @@ def _make_next_obs( layout=obs.layout, ) else: - observation = torch.cat( - [ - obs, - action, - ], - -1, - ) + observation = torch.cat([obs, action], -1) except TypeError: raise TypeError( "Failed to cat action and observation tensors. Check that str2str argument is correctly " f"set in {type(self).__name__}." ) - return nex_td.set(self.observation_key, observation) + return nex_td.set(self.token_key, observation) def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: # We should have an observation by this time, if not raise an exception - print("tensordict", tensordict) - if tensordict is None or self.observation_key not in tensordict.keys( - isinstance(self.observation_key, tuple) - ): + def check_token(): + return not self.str2str and ( + self.token_key not in tensordict.keys(isinstance(self.token_key, tuple)) + ) + + def check_str(): + return self.str2str and ( + self.str_key not in tensordict.keys(isinstance(self.str_key, tuple)) + ) + + if tensordict is None or check_token() or check_str(): raise KeyError( - f"Observation key {self.observation_key} is not defined. Make sure a TensorDictPrimer (eg, " + f"Observation key {self.token_key} is not defined. Make sure a TensorDictPrimer (eg, " f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms." ) td_reset = tensordict.copy() diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py index ab3b568cd87..74b93df0ee1 100644 --- a/torchrl/envs/libs/unity_mlagents.py +++ b/torchrl/envs/libs/unity_mlagents.py @@ -132,7 +132,6 @@ def _collect_agents(self, env): for steps_idx in [0, 1]: for behavior in env.behavior_specs.keys(): steps = env.get_steps(behavior)[steps_idx] - is_terminal = steps_idx == 1 agent_ids = steps.agent_id group_ids = steps.group_id diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 99fb89efe62..3b412ac146b 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -24,6 +24,7 @@ from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param +from torchrl.envs.utils import make_composite_from_td def as_nested_tensor(list_of_tensordicts: list[TensorDictBase]) -> TensorDictBase: @@ -374,12 +375,26 @@ def __init__( use_buffer = True self.use_buffer = use_buffer + if self.use_buffer: + self._queue = deque() + # No auto_batch_size if we know we have a single element self.auto_batch_size = auto_batch_size and ( getattr(dataloader, "batch_size", 1) > 0 ) self.endless_dataloader = self._endless_iter(self.dataloader) - if primers is None: + + if stack_method is None: + stack_method = maybe_dense_stack + elif stack_method == "as_nested_tensor": + stack_method = as_nested_tensor + elif stack_method == "as_padded_tensor": + stack_method = as_padded_tensor + elif not callable(stack_method): + raise ValueError(f"Unknown stack_method={stack_method}") + self.stack_method = stack_method + + if primers is None and not self.use_buffer: if data_keys is None: data_keys = ["data"] if data_specs is None: @@ -391,18 +406,18 @@ def __init__( } ) self.data_keys = data_keys + elif primers is None: + self.data_keys = data_keys + # We can get the primer from the dataloader itself + data = self._load_from_dataloader() + primers = make_composite_from_td(data, dynamic_shape=True) + self._queue.insert(0, data) + if data_keys is None: + self.data_keys = list(primers.keys(True, True)) else: self.data_keys = list(primers.keys(True, True)) - if stack_method is None: - stack_method = maybe_dense_stack - elif stack_method == "as_nested_tensor": - stack_method = as_nested_tensor - elif stack_method == "as_padded_tensor": - stack_method = as_padded_tensor - elif not callable(stack_method): - raise ValueError(f"Unknown stack_method={stack_method}") - self.stack_method = stack_method + self._reset_key = "_reset" super().__init__( primers=primers, @@ -412,9 +427,6 @@ def __init__( single_default_value=True, call_before_env_reset=True, ) - self._reset_key = "_reset" - if self.use_buffer: - self._queue = deque() @classmethod def _endless_iter(self, obj): @@ -445,6 +457,12 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): out = TensorDict.from_dict( data, auto_batch_size=self.auto_batch_size, batch_dims=1 ) + elif self.data_keys is None: + raise RuntimeError( + f"Cannot lazily instantiate the {type(self).__name__} as the data_keys was " + f"not passed but the data is not a Mapping, therefore the keys cannot be retrieved " + f"automatically. Please pass the data_keys to the constructor." + ) elif len(self.data_keys) > 1 and isinstance(data, (list, tuple)): out = TensorDict.from_dict( {k: val for k, val in _zip_strict(self.data_keys, data)}, @@ -461,7 +479,6 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): raise ValueError( f"Unrecognized data type: {type(data)} with keys {self.data_keys}." ) - print("out", out) if self.use_buffer: if not out.ndim: out = out.unsqueeze(0) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 732fb203619..dfae8064a27 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8124,7 +8124,7 @@ def _reset( _reset = tensordict.get(reset_key, None) if _reset is None: done_key = _replace_last(init_key, "done") - shape = self.parent.full_done_spec[done_key].shape + shape = self.parent.full_done_spec[done_key]._safe_shape tensordict_reset.set( init_key, torch.ones( diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index e2d2d7bfe93..66c1ac032b1 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -888,13 +888,19 @@ def _sort_keys(element): return element -def make_composite_from_td(data, unsqueeze_null_shapes: bool = True): +def make_composite_from_td( + data, *, unsqueeze_null_shapes: bool = True, dynamic_shape: bool = False +): """Creates a Composite instance from a tensordict, assuming all values are unbounded. Args: data (tensordict.TensorDict): a tensordict to be mapped onto a Composite. + + Keyword Args: unsqueeze_null_shapes (bool, optional): if ``True``, every empty shape will be unsqueezed to (1,). Defaults to ``True``. + dynamic_shape (bool, optional): if ``True``, all tensors will be assumed to have a dynamic shape + along the last dimension. Defaults to ``False``. Examples: >>> from tensordict import TensorDict @@ -919,18 +925,28 @@ def make_composite_from_td(data, unsqueeze_null_shapes: bool = True): """ # custom function to convert a tensordict in a similar spec structure # of unbounded values. + def make_shape(shape): + if shape or not unsqueeze_null_shapes: + if dynamic_shape: + return shape[:-1] + (-1,) + else: + return shape + return torch.Size([1]) + composite = Composite( { - key: make_composite_from_td(tensor) + key: make_composite_from_td( + tensor, + unsqueeze_null_shapes=unsqueeze_null_shapes, + dynamic_shape=dynamic_shape, + ) if isinstance(tensor, TensorDictBase) - else NonTensor(shape=data.shape, device=tensor.device) + else NonTensor( + shape=data.shape, example_data=data.data, device=tensor.device + ) if is_non_tensor(tensor) else Unbounded( - dtype=tensor.dtype, - device=tensor.device, - shape=tensor.shape - if tensor.shape or not unsqueeze_null_shapes - else [1], + dtype=tensor.dtype, device=tensor.device, shape=make_shape(tensor.shape) ) for key, tensor in data.items() }, @@ -1390,7 +1406,6 @@ def _update_during_reset( if not reset_keys: return tensordict.update(tensordict_reset) roots = set() - print("reset_keys", reset_keys) for reset_key in reset_keys: # get the node of the reset key if isinstance(reset_key, tuple): @@ -1406,7 +1421,6 @@ def _update_during_reset( reset_key_tuple = (reset_key,) # get the reset signal reset = tensordict.pop(reset_key, None) - print("reset popped", reset) # check if this reset should be ignored -- this happens whenever the # root node has already been updated diff --git a/torchrl/modules/llm/transformers.py b/torchrl/modules/llm/transformers.py index 19e6b21b57e..30890653b3e 100644 --- a/torchrl/modules/llm/transformers.py +++ b/torchrl/modules/llm/transformers.py @@ -6,6 +6,7 @@ # TODO: lazy imports import torch + import transformers from tensordict import NestedKey, TensorDict, TensorDictBase from tensordict.nn import ( @@ -14,6 +15,7 @@ TensorDictSequential as Seq, WrapModule, ) +from tensordict.tensorclass import NonTensorData from transformers import AutoModelForCausalLM, AutoTokenizer diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index e4b91c1a543..b1db6c6712a 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -2208,8 +2208,8 @@ class MultiStepActorWrapper(TensorDictModuleBase): Args: actor (TensorDictModuleBase): An actor. - n_steps (int): the number of actions the actor outputs at once - (lookahead window). + n_steps (int, optional): the number of actions the actor outputs at once + (lookahead window). Defaults to `None`. Keyword Args: action_keys (list of NestedKeys, optional): the action keys from @@ -2220,6 +2220,8 @@ class MultiStepActorWrapper(TensorDictModuleBase): when the environment has gone through a reset. Defaults to ``"is_init"`` which is the ``out_key`` from the :class:`~torchrl.envs.transforms.InitTracker` transform. + keep_dim (bool, optional): whether to keep the time dimension of + the macro during indexing. Defaults to ``False``. Examples: >>> import torch.nn @@ -2288,14 +2290,16 @@ class MultiStepActorWrapper(TensorDictModuleBase): def __init__( self, actor: TensorDictModuleBase, - n_steps: int, + n_steps: int | None = None, *, action_keys: list[NestedKey] | None = None, init_key: list[NestedKey] | None = None, + keep_dim: bool = False, ): self.action_keys = action_keys self.init_key = init_key self.n_steps = n_steps + self.keep_dim = keep_dim super().__init__() self.actor = actor @@ -2367,7 +2371,7 @@ def forward( action_entry = parent_td.get(action_key_orig[-1], None) if action_entry is None: raise self._NO_INIT_ERR - if action_entry.shape[parent_td.ndim] != self.n_steps: + if self.n_steps is not None and action_entry.shape[parent_td.ndim] != self.n_steps: raise RuntimeError( f"The action's time dimension (dim={parent_td.ndim}) doesn't match the n_steps argument ({self.n_steps}). " f"The action shape was {action_entry.shape}." @@ -2377,7 +2381,10 @@ def forward( None, ), ) * parent_td.ndim - cur_action = action_entry[base_idx + (0,)] + if not self.keep_dim: + cur_action = action_entry[base_idx + (0,)] + else: + cur_action = action_entry[base_idx + (slice(1),)] tensordict.set(action_key, cur_action) tensordict.set( action_key_orig, diff --git a/tutorials/sphinx-tutorials-save/README.rst b/tutorials/sphinx-tutorials-save/README.rst deleted file mode 100644 index 7995a1fbb2e..00000000000 --- a/tutorials/sphinx-tutorials-save/README.rst +++ /dev/null @@ -1,4 +0,0 @@ -README Tutos -============ - -Check the tutorials on torchrl documentation: https://pytorch.org/rl