Skip to content

Commit b7e0645

Browse files
committed
[Feature] Make PPO ready for text-based data
ghstack-source-id: 534a429 Pull Request resolved: #2855
1 parent bc454f7 commit b7e0645

File tree

2 files changed

+162
-84
lines changed

2 files changed

+162
-84
lines changed

torchrl/envs/transforms/llm.py

+62-33
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from collections import deque
88
from collections.abc import Mapping
9-
from copy import copy, deepcopy
9+
from copy import copy
1010
from typing import Any, Callable, Iterable, Literal
1111

1212
import torch
@@ -17,13 +17,15 @@
1717
TensorDictBase,
1818
unravel_key,
1919
)
20-
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
20+
from tensordict.nn import (
21+
ProbabilisticTensorDictModule,
22+
ProbabilisticTensorDictSequential,
23+
)
2124
from tensordict.utils import _zip_strict, is_seq_of_nested_key
22-
from torch import nn
2325

2426
from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded
2527
from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform
26-
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param
28+
from torchrl.envs.transforms.utils import _set_missing_tolerance
2729
from torchrl.envs.utils import make_composite_from_td
2830

2931

@@ -506,6 +508,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
506508
return self._queue.popleft()
507509
return out
508510

511+
def __repr__(self) -> str:
512+
class_name = self.__class__.__name__
513+
return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})"
514+
509515

510516
class KLRewardTransform(Transform):
511517
"""A transform to add a KL[pi_current||pi_0] correction term to the reward.
@@ -578,6 +584,8 @@ def __init__(
578584
in_keys=None,
579585
out_keys=None,
580586
requires_grad=False,
587+
log_prob_key: NestedKey = "sample_log_prob",
588+
action_key: NestedKey = "action",
581589
):
582590
if in_keys is None:
583591
in_keys = self.DEFAULT_IN_KEYS
@@ -604,35 +612,38 @@ def __init__(
604612
self.in_keys = self.in_keys + actor.in_keys
605613

606614
# check that the model has parameters
607-
params = TensorDict.from_module(actor)
608-
with params.apply(
609-
_stateless_param, device="meta", filter_empty=False
610-
).to_module(actor):
611-
# copy a stateless actor
612-
self.__dict__["functional_actor"] = deepcopy(actor)
615+
# params = TensorDict.from_module(actor)
616+
# with params.apply(
617+
# _stateless_param, device="meta", filter_empty=False
618+
# ).to_module(actor):
619+
# # copy a stateless actor
620+
# self.__dict__["functional_actor"] = deepcopy(actor)
621+
self.__dict__["functional_actor"] = actor
622+
613623
# we need to register these params as buffer to have `to` and similar
614624
# methods work properly
615625

616-
def _make_detached_param(x):
617-
618-
if isinstance(x, nn.Parameter):
619-
# we need an nn.Parameter since some modules (RNN) require nn.Parameters
620-
return nn.Parameter(x.data.clone(), requires_grad=requires_grad)
621-
elif x.requires_grad:
622-
raise ValueError(
623-
"Encountered a value that requires gradients but is not an nn.Parameter instance."
624-
)
625-
return x.clone()
626-
627-
self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
628-
if requires_grad:
629-
# includes the frozen params/buffers in the module parameters/buffers
630-
self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
626+
# def _make_detached_param(x):
627+
#
628+
# if isinstance(x, nn.Parameter):
629+
# # we need an nn.Parameter since some modules (RNN) require nn.Parameters
630+
# return nn.Parameter(x.data.clone(), requires_grad=requires_grad)
631+
# elif x.requires_grad:
632+
# raise ValueError(
633+
# "Encountered a value that requires gradients but is not an nn.Parameter instance."
634+
# )
635+
# return x.clone()
636+
# self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
637+
# if requires_grad:
638+
# # includes the frozen params/buffers in the module parameters/buffers
639+
# self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
631640

632641
# self._buffers["actor_params"] = params.clone().detach()
633642

643+
self.action_key = action_key
644+
634645
# find the sample log-prob key
635-
self.sample_log_prob_key = "sample_log_prob"
646+
self.sample_log_prob_key = log_prob_key
636647

637648
def find_sample_log_prob(module):
638649
if hasattr(module, "log_prob_key"):
@@ -653,16 +664,25 @@ def _reset(
653664

654665
def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
655666
# run the actor on the tensordict
656-
action = next_tensordict.get("action", None)
667+
action = next_tensordict.get(self.action_key, None)
657668
if action is None:
658669
# being called after reset or without action, skipping
659670
if self.out_keys[0] != ("reward",) and self.parent is not None:
660671
next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
661672
return next_tensordict
662-
with self.frozen_params.to_module(self.functional_actor):
663-
dist = self.functional_actor.get_dist(next_tensordict.clone(False))
664-
# get the log_prob given the original model
665-
log_prob = dist.log_prob(action)
673+
# with self.frozen_params.to_module(self.functional_actor):
674+
if isinstance(
675+
self.functional_actor,
676+
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
677+
):
678+
dist = self.functional_actor.get_dist(next_tensordict.copy())
679+
# get the log_prob given the original model
680+
log_prob = dist.log_prob(action)
681+
else:
682+
log_prob = self.functional_actor(next_tensordict.copy()).get(
683+
self.sample_log_prob_key
684+
)
685+
666686
reward_key = self.in_keys[0]
667687
reward = next_tensordict.get("next").get(reward_key)
668688
curr_log_prob = next_tensordict.get(self.sample_log_prob_key)
@@ -685,12 +705,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
685705

686706
if in_key == "reward" and out_key == "reward":
687707
parent = self.parent
708+
709+
reward_keys = parent.reward_keys
710+
if len(reward_keys) == 1:
711+
reward_key = reward_keys[0]
712+
elif "reward" in reward_keys:
713+
reward_key = "reward"
714+
else:
715+
raise KeyError("Couln't find the reward key.")
716+
688717
reward_spec = Unbounded(
689718
device=output_spec.device,
690-
shape=output_spec["full_reward_spec"][parent.reward_key].shape,
719+
shape=output_spec["full_reward_spec"][reward_key].shape,
691720
)
692721
output_spec["full_reward_spec"] = Composite(
693-
{parent.reward_key: reward_spec},
722+
{reward_key: reward_spec},
694723
shape=output_spec["full_reward_spec"].shape,
695724
)
696725
elif in_key == "reward":

0 commit comments

Comments
 (0)