diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 67cbda48f14..aec148f9523 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4941,7 +4941,7 @@ def set(self, name: str, spec: TensorSpec) -> Composite: spec.shape = self.shape else: raise ValueError( - f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first " + f"The shapes of the spec {type(spec).__name__} and the {type(self).__name__} mismatch: the first " f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " f"Composite.shape={self.shape}." ) diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index d4f38843ba1..1fb7666d24d 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -6,18 +6,20 @@ from collections import deque from collections.abc import Mapping -from copy import copy, deepcopy +from copy import copy from typing import Any, Callable, Iterable, Literal import torch from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key -from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams +from tensordict.nn import ( + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, +) from tensordict.utils import _zip_strict, is_seq_of_nested_key -from torch import nn from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform -from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param +from torchrl.envs.transforms.utils import _set_missing_tolerance from torchrl.envs.utils import make_composite_from_td @@ -500,6 +502,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): return self._queue.popleft() return out + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})" + class KLRewardTransform(Transform): """A transform to add a KL[pi_current||pi_0] correction term to the reward. diff --git a/torchrl/modules/llm/transformers_policy.py b/torchrl/modules/llm/transformers_policy.py index c506481746e..e977fc18be6 100644 --- a/torchrl/modules/llm/transformers_policy.py +++ b/torchrl/modules/llm/transformers_policy.py @@ -53,11 +53,11 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: - "tokens_out", "scores" """ - # TODO: how do we avoid getting these? tokens_out = td["tokens_out", "sequences"] seq_len = tokens_out.shape[1] del td["tokens_out", "past_key_values"] + scores = dict(td["tokens_out", "scores"].items()) scores = torch.stack( [scores[str(k)] for k in range(len(scores))], 1 @@ -90,15 +90,18 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: - "forward", "past_key_values" - "forward" """ - # TODO: how do we avoid getting these? + tokens_out = td["tokens_response", "input_ids"] + seq_len = tokens_out.shape[-1] + del td["forward", "past_key_values"] + scores = td["forward", "logits"] + scores = scores[..., -seq_len:, :] logits = scores - scores.logsumexp(dim=-1, keepdim=True) td["logits"] = scores del td["forward"] scores.shape[1] - tokens = td["tokens_in", "input_ids"] - log_probs = logits.gather(-1, tokens.unsqueeze(-1)) + log_probs = logits.gather(-1, tokens_out.unsqueeze(-1)) td["log_probs"] = log_probs return td