Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature,Deprecation] Split KLRewardTransform in more modules #2813

Open
wants to merge 1 commit into
base: gh/vmoens/92/base
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Update
[ghstack-poisoned]
vmoens committed Feb 27, 2025
commit 9b8056d507462fca775ddd405b0a2932b83c24a5
72 changes: 68 additions & 4 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
@@ -2,23 +2,82 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from copy import copy, deepcopy

import torch
from tensordict import TensorDict, TensorDictBase, unravel_key
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModuleBase,
TensorDictParams,
TensorDictSequential,
)
from tensordict.utils import is_seq_of_nested_key
from torch import nn
from torchrl.data.tensor_specs import Composite, Unbounded
from torchrl.envs.transforms.transforms import Transform
from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param

# TODO: This should live somewhere else
class ReferenceModelLogProbTransform(Transform):
"""A transform to compute and store the log-probabilities from the reference model."""

def __init__(
self,
frozen_model: ProbabilisticTensorDictModule,
):
super().__init__(in_keys=frozen_model.in_keys, out_keys=frozen_model.out_keys)
self.frozen_model: ProbabilisticTensorDictModule = frozen_model

def _call(self, inputs: TensorDict) -> TensorDict:
# Compute the log-prob given the reference model
return self.frozen_model(inputs)

class KLDivergenceTransform(Transform):
"""A transform to compute the KL divergence between the current and reference policies."""

...


class RewardAdjustmentTransform(Transform):
"""A transform to adjust the reward based on the computed KL divergence."""

...


class KLConstrainedTransform(Composite):
"""A composite transform to apply KL-based constraints on the policy."""

...


class KLRewardTransform(Transform):
"""A transform to add a KL[pi_current||pi_0] correction term to the reward.
r"""A transform to add a KL divergence correction term to the reward.
This transform is used to constrain the policy to remain close to its original
configuration which limits overfitting when fine-tuning using RLHF.
configuration, which helps limit overfitting when fine-tuning using Reinforcement Learning with Human Feedback
(RLHF) or other forms of post-training (e.g., GRPO).
The KL divergence between the current policy distribution and the reference policy distribution is used to adjust the reward:
.. math::
R_{\text{adjusted}} = R - \text{coef} \times \text{KL}(\pi_{\text{current}} || \pi_0)
where \( R_{\text{adjusted}} \) is the adjusted reward, \( R \) is the original reward, and
\(\text{KL}(\pi_{\text{current}} || \pi_0)\) is the Kullback-Leibler divergence between the current policy
distribution \( \pi_{\text{current}} \) and the reference policy distribution \( \pi_0 \).
The KL divergence can be estimated using the difference in log probabilities of the actions:
.. math::
\text{KL}(\pi_{\text{current}} || \pi_0) \approx \log p(a \mid \theta_{\text{current}}) - \log p(a \mid \theta_0)
where \( \log p(a \mid \theta_{\text{current}}) \) is the log probability of action \( a \) under the current model, and
\( \log p(a \mid \theta_0) \) is the log probability of action \( a \) under the reference model.
Args:
actor (ProbabilisticTensorDictModule): a probabilistic actor. It must
@@ -86,6 +145,11 @@ def __init__(
out_keys=None,
requires_grad=False,
):
warnings.warn(
"This class will be removed in a future release (v0.10.0). Please use torchrl.envs.KLConstrainedTransform "
"instead.",
category=FutureWarning,
)
if in_keys is None:
in_keys = self.DEFAULT_IN_KEYS
if out_keys is None:
@@ -160,7 +224,7 @@ def _reset(

def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
# run the actor on the tensordict
action = next_tensordict.get("action", None)
action = next_tensordict.get("action")
if action is None:
# being called after reset or without action, skipping
if self.out_keys[0] != ("reward",) and self.parent is not None: