From 3ed348acd617945277a4d07fb673005dfaf7801c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 14 Mar 2025 16:34:38 +0000 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torchrl/data/tensor_specs.py | 2 +- torchrl/modules/llm/transformers_policy.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 67cbda48f14..aec148f9523 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4941,7 +4941,7 @@ def set(self, name: str, spec: TensorSpec) -> Composite: spec.shape = self.shape else: raise ValueError( - f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first " + f"The shapes of the spec {type(spec).__name__} and the {type(self).__name__} mismatch: the first " f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " f"Composite.shape={self.shape}." ) diff --git a/torchrl/modules/llm/transformers_policy.py b/torchrl/modules/llm/transformers_policy.py index c506481746e..e977fc18be6 100644 --- a/torchrl/modules/llm/transformers_policy.py +++ b/torchrl/modules/llm/transformers_policy.py @@ -53,11 +53,11 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: - "tokens_out", "scores" """ - # TODO: how do we avoid getting these? tokens_out = td["tokens_out", "sequences"] seq_len = tokens_out.shape[1] del td["tokens_out", "past_key_values"] + scores = dict(td["tokens_out", "scores"].items()) scores = torch.stack( [scores[str(k)] for k in range(len(scores))], 1 @@ -90,15 +90,18 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: - "forward", "past_key_values" - "forward" """ - # TODO: how do we avoid getting these? + tokens_out = td["tokens_response", "input_ids"] + seq_len = tokens_out.shape[-1] + del td["forward", "past_key_values"] + scores = td["forward", "logits"] + scores = scores[..., -seq_len:, :] logits = scores - scores.logsumexp(dim=-1, keepdim=True) td["logits"] = scores del td["forward"] scores.shape[1] - tokens = td["tokens_in", "input_ids"] - log_probs = logits.gather(-1, tokens.unsqueeze(-1)) + log_probs = logits.gather(-1, tokens_out.unsqueeze(-1)) td["log_probs"] = log_probs return td From 7a4051e1880f351071e8a5394a9800cd3d6b50de Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 17 Mar 2025 13:09:12 +0000 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchrl/envs/transforms/llm.py | 81 +++++++----------- torchrl/objectives/ppo.py | 151 +++++++++++---------------------- 2 files changed, 80 insertions(+), 152 deletions(-) diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index 31f9f2ce055..1fb7666d24d 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -578,8 +578,6 @@ def __init__( in_keys=None, out_keys=None, requires_grad=False, - log_prob_key: NestedKey = "sample_log_prob", - action_key: NestedKey = "action", ): if in_keys is None: in_keys = self.DEFAULT_IN_KEYS @@ -606,38 +604,35 @@ def __init__( self.in_keys = self.in_keys + actor.in_keys # check that the model has parameters - # params = TensorDict.from_module(actor) - # with params.apply( - # _stateless_param, device="meta", filter_empty=False - # ).to_module(actor): - # # copy a stateless actor - # self.__dict__["functional_actor"] = deepcopy(actor) - self.__dict__["functional_actor"] = actor - + params = TensorDict.from_module(actor) + with params.apply( + _stateless_param, device="meta", filter_empty=False + ).to_module(actor): + # copy a stateless actor + self.__dict__["functional_actor"] = deepcopy(actor) # we need to register these params as buffer to have `to` and similar # methods work properly - # def _make_detached_param(x): - # - # if isinstance(x, nn.Parameter): - # # we need an nn.Parameter since some modules (RNN) require nn.Parameters - # return nn.Parameter(x.data.clone(), requires_grad=requires_grad) - # elif x.requires_grad: - # raise ValueError( - # "Encountered a value that requires gradients but is not an nn.Parameter instance." - # ) - # return x.clone() - # self.frozen_params = params.apply(_make_detached_param, filter_empty=False) - # if requires_grad: - # # includes the frozen params/buffers in the module parameters/buffers - # self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True) + def _make_detached_param(x): - # self._buffers["actor_params"] = params.clone().detach() + if isinstance(x, nn.Parameter): + # we need an nn.Parameter since some modules (RNN) require nn.Parameters + return nn.Parameter(x.data.clone(), requires_grad=requires_grad) + elif x.requires_grad: + raise ValueError( + "Encountered a value that requires gradients but is not an nn.Parameter instance." + ) + return x.clone() - self.action_key = action_key + self.frozen_params = params.apply(_make_detached_param, filter_empty=False) + if requires_grad: + # includes the frozen params/buffers in the module parameters/buffers + self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True) + + # self._buffers["actor_params"] = params.clone().detach() # find the sample log-prob key - self.sample_log_prob_key = log_prob_key + self.sample_log_prob_key = "sample_log_prob" def find_sample_log_prob(module): if hasattr(module, "log_prob_key"): @@ -658,25 +653,16 @@ def _reset( def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: # run the actor on the tensordict - action = next_tensordict.get(self.action_key, None) + action = next_tensordict.get("action", None) if action is None: # being called after reset or without action, skipping if self.out_keys[0] != ("reward",) and self.parent is not None: next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero()) return next_tensordict - # with self.frozen_params.to_module(self.functional_actor): - if isinstance( - self.functional_actor, - (ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential), - ): - dist = self.functional_actor.get_dist(next_tensordict.copy()) - # get the log_prob given the original model - log_prob = dist.log_prob(action) - else: - log_prob = self.functional_actor(next_tensordict.copy()).get( - self.sample_log_prob_key - ) - + with self.frozen_params.to_module(self.functional_actor): + dist = self.functional_actor.get_dist(next_tensordict.clone(False)) + # get the log_prob given the original model + log_prob = dist.log_prob(action) reward_key = self.in_keys[0] reward = next_tensordict.get("next").get(reward_key) curr_log_prob = next_tensordict.get(self.sample_log_prob_key) @@ -699,21 +685,12 @@ def transform_output_spec(self, output_spec: Composite) -> Composite: if in_key == "reward" and out_key == "reward": parent = self.parent - - reward_keys = parent.reward_keys - if len(reward_keys) == 1: - reward_key = reward_keys[0] - elif "reward" in reward_keys: - reward_key = "reward" - else: - raise KeyError("Couln't find the reward key.") - reward_spec = Unbounded( device=output_spec.device, - shape=output_spec["full_reward_spec"][reward_key].shape, + shape=output_spec["full_reward_spec"][parent.reward_key].shape, ) output_spec["full_reward_spec"] = Composite( - {reward_key: reward_spec}, + {parent.reward_key: reward_spec}, shape=output_spec["full_reward_spec"].shape, ) elif in_key == "reward": diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 22191e049da..db887cf0fba 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -24,7 +24,6 @@ ProbabilisticTensorDictSequential, set_composite_lp_aggregate, TensorDictModule, - TensorDictModuleBase, ) from tensordict.utils import NestedKey from torch import distributions as d @@ -328,7 +327,7 @@ def __init__( entropy_bonus: bool = True, samples_mc_entropy: int = 1, entropy_coef: float = 0.01, - critic_coef: float | None = None, + critic_coef: float = 1.0, loss_critic_type: str = "smooth_l1", normalize_advantage: bool = False, normalize_advantage_exclude_dims: tuple[int] = (), @@ -350,15 +349,10 @@ def __init__( if critic is not None: critic_network = critic del critic - if actor_network is None or ( - critic_network is None and critic_coef not in (None, 0.0) - ): + if actor_network is None or critic_network is None: raise TypeError( "Missing positional arguments actor_network or critic_network." ) - critic_coef = ( - 1.0 if critic_coef is None and critic_network is not None else critic_coef - ) if reduction is None: reduction = "mean" @@ -379,7 +373,7 @@ def __init__( policy_params = list(actor_network.parameters()) else: policy_params = None - if functional and critic_network is not None: + if functional: self.convert_to_functional( critic_network, "critic_network", compare_against=policy_params ) @@ -430,15 +424,12 @@ def __init__( f"clip_value must be a float or a scalar tensor, got {clip_value}." ) self.register_buffer("clip_value", clip_value) - try: - log_prob_keys = self.actor_network.log_prob_keys - action_keys = self.actor_network.dist_sample_keys - if len(log_prob_keys) > 1: - self.set_keys(sample_log_prob=log_prob_keys, action=action_keys) - else: - self.set_keys(sample_log_prob=log_prob_keys[0], action=action_keys[0]) - except AttributeError: - pass + log_prob_keys = self.actor_network.log_prob_keys + action_keys = self.actor_network.dist_sample_keys + if len(log_prob_keys) > 1: + self.set_keys(sample_log_prob=log_prob_keys, action=action_keys) + else: + self.set_keys(sample_log_prob=log_prob_keys[0], action=action_keys[0]) @property def functional(self): @@ -448,8 +439,7 @@ def _set_in_keys(self): keys = [] _maybe_add_or_extend_key(keys, self.actor_network.in_keys) _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") - if self.critic_network is not None: - _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) _maybe_add_or_extend_key(keys, self.tensor_keys.action) _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") @@ -526,50 +516,28 @@ def _get_entropy( entropy.batch_size = adv_shape return entropy.unsqueeze(-1) - def _get_cur_log_prob(self, tensordict): - - if isinstance( - self.actor_network, - (ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule), - ): - with self.actor_network_params.to_module( - self.actor_network - ) if self.functional else contextlib.nullcontext(): - dist = self.actor_network.get_dist(tensordict) - - is_composite = isinstance(dist, CompositeDistribution) - - if is_composite: - action = tensordict.select( - *( - (self.tensor_keys.action,) - if isinstance(self.tensor_keys.action, NestedKey) - else self.tensor_keys.action - ) - ) - else: - action = _maybe_get_or_select(tensordict, self.tensor_keys.action) - - if action.requires_grad: - raise RuntimeError( - f"tensordict stored {self.tensor_keys.action} requires grad." - ) - log_prob = dist.log_prob(action) - else: - with self.actor_network_params.to_module( - self.actor_network - ) if self.functional else contextlib.nullcontext(): - td = self.actor_network(tensordict) - log_prob = td.get(self.tensor_keys.sample_log_prob) - # TODO: decustomize this - dist = torch.distributions.Categorical(td.get("logits")) - is_composite = False - return log_prob, dist, is_composite - def _log_weight( self, tensordict: TensorDictBase, adv_shape: torch.Size ) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]: + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): + dist = self.actor_network.get_dist(tensordict) + + is_composite = isinstance(dist, CompositeDistribution) + + if is_composite: + action = tensordict.select( + *( + (self.tensor_keys.action,) + if isinstance(self.tensor_keys.action, NestedKey) + else self.tensor_keys.action + ) + ) + else: + action = _maybe_get_or_select(tensordict, self.tensor_keys.action) + prev_log_prob = _maybe_get_or_select( tensordict, self.tensor_keys.sample_log_prob, @@ -581,8 +549,11 @@ def _log_weight( f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad." ) - log_prob, dist, is_composite = self._get_cur_log_prob(tensordict) - + if action.requires_grad: + raise RuntimeError( + f"tensordict stored {self.tensor_keys.action} requires grad." + ) + log_prob = dist.log_prob(action) if is_composite: with set_composite_lp_aggregate(False): if log_prob.batch_size != adv_shape: @@ -599,6 +570,7 @@ def _log_weight( if is_tensor_collection(log_prob): log_prob = _sum_td_features(log_prob) log_prob.view_as(prev_log_prob) + log_weight = (log_prob - prev_log_prob).unsqueeze(-1) if is_tensor_collection(log_weight): log_weight = _sum_td_features(log_weight) @@ -1030,7 +1002,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: entropy = _sum_td_features(entropy) td_out.set("entropy", entropy.detach().mean()) # for logging td_out.set("loss_entropy", -self.entropy_coef * entropy) - if self.critic_coef is not None and self.critic_coef > 0: + if self.critic_coef is not None: loss_critic, value_clip_fraction = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) if value_clip_fraction is not None: @@ -1231,8 +1203,7 @@ def _set_in_keys(self): keys = [] _maybe_add_or_extend_key(keys, self.actor_network.in_keys) _maybe_add_or_extend_key(keys, self.actor_network.in_keys, "next") - if self.critic_network is not None: - _maybe_add_or_extend_key(keys, self.critic_network.in_keys) + _maybe_add_or_extend_key(keys, self.critic_network.in_keys) _maybe_add_or_extend_key(keys, self.tensor_keys.action) _maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob) _maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next") @@ -1240,24 +1211,21 @@ def _set_in_keys(self): _maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next") # Get the parameter keys from the actor dist - # actor_dist_module = None - # for module in self.actor_network.modules(): - # # Ideally we should combine them if there is more than one - # if isinstance(module, ProbabilisticTensorDictModule): - # if actor_dist_module is not None: - # raise RuntimeError( - # "Actors with one and only one distribution are currently supported " - # f"in {type(self).__name__}. If you need to use more than one " - # f"distributions over the action space please submit an issue " - # f"on github." - # ) - # actor_dist_module = module - # if actor_dist_module is None: - # if hasattr(self.actor_network, "in_keys"): - # actor_dist_module = self.actor_network - # else: - # raise RuntimeError("Could not find the probabilistic module in the actor.") - keys += list(self.actor_network.in_keys) + actor_dist_module = None + for module in self.actor_network.modules(): + # Ideally we should combine them if there is more than one + if isinstance(module, ProbabilisticTensorDictModule): + if actor_dist_module is not None: + raise RuntimeError( + "Actors with one and only one distribution are currently supported " + f"in {type(self).__name__}. If you need to use more than one " + f"distributions over the action space please submit an issue " + f"on github." + ) + actor_dist_module = module + if actor_dist_module is None: + raise RuntimeError("Could not find the probabilistic module in the actor.") + keys += list(actor_dist_module.in_keys) self._in_keys = list(set(keys)) @property @@ -1382,20 +1350,3 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: def reset(self) -> None: self.beta = self._beta_init - - -class GRPO(ClipPPOLoss): - """TODO""" - def __init__( - self, - actor_network: TensorDictModuleBase, - # Default value of LLMData - log_prob_key="log_probs", - ): - super().__init__( - actor_network=actor_network, - critic_network=None, - critic_coef=0.0, - functional=False, - ) - self.set_keys(log_prob_key=log_prob_key)