diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 0b5959bb900..bc9141b1373 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -2,23 +2,82 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import warnings from copy import copy, deepcopy import torch from tensordict import TensorDict, TensorDictBase, unravel_key -from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams +from tensordict.nn import ( + ProbabilisticTensorDictModule, + ProbabilisticTensorDictSequential, + TensorDictModuleBase, + TensorDictParams, + TensorDictSequential, +) from tensordict.utils import is_seq_of_nested_key from torch import nn from torchrl.data.tensor_specs import Composite, Unbounded from torchrl.envs.transforms.transforms import Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param +# TODO: This should live somewhere else +class ReferenceModelLogProbTransform(Transform): + """A transform to compute and store the log-probabilities from the reference model.""" + + def __init__( + self, + frozen_model: ProbabilisticTensorDictModule, + ): + super().__init__(in_keys=frozen_model.in_keys, out_keys=frozen_model.out_keys) + self.frozen_model: ProbabilisticTensorDictModule = frozen_model + + def _call(self, inputs: TensorDict) -> TensorDict: + # Compute the log-prob given the reference model + return self.frozen_model(inputs) + +class KLDivergenceTransform(Transform): + """A transform to compute the KL divergence between the current and reference policies.""" + + ... + + +class RewardAdjustmentTransform(Transform): + """A transform to adjust the reward based on the computed KL divergence.""" + + ... + + +class KLConstrainedTransform(Composite): + """A composite transform to apply KL-based constraints on the policy.""" + + ... + class KLRewardTransform(Transform): - """A transform to add a KL[pi_current||pi_0] correction term to the reward. + r"""A transform to add a KL divergence correction term to the reward. This transform is used to constrain the policy to remain close to its original - configuration which limits overfitting when fine-tuning using RLHF. + configuration, which helps limit overfitting when fine-tuning using Reinforcement Learning with Human Feedback + (RLHF) or other forms of post-training (e.g., GRPO). + The KL divergence between the current policy distribution and the reference policy distribution is used to adjust the reward: + + .. math:: + + R_{\text{adjusted}} = R - \text{coef} \times \text{KL}(\pi_{\text{current}} || \pi_0) + + where \( R_{\text{adjusted}} \) is the adjusted reward, \( R \) is the original reward, and + \(\text{KL}(\pi_{\text{current}} || \pi_0)\) is the Kullback-Leibler divergence between the current policy + distribution \( \pi_{\text{current}} \) and the reference policy distribution \( \pi_0 \). + + The KL divergence can be estimated using the difference in log probabilities of the actions: + + .. math:: + + \text{KL}(\pi_{\text{current}} || \pi_0) \approx \log p(a \mid \theta_{\text{current}}) - \log p(a \mid \theta_0) + + where \( \log p(a \mid \theta_{\text{current}}) \) is the log probability of action \( a \) under the current model, and + \( \log p(a \mid \theta_0) \) is the log probability of action \( a \) under the reference model. + Args: actor (ProbabilisticTensorDictModule): a probabilistic actor. It must @@ -86,6 +145,11 @@ def __init__( out_keys=None, requires_grad=False, ): + warnings.warn( + "This class will be removed in a future release (v0.10.0). Please use torchrl.envs.KLConstrainedTransform " + "instead.", + category=FutureWarning, + ) if in_keys is None: in_keys = self.DEFAULT_IN_KEYS if out_keys is None: @@ -160,7 +224,7 @@ def _reset( def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase: # run the actor on the tensordict - action = next_tensordict.get("action", None) + action = next_tensordict.get("action") if action is None: # being called after reset or without action, skipping if self.out_keys[0] != ("reward",) and self.parent is not None: