Skip to content

Commit 867d870

Browse files
committed
[BugFix] Right log-prob size in transformer wrapper
ghstack-source-id: e7c6f04f19cb5b78191478fe8fcbacf2130efb62 Pull Request resolved: #2854
1 parent fd14ef0 commit 867d870

File tree

4 files changed

+170
-89
lines changed

4 files changed

+170
-89
lines changed

torchrl/data/tensor_specs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4941,7 +4941,7 @@ def set(self, name: str, spec: TensorSpec) -> Composite:
49414941
spec.shape = self.shape
49424942
else:
49434943
raise ValueError(
4944-
f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first "
4944+
f"The shapes of the spec {type(spec).__name__} and the {type(self).__name__} mismatch: the first "
49454945
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
49464946
f"Composite.shape={self.shape}."
49474947
)

torchrl/envs/transforms/llm.py

+62-33
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66

77
from collections import deque
88
from collections.abc import Mapping
9-
from copy import copy, deepcopy
9+
from copy import copy
1010
from typing import Any, Callable, Iterable, Literal
1111

1212
import torch
1313
from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key
14-
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
14+
from tensordict.nn import (
15+
ProbabilisticTensorDictModule,
16+
ProbabilisticTensorDictSequential,
17+
)
1518
from tensordict.utils import _zip_strict, is_seq_of_nested_key
16-
from torch import nn
1719

1820
from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded
1921
from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform
20-
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param
22+
from torchrl.envs.transforms.utils import _set_missing_tolerance
2123
from torchrl.envs.utils import make_composite_from_td
2224

2325

@@ -500,6 +502,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
500502
return self._queue.popleft()
501503
return out
502504

505+
def __repr__(self) -> str:
506+
class_name = self.__class__.__name__
507+
return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})"
508+
503509

504510
class KLRewardTransform(Transform):
505511
"""A transform to add a KL[pi_current||pi_0] correction term to the reward.
@@ -572,6 +578,8 @@ def __init__(
572578
in_keys=None,
573579
out_keys=None,
574580
requires_grad=False,
581+
log_prob_key: NestedKey = "sample_log_prob",
582+
action_key: NestedKey = "action",
575583
):
576584
if in_keys is None:
577585
in_keys = self.DEFAULT_IN_KEYS
@@ -598,35 +606,38 @@ def __init__(
598606
self.in_keys = self.in_keys + actor.in_keys
599607

600608
# check that the model has parameters
601-
params = TensorDict.from_module(actor)
602-
with params.apply(
603-
_stateless_param, device="meta", filter_empty=False
604-
).to_module(actor):
605-
# copy a stateless actor
606-
self.__dict__["functional_actor"] = deepcopy(actor)
609+
# params = TensorDict.from_module(actor)
610+
# with params.apply(
611+
# _stateless_param, device="meta", filter_empty=False
612+
# ).to_module(actor):
613+
# # copy a stateless actor
614+
# self.__dict__["functional_actor"] = deepcopy(actor)
615+
self.__dict__["functional_actor"] = actor
616+
607617
# we need to register these params as buffer to have `to` and similar
608618
# methods work properly
609619

610-
def _make_detached_param(x):
611-
612-
if isinstance(x, nn.Parameter):
613-
# we need an nn.Parameter since some modules (RNN) require nn.Parameters
614-
return nn.Parameter(x.data.clone(), requires_grad=requires_grad)
615-
elif x.requires_grad:
616-
raise ValueError(
617-
"Encountered a value that requires gradients but is not an nn.Parameter instance."
618-
)
619-
return x.clone()
620-
621-
self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
622-
if requires_grad:
623-
# includes the frozen params/buffers in the module parameters/buffers
624-
self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
620+
# def _make_detached_param(x):
621+
#
622+
# if isinstance(x, nn.Parameter):
623+
# # we need an nn.Parameter since some modules (RNN) require nn.Parameters
624+
# return nn.Parameter(x.data.clone(), requires_grad=requires_grad)
625+
# elif x.requires_grad:
626+
# raise ValueError(
627+
# "Encountered a value that requires gradients but is not an nn.Parameter instance."
628+
# )
629+
# return x.clone()
630+
# self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
631+
# if requires_grad:
632+
# # includes the frozen params/buffers in the module parameters/buffers
633+
# self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
625634

626635
# self._buffers["actor_params"] = params.clone().detach()
627636

637+
self.action_key = action_key
638+
628639
# find the sample log-prob key
629-
self.sample_log_prob_key = "sample_log_prob"
640+
self.sample_log_prob_key = log_prob_key
630641

631642
def find_sample_log_prob(module):
632643
if hasattr(module, "log_prob_key"):
@@ -647,16 +658,25 @@ def _reset(
647658

648659
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
649660
# run the actor on the tensordict
650-
action = next_tensordict.get("action", None)
661+
action = next_tensordict.get(self.action_key, None)
651662
if action is None:
652663
# being called after reset or without action, skipping
653664
if self.out_keys[0] != ("reward",) and self.parent is not None:
654665
next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
655666
return next_tensordict
656-
with self.frozen_params.to_module(self.functional_actor):
657-
dist = self.functional_actor.get_dist(next_tensordict.clone(False))
658-
# get the log_prob given the original model
659-
log_prob = dist.log_prob(action)
667+
# with self.frozen_params.to_module(self.functional_actor):
668+
if isinstance(
669+
self.functional_actor,
670+
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
671+
):
672+
dist = self.functional_actor.get_dist(next_tensordict.copy())
673+
# get the log_prob given the original model
674+
log_prob = dist.log_prob(action)
675+
else:
676+
log_prob = self.functional_actor(next_tensordict.copy()).get(
677+
self.sample_log_prob_key
678+
)
679+
660680
reward_key = self.in_keys[0]
661681
reward = next_tensordict.get("next").get(reward_key)
662682
curr_log_prob = next_tensordict.get(self.sample_log_prob_key)
@@ -679,12 +699,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
679699

680700
if in_key == "reward" and out_key == "reward":
681701
parent = self.parent
702+
703+
reward_keys = parent.reward_keys
704+
if len(reward_keys) == 1:
705+
reward_key = reward_keys[0]
706+
elif "reward" in reward_keys:
707+
reward_key = "reward"
708+
else:
709+
raise KeyError("Couln't find the reward key.")
710+
682711
reward_spec = Unbounded(
683712
device=output_spec.device,
684-
shape=output_spec["full_reward_spec"][parent.reward_key].shape,
713+
shape=output_spec["full_reward_spec"][reward_key].shape,
685714
)
686715
output_spec["full_reward_spec"] = Composite(
687-
{parent.reward_key: reward_spec},
716+
{reward_key: reward_spec},
688717
shape=output_spec["full_reward_spec"].shape,
689718
)
690719
elif in_key == "reward":

torchrl/modules/llm/transformers_policy.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
5353
- "tokens_out", "scores"
5454
5555
"""
56-
# TODO: how do we avoid getting these?
5756
tokens_out = td["tokens_out", "sequences"]
5857
seq_len = tokens_out.shape[1]
5958

6059
del td["tokens_out", "past_key_values"]
60+
6161
scores = dict(td["tokens_out", "scores"].items())
6262
scores = torch.stack(
6363
[scores[str(k)] for k in range(len(scores))], 1
@@ -90,15 +90,18 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
9090
- "forward", "past_key_values"
9191
- "forward"
9292
"""
93-
# TODO: how do we avoid getting these?
93+
tokens_out = td["tokens_response", "input_ids"]
94+
seq_len = tokens_out.shape[-1]
95+
9496
del td["forward", "past_key_values"]
97+
9598
scores = td["forward", "logits"]
99+
scores = scores[..., -seq_len:, :]
96100
logits = scores - scores.logsumexp(dim=-1, keepdim=True)
97101
td["logits"] = scores
98102
del td["forward"]
99103
scores.shape[1]
100-
tokens = td["tokens_in", "input_ids"]
101-
log_probs = logits.gather(-1, tokens.unsqueeze(-1))
104+
log_probs = logits.gather(-1, tokens_out.unsqueeze(-1))
102105
td["log_probs"] = log_probs
103106
return td
104107

0 commit comments

Comments
 (0)