Skip to content

Commit 8d34eac

Browse files
你的名字Berit-chengyi
authored andcommitted
test&implement(dcy): add unit tests for GRPO and RLOO
- Add test_grpo_rlhf.py for GRPO unit tests - Add test_rloo_rlhf.py for RLOO unit tests - Update GRPO implementation - Update RLOO implementation
1 parent 2e49437 commit 8d34eac

File tree

4 files changed

+323
-0
lines changed

4 files changed

+323
-0
lines changed

ding/rl_utils/grpo.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,76 @@
1010
def grpo_policy_error(
1111
data: namedtuple,
1212
clip_ratio: float = 0.2,
13+
beta: float = 0.1, # KL散度的权重系数
1314
) -> Tuple[namedtuple, namedtuple]:
15+
1416
"""
1517
.. note::
1618
Each element in this input data is a group of response samples from the same prompt.
1719
"""
20+
"""计算GRPO (Generalized Reward-Conditioned Policy Optimization) 的策略损失
21+
22+
Args:
23+
data (grpo_policy_data): 包含以下字段的数据:
24+
- logit_new: 当前策略的logits [batch_size, seq_len, vocab_size]
25+
- logit_old: 旧策略的logits [batch_size, seq_len, vocab_size]
26+
- logit_ref: 参考策略的logits [batch_size, seq_len, vocab_size]
27+
- action: 实际采取的动作 [batch_size, seq_len]
28+
- adv: 优势值 [batch_size]
29+
- weight: 注意力掩码 [batch_size, seq_len]
30+
clip_ratio (float): PPO截断比率,默认0.2
31+
beta (float): KL散度的权重系数,默认0.1
32+
33+
Returns:
34+
Tuple[namedtuple, namedtuple]:
35+
- 第一个namedtuple包含policy_loss
36+
- 第二个namedtuple包含额外的指标信息
37+
"""
38+
# 计算每个token的log概率
39+
log_prob_new = torch.log_softmax(data.logit_new, dim=-1)
40+
log_prob_old = torch.log_softmax(data.logit_old, dim=-1)
41+
log_prob_ref = torch.log_softmax(data.logit_ref, dim=-1)
42+
43+
# 获取选定动作的log概率
44+
action = data.action.unsqueeze(-1) # [batch_size, seq_len, 1]
45+
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1) # [batch_size, seq_len]
46+
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1)
47+
per_token_ref_logps = torch.gather(log_prob_ref, -1, action).squeeze(-1)
48+
49+
# 计算KL散度: exp(q-p) - (q-p) - 1,其中p是当前策略,q是参考策略
50+
per_token_kl = torch.exp(per_token_ref_logps - per_token_logps) - \
51+
(per_token_ref_logps - per_token_logps) - 1
52+
53+
# 计算策略比率
54+
ratio = torch.exp(per_token_logps - per_token_old_logps)
55+
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)
56+
57+
# 计算每个token的损失
58+
advantages = data.adv.unsqueeze(1) # [batch_size, 1]
59+
per_token_loss_unclipped = ratio * advantages
60+
per_token_loss_clipped = ratio_clipped * advantages
61+
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)
62+
63+
# 添加KL散度正则化项
64+
per_token_loss = per_token_loss + beta * per_token_kl
65+
66+
# 使用weight计算平均损失
67+
weight = data.weight if data.weight is not None else torch.ones_like(per_token_loss)
68+
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()
69+
70+
# 计算额外的指标
71+
metrics = {
72+
'mean_kl': ((per_token_kl * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
73+
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
74+
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + \
75+
(ratio < (1 - clip_ratio)).float().mean().item(),
76+
}
77+
78+
# 创建返回的namedtuple对象
79+
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
80+
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)
81+
82+
return loss_info, metric_info
83+
84+
1885
raise NotImplementedError

ding/rl_utils/rloo.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,63 @@ def rloo_policy_error(
1111
data: namedtuple,
1212
clip_ratio: float = 0.2,
1313
) -> Tuple[namedtuple, namedtuple]:
14+
1415
"""
1516
.. note::
1617
Each element in this input data is a group of response samples from the same prompt.
1718
"""
19+
"""计算RLOO (Rejection Learning with Optimistic Optimization) 的策略损失
20+
21+
Args:
22+
data (rloo_policy_data): 包含以下字段的数据:
23+
- logit_new: 当前策略的logits [batch_size, seq_len, vocab_size]
24+
- logit_old: 旧策略的logits [batch_size, seq_len, vocab_size]
25+
- action: 实际采取的动作 [batch_size, seq_len]
26+
- adv: 优势值 [batch_size]
27+
- weight: 注意力掩码 [batch_size, seq_len]
28+
clip_ratio (float): PPO截断比率,默认0.2
29+
30+
Returns:
31+
Tuple[namedtuple, namedtuple]:
32+
- 第一个namedtuple包含policy_loss
33+
- 第二个namedtuple包含额外的指标信息
34+
"""
35+
# 计算每个token的log概率
36+
log_prob_new = torch.log_softmax(data.logit_new, dim=-1)
37+
log_prob_old = torch.log_softmax(data.logit_old, dim=-1)
38+
39+
# 获取选定动作的log概率
40+
action = data.action.unsqueeze(-1) # [batch_size, seq_len, 1]
41+
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1) # [batch_size, seq_len]
42+
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1)
43+
44+
# 计算策略比率
45+
ratio = torch.exp(per_token_logps - per_token_old_logps)
46+
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)
47+
48+
# 计算每个token的损失
49+
advantages = data.adv.unsqueeze(1) # [batch_size, 1]
50+
per_token_loss_unclipped = ratio * advantages
51+
per_token_loss_clipped = ratio_clipped * advantages
52+
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)
53+
54+
# 使用weight计算平均损失
55+
weight = data.weight if data.weight is not None else torch.ones_like(per_token_loss)
56+
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()
57+
58+
# 计算额外的指标
59+
metrics = {
60+
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
61+
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + \
62+
(ratio < (1 - clip_ratio)).float().mean().item(),
63+
'mean_advantage': advantages.mean().item(),
64+
}
65+
66+
# 创建返回的namedtuple对象
67+
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
68+
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)
69+
70+
return loss_info, metric_info
71+
72+
1873
raise NotImplementedError

ding/rl_utils/tests/test_grpo_rlhf.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import pytest
2+
import numpy as np
3+
import torch
4+
from ding.rl_utils.grpo import grpo_policy_data, grpo_policy_error # 导入GRPO相关函数
5+
6+
7+
@pytest.fixture
8+
def batch_size():
9+
return 4
10+
11+
12+
@pytest.fixture
13+
def seq_length():
14+
return 8
15+
16+
17+
@pytest.fixture
18+
def dictionary_num():
19+
return 1000
20+
21+
22+
23+
@pytest.mark.unittest
24+
def test_grpo_policy_loss_with_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
25+
"""测试GRPO策略损失的计算"""
26+
# 1. 创建测试数据
27+
logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True) # 当前策略的logits
28+
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 旧策略的logits(稍微偏离当前策略)
29+
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2 # 参考策略的logits
30+
action = torch.randint(0, vocab_size, (batch_size, seq_length)) # 随机采样的token
31+
adv = torch.randn(batch_size) # 每个序列的优势值
32+
weight = torch.ones(batch_size, seq_length) # 掩码
33+
weight[:, -2:] = 0 # 设置最后两个时间步为padding
34+
35+
# 2. 创建grpo_policy_data实例
36+
data = grpo_policy_data(
37+
logit_new=logit_new, # 当前策略的输出
38+
logit_old=logit_old, # 旧策略的输出
39+
logit_ref=logit_ref, # 参考策略的输出
40+
action=action, # 实际采样的token
41+
adv=adv, # 优势值
42+
weight=weight # 掩码
43+
)
44+
45+
# 3. 计算GRPO损失
46+
loss, info = grpo_policy_error(
47+
data=data,
48+
clip_ratio=0.2, # PPO截断比率
49+
beta=0.1 # KL散度权重
50+
)
51+
52+
# 4. 验证输出
53+
assert isinstance(loss.policy_loss, torch.Tensor)
54+
assert loss.policy_loss.shape == torch.Size([]) # 确保是标量
55+
assert not torch.isnan(loss.policy_loss)
56+
assert not torch.isinf(loss.policy_loss)
57+
58+
# 5. 测试梯度
59+
assert logit_new.grad is None
60+
loss.policy_loss.backward()
61+
assert isinstance(logit_new.grad, torch.Tensor)
62+
63+
# 6. 验证指标
64+
assert 'mean_kl' in info._asdict()
65+
assert 'mean_ratio' in info._asdict()
66+
assert 'mean_clipped' in info._asdict()
67+
assert all([np.isscalar(v) for v in info._asdict().values()])
68+
69+
70+
@pytest.mark.unittest
71+
def test_grpo_policy_loss_without_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
72+
"""测试GRPO策略损失的计算"""
73+
# 1. 创建测试数据
74+
logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True) # 当前策略的logits
75+
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 旧策略的logits(稍微偏离当前策略)
76+
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2 # 参考策略的logits
77+
action = torch.randint(0, vocab_size, (batch_size, seq_length)) # 随机采样的token
78+
adv = torch.randn(batch_size) # 每个序列的优势值
79+
80+
81+
# 2. 创建grpo_policy_data实例
82+
data = grpo_policy_data(
83+
logit_new=logit_new, # 当前策略的输出
84+
logit_old=logit_old, # 旧策略的输出
85+
logit_ref=logit_ref, # 参考策略的输出
86+
action=action, # 实际采样的token
87+
adv=adv, # 优势值
88+
weight=None # 掩码
89+
)
90+
91+
# 3. 计算GRPO损失
92+
loss, info = grpo_policy_error(
93+
data=data,
94+
clip_ratio=0.2, # PPO截断比率
95+
beta=0.1 # KL散度权重
96+
)
97+
98+
# 4. 验证输出
99+
assert isinstance(loss.policy_loss, torch.Tensor)
100+
assert loss.policy_loss.shape == torch.Size([]) # 确保是标量
101+
assert not torch.isnan(loss.policy_loss)
102+
assert not torch.isinf(loss.policy_loss)
103+
104+
# 5. 测试梯度
105+
assert logit_new.grad is None
106+
loss.policy_loss.backward()
107+
assert isinstance(logit_new.grad, torch.Tensor)
108+
109+
# 6. 验证指标
110+
assert 'mean_kl' in info._asdict()
111+
assert 'mean_ratio' in info._asdict()
112+
assert 'mean_clipped' in info._asdict()
113+
assert all([np.isscalar(v) for v in info._asdict().values()])

ding/rl_utils/tests/test_rloo_rlhf.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import pytest
2+
import numpy as np
3+
import torch
4+
from ding.rl_utils.rloo import rloo_policy_data, rloo_policy_error
5+
6+
7+
@pytest.fixture
8+
def batch_size():
9+
return 4
10+
11+
12+
@pytest.fixture
13+
def seq_length():
14+
return 8
15+
16+
17+
@pytest.fixture
18+
def dictionary_num():
19+
return 1000
20+
21+
22+
@pytest.mark.unittest
23+
def test_rloo_policy_loss_without_mask(batch_size: int, seq_length: int, dictionary_num: int):
24+
"""测试不带掩码的RLOO策略损失计算"""
25+
# 创建测试数据
26+
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
27+
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 稍微偏离当前策略
28+
action = torch.randint(0, dictionary_num, (batch_size, seq_length))
29+
advantages = torch.randn(batch_size) # RLOO中每个序列只有一个优势值
30+
31+
# 计算损失
32+
data = rloo_policy_data(
33+
logit_new=logit_new,
34+
logit_old=logit_old,
35+
action=action,
36+
adv=advantages,
37+
weight=None
38+
)
39+
loss, info = rloo_policy_error(data, clip_ratio=0.2)
40+
41+
# 验证输出
42+
assert isinstance(loss.policy_loss, torch.Tensor)
43+
assert loss.policy_loss.shape == torch.Size([]) # 标量
44+
assert not torch.isnan(loss.policy_loss)
45+
assert not torch.isinf(loss.policy_loss)
46+
assert logit_new.grad is None
47+
loss.policy_loss.backward()
48+
assert isinstance(logit_new.grad, torch.Tensor)
49+
assert all([np.isscalar(v) for v in info._asdict().values()])
50+
51+
52+
@pytest.mark.unittest
53+
def test_rloo_policy_loss_with_mask(batch_size: int, seq_length: int, dictionary_num: int):
54+
"""测试带掩码的RLOO策略损失计算"""
55+
# 创建测试数据
56+
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
57+
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 稍微偏离当前策略
58+
action = torch.randint(0, dictionary_num, (batch_size, seq_length))
59+
advantages = torch.randn(batch_size) # RLOO中每个序列只有一个优势值
60+
action_mask = torch.ones(batch_size, seq_length)
61+
action_mask[:, -2:] = 0 # 设置最后两个时间步为padding
62+
63+
# 计算损失
64+
data = rloo_policy_data(
65+
logit_new=logit_new,
66+
logit_old=logit_old,
67+
action=action,
68+
adv=advantages,
69+
weight=action_mask
70+
)
71+
loss, info = rloo_policy_error(data, clip_ratio=0.2)
72+
73+
# 验证输出
74+
assert isinstance(loss.policy_loss, torch.Tensor)
75+
assert loss.policy_loss.shape == torch.Size([]) # 标量
76+
assert not torch.isnan(loss.policy_loss)
77+
assert not torch.isinf(loss.policy_loss)
78+
assert logit_new.grad is None
79+
loss.policy_loss.backward()
80+
assert isinstance(logit_new.grad, torch.Tensor)
81+
82+
# 验证指标
83+
assert 'mean_ratio' in info._asdict()
84+
assert 'mean_clipped' in info._asdict()
85+
assert 'mean_advantage' in info._asdict()
86+
assert all([np.isscalar(v) for v in info._asdict().values()])
87+
88+

0 commit comments

Comments
 (0)