Skip to content

Commit

Permalink
feat(rlhf): add unit tests for GRPO and RLOO
Browse files Browse the repository at this point in the history
- Add test_grpo_rlhf.py for GRPO unit tests
- Add test_rloo_rlhf.py for RLOO unit tests
- Update GRPO implementation
- Update RLOO implementation
  • Loading branch information
Jinnyyy-zhang committed Feb 13, 2025
1 parent 2e49437 commit d3f6f3f
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 0 deletions.
67 changes: 67 additions & 0 deletions ding/rl_utils/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,76 @@
def grpo_policy_error(
data: namedtuple,
clip_ratio: float = 0.2,
beta: float = 0.1, # KL散度的权重系数
) -> Tuple[namedtuple, namedtuple]:

"""
.. note::
Each element in this input data is a group of response samples from the same prompt.
"""
"""计算GRPO (Generalized Reward-Conditioned Policy Optimization) 的策略损失
Args:
data (grpo_policy_data): 包含以下字段的数据:
- logit_new: 当前策略的logits [batch_size, seq_len, vocab_size]
- logit_old: 旧策略的logits [batch_size, seq_len, vocab_size]
- logit_ref: 参考策略的logits [batch_size, seq_len, vocab_size]
- action: 实际采取的动作 [batch_size, seq_len]
- adv: 优势值 [batch_size]
- weight: 注意力掩码 [batch_size, seq_len]
clip_ratio (float): PPO截断比率,默认0.2
beta (float): KL散度的权重系数,默认0.1
Returns:
Tuple[namedtuple, namedtuple]:
- 第一个namedtuple包含policy_loss
- 第二个namedtuple包含额外的指标信息
"""
# 计算每个token的log概率
log_prob_new = torch.log_softmax(data.logit_new, dim=-1)
log_prob_old = torch.log_softmax(data.logit_old, dim=-1)
log_prob_ref = torch.log_softmax(data.logit_ref, dim=-1)

# 获取选定动作的log概率
action = data.action.unsqueeze(-1) # [batch_size, seq_len, 1]
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1) # [batch_size, seq_len]
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1)
per_token_ref_logps = torch.gather(log_prob_ref, -1, action).squeeze(-1)

# 计算KL散度: exp(q-p) - (q-p) - 1,其中p是当前策略,q是参考策略
per_token_kl = torch.exp(per_token_ref_logps - per_token_logps) - \
(per_token_ref_logps - per_token_logps) - 1

# 计算策略比率
ratio = torch.exp(per_token_logps - per_token_old_logps)
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)

# 计算每个token的损失
advantages = data.adv.unsqueeze(1) # [batch_size, 1]
per_token_loss_unclipped = ratio * advantages
per_token_loss_clipped = ratio_clipped * advantages
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)

# 添加KL散度正则化项
per_token_loss = per_token_loss + beta * per_token_kl

# 使用weight计算平均损失
weight = data.weight if data.weight is not None else torch.ones_like(per_token_loss)
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()

# 计算额外的指标
metrics = {
'mean_kl': ((per_token_kl * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + \
(ratio < (1 - clip_ratio)).float().mean().item(),
}

# 创建返回的namedtuple对象
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)

return loss_info, metric_info


raise NotImplementedError
55 changes: 55 additions & 0 deletions ding/rl_utils/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,63 @@ def rloo_policy_error(
data: namedtuple,
clip_ratio: float = 0.2,
) -> Tuple[namedtuple, namedtuple]:

"""
.. note::
Each element in this input data is a group of response samples from the same prompt.
"""
"""计算RLOO (Rejection Learning with Optimistic Optimization) 的策略损失
Args:
data (rloo_policy_data): 包含以下字段的数据:
- logit_new: 当前策略的logits [batch_size, seq_len, vocab_size]
- logit_old: 旧策略的logits [batch_size, seq_len, vocab_size]
- action: 实际采取的动作 [batch_size, seq_len]
- adv: 优势值 [batch_size]
- weight: 注意力掩码 [batch_size, seq_len]
clip_ratio (float): PPO截断比率,默认0.2
Returns:
Tuple[namedtuple, namedtuple]:
- 第一个namedtuple包含policy_loss
- 第二个namedtuple包含额外的指标信息
"""
# 计算每个token的log概率
log_prob_new = torch.log_softmax(data.logit_new, dim=-1)
log_prob_old = torch.log_softmax(data.logit_old, dim=-1)

# 获取选定动作的log概率
action = data.action.unsqueeze(-1) # [batch_size, seq_len, 1]
per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1) # [batch_size, seq_len]
per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1)

# 计算策略比率
ratio = torch.exp(per_token_logps - per_token_old_logps)
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)

# 计算每个token的损失
advantages = data.adv.unsqueeze(1) # [batch_size, 1]
per_token_loss_unclipped = ratio * advantages
per_token_loss_clipped = ratio_clipped * advantages
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)

# 使用weight计算平均损失
weight = data.weight if data.weight is not None else torch.ones_like(per_token_loss)
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()

# 计算额外的指标
metrics = {
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + \
(ratio < (1 - clip_ratio)).float().mean().item(),
'mean_advantage': advantages.mean().item(),
}

# 创建返回的namedtuple对象
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)

return loss_info, metric_info


raise NotImplementedError
113 changes: 113 additions & 0 deletions ding/rl_utils/tests/test_grpo_rlhf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import pytest
import numpy as np
import torch
from ding.rl_utils.grpo import grpo_policy_data, grpo_policy_error # 导入GRPO相关函数


@pytest.fixture
def batch_size():
return 4


@pytest.fixture
def seq_length():
return 8


@pytest.fixture
def dictionary_num():
return 1000



@pytest.mark.unittest
def test_grpo_policy_loss_with_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
"""测试GRPO策略损失的计算"""
# 1. 创建测试数据
logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True) # 当前策略的logits
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 旧策略的logits(稍微偏离当前策略)
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2 # 参考策略的logits
action = torch.randint(0, vocab_size, (batch_size, seq_length)) # 随机采样的token
adv = torch.randn(batch_size) # 每个序列的优势值
weight = torch.ones(batch_size, seq_length) # 掩码
weight[:, -2:] = 0 # 设置最后两个时间步为padding

# 2. 创建grpo_policy_data实例
data = grpo_policy_data(
logit_new=logit_new, # 当前策略的输出
logit_old=logit_old, # 旧策略的输出
logit_ref=logit_ref, # 参考策略的输出
action=action, # 实际采样的token
adv=adv, # 优势值
weight=weight # 掩码
)

# 3. 计算GRPO损失
loss, info = grpo_policy_error(
data=data,
clip_ratio=0.2, # PPO截断比率
beta=0.1 # KL散度权重
)

# 4. 验证输出
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # 确保是标量
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)

# 5. 测试梯度
assert logit_new.grad is None
loss.policy_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)

# 6. 验证指标
assert 'mean_kl' in info._asdict()
assert 'mean_ratio' in info._asdict()
assert 'mean_clipped' in info._asdict()
assert all([np.isscalar(v) for v in info._asdict().values()])


@pytest.mark.unittest
def test_grpo_policy_loss_without_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000):
"""测试GRPO策略损失的计算"""
# 1. 创建测试数据
logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True) # 当前策略的logits
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 旧策略的logits(稍微偏离当前策略)
logit_ref = logit_new + torch.randn_like(logit_new) * 0.2 # 参考策略的logits
action = torch.randint(0, vocab_size, (batch_size, seq_length)) # 随机采样的token
adv = torch.randn(batch_size) # 每个序列的优势值


# 2. 创建grpo_policy_data实例
data = grpo_policy_data(
logit_new=logit_new, # 当前策略的输出
logit_old=logit_old, # 旧策略的输出
logit_ref=logit_ref, # 参考策略的输出
action=action, # 实际采样的token
adv=adv, # 优势值
weight=None # 掩码
)

# 3. 计算GRPO损失
loss, info = grpo_policy_error(
data=data,
clip_ratio=0.2, # PPO截断比率
beta=0.1 # KL散度权重
)

# 4. 验证输出
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # 确保是标量
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)

# 5. 测试梯度
assert logit_new.grad is None
loss.policy_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)

# 6. 验证指标
assert 'mean_kl' in info._asdict()
assert 'mean_ratio' in info._asdict()
assert 'mean_clipped' in info._asdict()
assert all([np.isscalar(v) for v in info._asdict().values()])
88 changes: 88 additions & 0 deletions ding/rl_utils/tests/test_rloo_rlhf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
import numpy as np
import torch
from ding.rl_utils.rloo import rloo_policy_data, rloo_policy_error


@pytest.fixture
def batch_size():
return 4


@pytest.fixture
def seq_length():
return 8


@pytest.fixture
def dictionary_num():
return 1000


@pytest.mark.unittest
def test_rloo_policy_loss_without_mask(batch_size: int, seq_length: int, dictionary_num: int):
"""测试不带掩码的RLOO策略损失计算"""
# 创建测试数据
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 稍微偏离当前策略
action = torch.randint(0, dictionary_num, (batch_size, seq_length))
advantages = torch.randn(batch_size) # RLOO中每个序列只有一个优势值

# 计算损失
data = rloo_policy_data(
logit_new=logit_new,
logit_old=logit_old,
action=action,
adv=advantages,
weight=None
)
loss, info = rloo_policy_error(data, clip_ratio=0.2)

# 验证输出
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # 标量
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)
assert logit_new.grad is None
loss.policy_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)
assert all([np.isscalar(v) for v in info._asdict().values()])


@pytest.mark.unittest
def test_rloo_policy_loss_with_mask(batch_size: int, seq_length: int, dictionary_num: int):
"""测试带掩码的RLOO策略损失计算"""
# 创建测试数据
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 稍微偏离当前策略
action = torch.randint(0, dictionary_num, (batch_size, seq_length))
advantages = torch.randn(batch_size) # RLOO中每个序列只有一个优势值
action_mask = torch.ones(batch_size, seq_length)
action_mask[:, -2:] = 0 # 设置最后两个时间步为padding

# 计算损失
data = rloo_policy_data(
logit_new=logit_new,
logit_old=logit_old,
action=action,
adv=advantages,
weight=action_mask
)
loss, info = rloo_policy_error(data, clip_ratio=0.2)

# 验证输出
assert isinstance(loss.policy_loss, torch.Tensor)
assert loss.policy_loss.shape == torch.Size([]) # 标量
assert not torch.isnan(loss.policy_loss)
assert not torch.isinf(loss.policy_loss)
assert logit_new.grad is None
loss.policy_loss.backward()
assert isinstance(logit_new.grad, torch.Tensor)

# 验证指标
assert 'mean_ratio' in info._asdict()
assert 'mean_clipped' in info._asdict()
assert 'mean_advantage' in info._asdict()
assert all([np.isscalar(v) for v in info._asdict().values()])


0 comments on commit d3f6f3f

Please sign in to comment.