Skip to content

[Feature] Make PPO ready for text-based data #2855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 62 additions & 33 deletions torchrl/envs/transforms/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from collections import deque
from collections.abc import Mapping
from copy import copy, deepcopy
from copy import copy
from typing import Any, Callable, Iterable, Literal

import torch
Expand All @@ -17,13 +17,15 @@
TensorDictBase,
unravel_key,
)
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
from tensordict.utils import _zip_strict, is_seq_of_nested_key
from torch import nn

from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded
from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param
from torchrl.envs.transforms.utils import _set_missing_tolerance
from torchrl.envs.utils import make_composite_from_td


Expand Down Expand Up @@ -506,6 +508,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
return self._queue.popleft()
return out

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})"


class KLRewardTransform(Transform):
"""A transform to add a KL[pi_current||pi_0] correction term to the reward.
Expand Down Expand Up @@ -578,6 +584,8 @@ def __init__(
in_keys=None,
out_keys=None,
requires_grad=False,
log_prob_key: NestedKey = "sample_log_prob",
action_key: NestedKey = "action",
):
if in_keys is None:
in_keys = self.DEFAULT_IN_KEYS
Expand All @@ -604,35 +612,38 @@ def __init__(
self.in_keys = self.in_keys + actor.in_keys

# check that the model has parameters
params = TensorDict.from_module(actor)
with params.apply(
_stateless_param, device="meta", filter_empty=False
).to_module(actor):
# copy a stateless actor
self.__dict__["functional_actor"] = deepcopy(actor)
# params = TensorDict.from_module(actor)
# with params.apply(
# _stateless_param, device="meta", filter_empty=False
# ).to_module(actor):
# # copy a stateless actor
# self.__dict__["functional_actor"] = deepcopy(actor)
self.__dict__["functional_actor"] = actor

# we need to register these params as buffer to have `to` and similar
# methods work properly

def _make_detached_param(x):

if isinstance(x, nn.Parameter):
# we need an nn.Parameter since some modules (RNN) require nn.Parameters
return nn.Parameter(x.data.clone(), requires_grad=requires_grad)
elif x.requires_grad:
raise ValueError(
"Encountered a value that requires gradients but is not an nn.Parameter instance."
)
return x.clone()

self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
if requires_grad:
# includes the frozen params/buffers in the module parameters/buffers
self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
# def _make_detached_param(x):
#
# if isinstance(x, nn.Parameter):
# # we need an nn.Parameter since some modules (RNN) require nn.Parameters
# return nn.Parameter(x.data.clone(), requires_grad=requires_grad)
# elif x.requires_grad:
# raise ValueError(
# "Encountered a value that requires gradients but is not an nn.Parameter instance."
# )
# return x.clone()
# self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
# if requires_grad:
# # includes the frozen params/buffers in the module parameters/buffers
# self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)

# self._buffers["actor_params"] = params.clone().detach()

self.action_key = action_key

# find the sample log-prob key
self.sample_log_prob_key = "sample_log_prob"
self.sample_log_prob_key = log_prob_key

def find_sample_log_prob(module):
if hasattr(module, "log_prob_key"):
Expand All @@ -653,16 +664,25 @@ def _reset(

def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
# run the actor on the tensordict
action = next_tensordict.get("action", None)
action = next_tensordict.get(self.action_key, None)
if action is None:
# being called after reset or without action, skipping
if self.out_keys[0] != ("reward",) and self.parent is not None:
next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
return next_tensordict
with self.frozen_params.to_module(self.functional_actor):
dist = self.functional_actor.get_dist(next_tensordict.clone(False))
# get the log_prob given the original model
log_prob = dist.log_prob(action)
# with self.frozen_params.to_module(self.functional_actor):
if isinstance(
self.functional_actor,
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
):
dist = self.functional_actor.get_dist(next_tensordict.copy())
# get the log_prob given the original model
log_prob = dist.log_prob(action)
else:
log_prob = self.functional_actor(next_tensordict.copy()).get(
self.sample_log_prob_key
)

reward_key = self.in_keys[0]
reward = next_tensordict.get("next").get(reward_key)
curr_log_prob = next_tensordict.get(self.sample_log_prob_key)
Expand All @@ -685,12 +705,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:

if in_key == "reward" and out_key == "reward":
parent = self.parent

reward_keys = parent.reward_keys
if len(reward_keys) == 1:
reward_key = reward_keys[0]
elif "reward" in reward_keys:
reward_key = "reward"
else:
raise KeyError("Couln't find the reward key.")

reward_spec = Unbounded(
device=output_spec.device,
shape=output_spec["full_reward_spec"][parent.reward_key].shape,
shape=output_spec["full_reward_spec"][reward_key].shape,
)
output_spec["full_reward_spec"] = Composite(
{parent.reward_key: reward_spec},
{reward_key: reward_spec},
shape=output_spec["full_reward_spec"].shape,
)
elif in_key == "reward":
Expand Down
Loading
Loading