diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index e5c14bf3c96..192ee480af3 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -6,7 +6,7 @@ from collections import deque from collections.abc import Mapping -from copy import copy, deepcopy +from copy import copy from typing import Any, Callable, Iterable, Literal import torch @@ -17,13 +17,15 @@ TensorDictBase, unravel_key, ) -from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams +from tensordict.nn import ( + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, +) from tensordict.utils import _zip_strict, is_seq_of_nested_key -from torch import nn from torchrl.data.tensor_specs import Composite, NonTensor, TensorSpec, Unbounded from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform -from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param +from torchrl.envs.transforms.utils import _set_missing_tolerance from torchrl.envs.utils import make_composite_from_td @@ -506,6 +508,10 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None): return self._queue.popleft() return out + def __repr__(self) -> str: + class_name = self.__class__.__name__ + return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})" + class KLRewardTransform(Transform): """A transform to add a KL[pi_current||pi_0] correction term to the reward. @@ -578,6 +584,8 @@ def __init__( in_keys=None, out_keys=None, requires_grad=False, + log_prob_key: NestedKey = "sample_log_prob", + action_key: NestedKey = "action", ): if in_keys is None: in_keys = self.DEFAULT_IN_KEYS @@ -604,35 +612,38 @@ def __init__( self.in_keys = self.in_keys + actor.in_keys # check that the model has parameters - params = TensorDict.from_module(actor) - with params.apply( - _stateless_param, device="meta", filter_empty=False - ).to_module(actor): - # copy a stateless actor - self.__dict__["functional_actor"] = deepcopy(actor) + # params = TensorDict.from_module(actor) + # with params.apply( + # _stateless_param, device="meta", filter_empty=False + # ).to_module(actor): + # # copy a stateless actor + # self.__dict__["functional_actor"] = deepcopy(actor) + self.__dict__["functional_actor"] = actor + # we need to register these params as buffer to have `to` and similar # methods work properly - def _make_detached_param(x): - - if isinstance(x, nn.Parameter): - # we need an nn.Parameter since some modules (RNN) require nn.Parameters - return nn.Parameter(x.data.clone(), requires_grad=requires_grad) - elif x.requires_grad: - raise ValueError( - "Encountered a value that requires gradients but is not an nn.Parameter instance." - ) - return x.clone() - - self.frozen_params = params.apply(_make_detached_param, filter_empty=False) - if requires_grad: - # includes the frozen params/buffers in the module parameters/buffers - self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True) + # def _make_detached_param(x): + # + # if isinstance(x, nn.Parameter): + # # we need an nn.Parameter since some modules (RNN) require nn.Parameters + # return nn.Parameter(x.data.clone(), requires_grad=requires_grad) + # elif x.requires_grad: + # raise ValueError( + # "Encountered a value that requires gradients but is not an nn.Parameter instance." + # ) + # return x.clone() + # self.frozen_params = params.apply(_make_detached_param, filter_empty=False) + # if requires_grad: + # # includes the frozen params/buffers in the module parameters/buffers + # self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True) # self._buffers["actor_params"] = params.clone().detach() + self.action_key = action_key + # find the sample log-prob key - self.sample_log_prob_key = "sample_log_prob" + self.sample_log_prob_key = log_prob_key def find_sample_log_prob(module): if hasattr(module, "log_prob_key"): @@ -653,16 +664,25 @@ def _reset( def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: # run the actor on the tensordict - action = next_tensordict.get("action", None) + action = next_tensordict.get(self.action_key, None) if action is None: # being called after reset or without action, skipping if self.out_keys[0] != ("reward",) and self.parent is not None: next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero()) return next_tensordict - with self.frozen_params.to_module(self.functional_actor): - dist = self.functional_actor.get_dist(next_tensordict.clone(False)) - # get the log_prob given the original model - log_prob = dist.log_prob(action) + # with self.frozen_params.to_module(self.functional_actor): + if isinstance( + self.functional_actor, + (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), + ): + dist = self.functional_actor.get_dist(next_tensordict.copy()) + # get the log_prob given the original model + log_prob = dist.log_prob(action) + else: + log_prob = self.functional_actor(next_tensordict.copy()).get( + self.sample_log_prob_key + ) + reward_key = self.in_keys[0] reward = next_tensordict.get("next").get(reward_key) curr_log_prob = next_tensordict.get(self.sample_log_prob_key) @@ -685,12 +705,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite: if in_key == "reward" and out_key == "reward": parent = self.parent + + reward_keys = parent.reward_keys + if len(reward_keys) == 1: + reward_key = reward_keys[0] + elif "reward" in reward_keys: + reward_key = "reward" + else: + raise KeyError("Couln't find the reward key.") + reward_spec = Unbounded( device=output_spec.device, - shape=output_spec["full_reward_spec"][parent.reward_key].shape, + shape=output_spec["full_reward_spec"][reward_key].shape, ) output_spec["full_reward_spec"] = Composite( - {parent.reward_key: reward_spec}, + {reward_key: reward_spec}, shape=output_spec["full_reward_spec"].shape, ) elif in_key == "reward": diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index db887cf0fba..22191e049da 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -24,6 +24,7 @@ ProbabilisticTensorDictSequential, set_composite_lp_aggregate, TensorDictModule, + TensorDictModuleBase, ) from tensordict.utils import NestedKey from torch import distributions as d @@ -327,7 +328,7 @@ def __init__( entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coef: float = 0.01, - critic_coef: float = 1.0, + critic_coef: float | None = None, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, normalize_advantage_exclude_dims: tuple[int] = (), @@ -349,10 +350,15 @@ def __init__( if critic is not None: critic_network = critic del critic - if actor_network is None or critic_network is None: + if actor_network is None or ( + critic_network is None and critic_coef not in (None, 0.0) + ): raise TypeError( "Missing positional arguments actor_network or critic_network." ) + critic_coef = ( + 1.0 if critic_coef is None and critic_network is not None else critic_coef + ) if reduction is None: reduction = "mean" @@ -373,7 +379,7 @@ def __init__( policy_params = list(actor_network.parameters()) else: policy_params = None - if functional: + if functional and critic_network is not None: self.convert_to_functional( critic_network, "critic_network", compare_against=policy_params ) @@ -424,12 +430,15 @@ def __init__( f"clip_value must be a float or a scalar tensor, got {clip_value}." ) self.register_buffer("clip_value", clip_value) - log_prob_keys = self.actor_network.log_prob_keys - action_keys = self.actor_network.dist_sample_keys - if len(log_prob_keys) > 1: - self.set_keys(sample_log_prob=log_prob_keys, action=action_keys) - else: - self.set_keys(sample_log_prob=log_prob_keys[0], action=action_keys[0]) + try: + log_prob_keys = self.actor_network.log_prob_keys + action_keys = self.actor_network.dist_sample_keys + if len(log_prob_keys) > 1: + self.set_keys(sample_log_prob=log_prob_keys, action=action_keys) + else: + self.set_keys(sample_log_prob=log_prob_keys[0], action=action_keys[0]) + except AttributeError: + pass @property def functional(self): @@ -439,7 +448,8 @@ def _set_in_keys(self): keys = [] _maybe_add_or_extend_key(keys, self.actor_network.in_keys) _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") - _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + if self.critic_network is not None: + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) _maybe_add_or_extend_key(keys, self.tensor_keys.action) _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") @@ -516,27 +526,49 @@ def _get_entropy( entropy.batch_size = adv_shape return entropy.unsqueeze(-1) - def _log_weight( - self, tensordict: TensorDictBase, adv_shape: torch.Size - ) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]: - - with self.actor_network_params.to_module( - self.actor_network - ) if self.functional else contextlib.nullcontext(): - dist = self.actor_network.get_dist(tensordict) - - is_composite = isinstance(dist, CompositeDistribution) + def _get_cur_log_prob(self, tensordict): + + if isinstance( + self.actor_network, + (ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule), + ): + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): + dist = self.actor_network.get_dist(tensordict) + + is_composite = isinstance(dist, CompositeDistribution) + + if is_composite: + action = tensordict.select( + *( + (self.tensor_keys.action,) + if isinstance(self.tensor_keys.action, NestedKey) + else self.tensor_keys.action + ) + ) + else: + action = _maybe_get_or_select(tensordict, self.tensor_keys.action) - if is_composite: - action = tensordict.select( - *( - (self.tensor_keys.action,) - if isinstance(self.tensor_keys.action, NestedKey) - else self.tensor_keys.action + if action.requires_grad: + raise RuntimeError( + f"tensordict stored {self.tensor_keys.action} requires grad." ) - ) + log_prob = dist.log_prob(action) else: - action = _maybe_get_or_select(tensordict, self.tensor_keys.action) + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): + td = self.actor_network(tensordict) + log_prob = td.get(self.tensor_keys.sample_log_prob) + # TODO: decustomize this + dist = torch.distributions.Categorical(td.get("logits")) + is_composite = False + return log_prob, dist, is_composite + + def _log_weight( + self, tensordict: TensorDictBase, adv_shape: torch.Size + ) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]: prev_log_prob = _maybe_get_or_select( tensordict, @@ -549,11 +581,8 @@ def _log_weight( f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." ) - if action.requires_grad: - raise RuntimeError( - f"tensordict stored {self.tensor_keys.action} requires grad." - ) - log_prob = dist.log_prob(action) + log_prob, dist, is_composite = self._get_cur_log_prob(tensordict) + if is_composite: with set_composite_lp_aggregate(False): if log_prob.batch_size != adv_shape: @@ -570,7 +599,6 @@ def _log_weight( if is_tensor_collection(log_prob): log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) - log_weight = (log_prob - prev_log_prob).unsqueeze(-1) if is_tensor_collection(log_weight): log_weight = _sum_td_features(log_weight) @@ -1002,7 +1030,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: entropy = _sum_td_features(entropy) td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef is not None: + if self.critic_coef is not None and self.critic_coef > 0: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: @@ -1203,7 +1231,8 @@ def _set_in_keys(self): keys = [] _maybe_add_or_extend_key(keys, self.actor_network.in_keys) _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") - _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + if self.critic_network is not None: + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) _maybe_add_or_extend_key(keys, self.tensor_keys.action) _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") @@ -1211,21 +1240,24 @@ def _set_in_keys(self): _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") # Get the parameter keys from the actor dist - actor_dist_module = None - for module in self.actor_network.modules(): - # Ideally we should combine them if there is more than one - if isinstance(module, ProbabilisticTensorDictModule): - if actor_dist_module is not None: - raise RuntimeError( - "Actors with one and only one distribution are currently supported " - f"in {type(self).__name__}. If you need to use more than one " - f"distributions over the action space please submit an issue " - f"on github." - ) - actor_dist_module = module - if actor_dist_module is None: - raise RuntimeError("Could not find the probabilistic module in the actor.") - keys += list(actor_dist_module.in_keys) + # actor_dist_module = None + # for module in self.actor_network.modules(): + # # Ideally we should combine them if there is more than one + # if isinstance(module, ProbabilisticTensorDictModule): + # if actor_dist_module is not None: + # raise RuntimeError( + # "Actors with one and only one distribution are currently supported " + # f"in {type(self).__name__}. If you need to use more than one " + # f"distributions over the action space please submit an issue " + # f"on github." + # ) + # actor_dist_module = module + # if actor_dist_module is None: + # if hasattr(self.actor_network, "in_keys"): + # actor_dist_module = self.actor_network + # else: + # raise RuntimeError("Could not find the probabilistic module in the actor.") + keys += list(self.actor_network.in_keys) self._in_keys = list(set(keys)) @property @@ -1350,3 +1382,20 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def reset(self) -> None: self.beta = self._beta_init + + +class GRPO(ClipPPOLoss): + """TODO""" + def __init__( + self, + actor_network: TensorDictModuleBase, + # Default value of LLMData + log_prob_key="log_probs", + ): + super().__init__( + actor_network=actor_network, + critic_network=None, + critic_coef=0.0, + functional=False, + ) + self.set_keys(log_prob_key=log_prob_key)