Skip to content

Commit

Permalink
test(nyz): polish ppo and add rlhf ppo loss test
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Feb 13, 2025
1 parent 3b7903a commit e8ef818
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 14 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1429,3 +1429,8 @@ collect_demo_data_config.py
events.*

evogym/*
ding/example/*
ding/framework/middleware/tests/wandb/
ding/.style.yapf
ding/format.sh
ding/framework/middleware_v3/
43 changes: 29 additions & 14 deletions ding/rl_utils/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,21 @@ def ppo_error(
return ppo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info


def ppo_policy_error(data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
'''
def ppo_policy_error(
data: namedtuple,
clip_ratio: float = 0.2,
dual_clip: Optional[float] = None,
entropy_bonus: bool = True
) -> Tuple[namedtuple, namedtuple]:
"""
Overview:
Get PPO policy loss
Get PPO policy loss (both for classical RL in control/video games and LLM/VLM RLHF).
Arguments:
- data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_policy_data``
- clip_ratio (:obj:`float`): clip value for ratio
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
defaults to 5.0, if you don't want to use it, set this parameter to None
- data (:obj:`namedtuple`): Ppo input data with fieids shown in ``ppo_policy_data``.
- clip_ratio (:obj:`float`): Clip value for ratio, defaults to 0.2.
- dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
defaults to 5.0, if you don't want to use it, set this parameter to None
- entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
Returns:
- ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
Expand All @@ -136,18 +140,29 @@ def ppo_policy_error(data: namedtuple,
>>> weight=torch.ones(3),
>>> )
>>> loss, info = ppo_policy_error(data)
'''
.. note::
This function can be extended from `B` to more parallel dimensions, like `(B, S)`, where `S` is the
sequence length in LLM/VLM.
.. note::
For the action mask often used in LLM/VLM, users can set the `weight` to the action mask.
"""
logit_new, logit_old, action, adv, weight = data
if weight is None:
weight = torch.ones_like(adv)
dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
dist_old = torch.distributions.categorical.Categorical(logits=logit_old)
logp_new = dist_new.log_prob(action)
logp_old = dist_old.log_prob(action)
dist_new_entropy = dist_new.entropy()
if dist_new_entropy.shape != weight.shape:
dist_new_entropy = dist_new.entropy().mean(dim=1)
entropy_loss = (dist_new_entropy * weight).mean()

if entropy_bonus:
dist_new_entropy = dist_new.entropy()
if dist_new_entropy.shape != weight.shape: # for the multi-agent rl case
dist_new_entropy = dist_new.entropy().mean(dim=1)
entropy_loss = (dist_new_entropy * weight).mean()
else:
entropy_loss = torch.tensor(0.0)
# policy_loss
ratio = torch.exp(logp_new - logp_old)
if ratio.shape != adv.shape:
Expand Down
88 changes: 88 additions & 0 deletions ding/rl_utils/tests/test_ppo_rlhf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
import numpy as np
import torch
from ding.rl_utils import ppo_policy_data, ppo_value_data, ppo_policy_error, ppo_value_error


@pytest.fixture
def batch_size():
return 4


@pytest.fixture
def seq_length():
return 8


@pytest.fixture
def dictionary_num():
return 1000


@pytest.mark.unittest
def test_policy_loss_without_mask(batch_size: int, seq_length: int, dictionary_num: int):
# Create test data
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
action = torch.randint(0, 10, (batch_size, seq_length))
advantages = torch.randn(batch_size, seq_length)

# Compute loss
data = ppo_policy_data(logit_new, logit_old, action, advantages, weight=None)
loss, info = ppo_policy_error(data, clip_ratio=0.2, entropy_bonus=False)

# Verify output
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # scalar
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)
assert logit_new.grad is None
loss.policy_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)
assert all([np.isscalar(i) for i in info])


@pytest.mark.unittest
def test_policy_loss_with_mask(batch_size: int, seq_length: int, dictionary_num: int):
# Create test data
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
action = torch.randint(0, 10, (batch_size, seq_length))
advantages = torch.randn(batch_size, seq_length)
action_mask = torch.ones(batch_size, seq_length)
action_mask[:, -2:] = 0 # Set last two timesteps as padding

# Compute loss
data = ppo_policy_data(logit_new, logit_old, action, advantages, weight=action_mask)
loss, info = ppo_policy_error(data, clip_ratio=0.2, entropy_bonus=False)

# Verify output
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # scalar
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)
assert logit_new.grad is None
loss.policy_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)
assert all([np.isscalar(i) for i in info])


@pytest.mark.unittest
def test_value_loss(batch_size: int, seq_length: int):
# Create test data
values = torch.randn(batch_size, seq_length).requires_grad_(True)
old_values = values + torch.randn_like(values) * 0.1
returns = torch.randn(batch_size, seq_length)

# Compute loss
data = ppo_value_data(values, old_values, returns, weight=None)
value_loss = ppo_value_error(data, clip_ratio=0.2, use_value_clip=True)

# Verify output
assert isinstance(value_loss, torch.Tensor)
assert value_loss.shape == torch.Size([])
assert not torch.isnan(value_loss)
assert not torch.isinf(value_loss)
assert values.grad is None
value_loss.backward()
assert isinstance(values.grad, torch.Tensor)

0 comments on commit e8ef818

Please sign in to comment.