Skip to content

Commit 4c4e4f9

Browse files
committed
[BugFix] Right log-prob size in transformer wrapper
ghstack-source-id: fd11bc55e61c1e3b40ce6702c075da885f6dca27 Pull Request resolved: #2854
1 parent 06a8c86 commit 4c4e4f9

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

torchrl/data/tensor_specs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4941,7 +4941,7 @@ def set(self, name: str, spec: TensorSpec) -> Composite:
49414941
spec.shape = self.shape
49424942
else:
49434943
raise ValueError(
4944-
f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first "
4944+
f"The shapes of the spec {type(spec).__name__} and the {type(self).__name__} mismatch: the first "
49454945
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
49464946
f"Composite.shape={self.shape}."
49474947
)

torchrl/envs/transforms/llm.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66

77
from collections import deque
88
from collections.abc import Mapping
9-
from copy import copy, deepcopy
9+
from copy import copy
1010
from typing import Any, Callable, Iterable, Literal
1111

1212
import torch
1313
from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key
14-
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
14+
from tensordict.nn import (
15+
ProbabilisticTensorDictModule,
16+
ProbabilisticTensorDictSequential,
17+
)
1518
from tensordict.utils import _zip_strict, is_seq_of_nested_key
16-
from torch import nn
1719

1820
from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded
1921
from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform
20-
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param
22+
from torchrl.envs.transforms.utils import _set_missing_tolerance
2123
from torchrl.envs.utils import make_composite_from_td
2224

2325

@@ -500,6 +502,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
500502
return self._queue.popleft()
501503
return out
502504

505+
def __repr__(self) -> str:
506+
class_name = self.__class__.__name__
507+
return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})"
508+
503509

504510
class KLRewardTransform(Transform):
505511
"""A transform to add a KL[pi_current||pi_0] correction term to the reward.

torchrl/modules/llm/transformers_policy.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
5353
- "tokens_out", "scores"
5454
5555
"""
56-
# TODO: how do we avoid getting these?
5756
tokens_out = td["tokens_out", "sequences"]
5857
seq_len = tokens_out.shape[1]
5958

6059
del td["tokens_out", "past_key_values"]
60+
6161
scores = dict(td["tokens_out", "scores"].items())
6262
scores = torch.stack(
6363
[scores[str(k)] for k in range(len(scores))], 1
@@ -90,15 +90,18 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
9090
- "forward", "past_key_values"
9191
- "forward"
9292
"""
93-
# TODO: how do we avoid getting these?
93+
tokens_out = td["tokens_response", "input_ids"]
94+
seq_len = tokens_out.shape[-1]
95+
9496
del td["forward", "past_key_values"]
97+
9598
scores = td["forward", "logits"]
99+
scores = scores[..., -seq_len:, :]
96100
logits = scores - scores.logsumexp(dim=-1, keepdim=True)
97101
td["logits"] = scores
98102
del td["forward"]
99103
scores.shape[1]
100-
tokens = td["tokens_in", "input_ids"]
101-
log_probs = logits.gather(-1, tokens.unsqueeze(-1))
104+
log_probs = logits.gather(-1, tokens_out.unsqueeze(-1))
102105
td["log_probs"] = log_probs
103106
return td
104107

0 commit comments

Comments
 (0)