Skip to content

Commit e8ef818

Browse files
committed
test(nyz): polish ppo and add rlhf ppo loss test
1 parent 3b7903a commit e8ef818

File tree

3 files changed

+122
-14
lines changed

3 files changed

+122
-14
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,3 +1429,8 @@ collect_demo_data_config.py
14291429
events.*
14301430

14311431
evogym/*
1432+
ding/example/*
1433+
ding/framework/middleware/tests/wandb/
1434+
ding/.style.yapf
1435+
ding/format.sh
1436+
ding/framework/middleware_v3/

ding/rl_utils/ppo.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,21 @@ def ppo_error(
104104
return ppo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info
105105

106106

107-
def ppo_policy_error(data: namedtuple,
108-
clip_ratio: float = 0.2,
109-
dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
110-
'''
107+
def ppo_policy_error(
108+
data: namedtuple,
109+
clip_ratio: float = 0.2,
110+
dual_clip: Optional[float] = None,
111+
entropy_bonus: bool = True
112+
) -> Tuple[namedtuple, namedtuple]:
113+
"""
111114
Overview:
112-
Get PPO policy loss
115+
Get PPO policy loss (both for classical RL in control/video games and LLM/VLM RLHF).
113116
Arguments:
114-
- data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_policy_data``
115-
- clip_ratio (:obj:`float`): clip value for ratio
116-
- dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
117-
defaults to 5.0, if you don't want to use it, set this parameter to None
117+
- data (:obj:`namedtuple`): Ppo input data with fieids shown in ``ppo_policy_data``.
118+
- clip_ratio (:obj:`float`): Clip value for ratio, defaults to 0.2.
119+
- dual_clip (:obj:`float`): A parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf), \
120+
defaults to 5.0, if you don't want to use it, set this parameter to None
121+
- entropy_bonus (:obj:`bool`): Whether to use entropy bonus, defaults to True. LLM RLHF usually does not use it.
118122
Returns:
119123
- ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
120124
- ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
@@ -136,18 +140,29 @@ def ppo_policy_error(data: namedtuple,
136140
>>> weight=torch.ones(3),
137141
>>> )
138142
>>> loss, info = ppo_policy_error(data)
139-
'''
143+
144+
.. note::
145+
This function can be extended from `B` to more parallel dimensions, like `(B, S)`, where `S` is the
146+
sequence length in LLM/VLM.
147+
148+
.. note::
149+
For the action mask often used in LLM/VLM, users can set the `weight` to the action mask.
150+
"""
140151
logit_new, logit_old, action, adv, weight = data
141152
if weight is None:
142153
weight = torch.ones_like(adv)
143154
dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
144155
dist_old = torch.distributions.categorical.Categorical(logits=logit_old)
145156
logp_new = dist_new.log_prob(action)
146157
logp_old = dist_old.log_prob(action)
147-
dist_new_entropy = dist_new.entropy()
148-
if dist_new_entropy.shape != weight.shape:
149-
dist_new_entropy = dist_new.entropy().mean(dim=1)
150-
entropy_loss = (dist_new_entropy * weight).mean()
158+
159+
if entropy_bonus:
160+
dist_new_entropy = dist_new.entropy()
161+
if dist_new_entropy.shape != weight.shape: # for the multi-agent rl case
162+
dist_new_entropy = dist_new.entropy().mean(dim=1)
163+
entropy_loss = (dist_new_entropy * weight).mean()
164+
else:
165+
entropy_loss = torch.tensor(0.0)
151166
# policy_loss
152167
ratio = torch.exp(logp_new - logp_old)
153168
if ratio.shape != adv.shape:

ding/rl_utils/tests/test_ppo_rlhf.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import pytest
2+
import numpy as np
3+
import torch
4+
from ding.rl_utils import ppo_policy_data, ppo_value_data, ppo_policy_error, ppo_value_error
5+
6+
7+
@pytest.fixture
8+
def batch_size():
9+
return 4
10+
11+
12+
@pytest.fixture
13+
def seq_length():
14+
return 8
15+
16+
17+
@pytest.fixture
18+
def dictionary_num():
19+
return 1000
20+
21+
22+
@pytest.mark.unittest
23+
def test_policy_loss_without_mask(batch_size: int, seq_length: int, dictionary_num: int):
24+
# Create test data
25+
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
26+
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
27+
action = torch.randint(0, 10, (batch_size, seq_length))
28+
advantages = torch.randn(batch_size, seq_length)
29+
30+
# Compute loss
31+
data = ppo_policy_data(logit_new, logit_old, action, advantages, weight=None)
32+
loss, info = ppo_policy_error(data, clip_ratio=0.2, entropy_bonus=False)
33+
34+
# Verify output
35+
assert isinstance(loss.policy_loss, torch.Tensor)
36+
assert loss.policy_loss.shape == torch.Size([]) # scalar
37+
assert not torch.isnan(loss.policy_loss)
38+
assert not torch.isinf(loss.policy_loss)
39+
assert logit_new.grad is None
40+
loss.policy_loss.backward()
41+
assert isinstance(logit_new.grad, torch.Tensor)
42+
assert all([np.isscalar(i) for i in info])
43+
44+
45+
@pytest.mark.unittest
46+
def test_policy_loss_with_mask(batch_size: int, seq_length: int, dictionary_num: int):
47+
# Create test data
48+
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
49+
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
50+
action = torch.randint(0, 10, (batch_size, seq_length))
51+
advantages = torch.randn(batch_size, seq_length)
52+
action_mask = torch.ones(batch_size, seq_length)
53+
action_mask[:, -2:] = 0 # Set last two timesteps as padding
54+
55+
# Compute loss
56+
data = ppo_policy_data(logit_new, logit_old, action, advantages, weight=action_mask)
57+
loss, info = ppo_policy_error(data, clip_ratio=0.2, entropy_bonus=False)
58+
59+
# Verify output
60+
assert isinstance(loss.policy_loss, torch.Tensor)
61+
assert loss.policy_loss.shape == torch.Size([]) # scalar
62+
assert not torch.isnan(loss.policy_loss)
63+
assert not torch.isinf(loss.policy_loss)
64+
assert logit_new.grad is None
65+
loss.policy_loss.backward()
66+
assert isinstance(logit_new.grad, torch.Tensor)
67+
assert all([np.isscalar(i) for i in info])
68+
69+
70+
@pytest.mark.unittest
71+
def test_value_loss(batch_size: int, seq_length: int):
72+
# Create test data
73+
values = torch.randn(batch_size, seq_length).requires_grad_(True)
74+
old_values = values + torch.randn_like(values) * 0.1
75+
returns = torch.randn(batch_size, seq_length)
76+
77+
# Compute loss
78+
data = ppo_value_data(values, old_values, returns, weight=None)
79+
value_loss = ppo_value_error(data, clip_ratio=0.2, use_value_clip=True)
80+
81+
# Verify output
82+
assert isinstance(value_loss, torch.Tensor)
83+
assert value_loss.shape == torch.Size([])
84+
assert not torch.isnan(value_loss)
85+
assert not torch.isinf(value_loss)
86+
assert values.grad is None
87+
value_loss.backward()
88+
assert isinstance(values.grad, torch.Tensor)

0 commit comments

Comments
 (0)