From d3f6f3fe9fafe0c9c7f121556499ab7e5f63d70f Mon Sep 17 00:00:00 2001 From: Jinnyyy-zhang Date: Thu, 13 Feb 2025 20:23:37 +0800 Subject: [PATCH] feat(rlhf): add unit tests for GRPO and RLOO - Add test_grpo_rlhf.py for GRPO unit tests - Add test_rloo_rlhf.py for RLOO unit tests - Update GRPO implementation - Update RLOO implementation --- ding/rl_utils/grpo.py | 67 +++++++++++++++ ding/rl_utils/rloo.py | 55 +++++++++++++ ding/rl_utils/tests/test_grpo_rlhf.py | 113 ++++++++++++++++++++++++++ ding/rl_utils/tests/test_rloo_rlhf.py | 88 ++++++++++++++++++++ 4 files changed, 323 insertions(+) create mode 100644 ding/rl_utils/tests/test_grpo_rlhf.py create mode 100644 ding/rl_utils/tests/test_rloo_rlhf.py diff --git a/ding/rl_utils/grpo.py b/ding/rl_utils/grpo.py index 6ad7204161..7e7852d9e1 100644 --- a/ding/rl_utils/grpo.py +++ b/ding/rl_utils/grpo.py @@ -10,9 +10,76 @@ def grpo_policy_error( data: namedtuple, clip_ratio: float = 0.2, + beta: float = 0.1, # KL散度的权重系数 ) -> Tuple[namedtuple, namedtuple]: + """ .. note:: Each element in this input data is a group of response samples from the same prompt. """ + """计算GRPO (Generalized Reward-Conditioned Policy Optimization) 的策略损失 + + Args: + data (grpo_policy_data): 包含以下字段的数据: + - logit_new: 当前策略的logits [batch_size, seq_len, vocab_size] + - logit_old: 旧策略的logits [batch_size, seq_len, vocab_size] + - logit_ref: 参考策略的logits [batch_size, seq_len, vocab_size] + - action: 实际采取的动作 [batch_size, seq_len] + - adv: 优势值 [batch_size] + - weight: 注意力掩码 [batch_size, seq_len] + clip_ratio (float): PPO截断比率,默认0.2 + beta (float): KL散度的权重系数,默认0.1 + + Returns: + Tuple[namedtuple, namedtuple]: + - 第一个namedtuple包含policy_loss + - 第二个namedtuple包含额外的指标信息 + """ + # 计算每个token的log概率 + log_prob_new = torch.log_softmax(data.logit_new, dim=-1) + log_prob_old = torch.log_softmax(data.logit_old, dim=-1) + log_prob_ref = torch.log_softmax(data.logit_ref, dim=-1) + + # 获取选定动作的log概率 + action = data.action.unsqueeze(-1) # [batch_size, seq_len, 1] + per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1) # [batch_size, seq_len] + per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1) + per_token_ref_logps = torch.gather(log_prob_ref, -1, action).squeeze(-1) + + # 计算KL散度: exp(q-p) - (q-p) - 1,其中p是当前策略,q是参考策略 + per_token_kl = torch.exp(per_token_ref_logps - per_token_logps) - \ + (per_token_ref_logps - per_token_logps) - 1 + + # 计算策略比率 + ratio = torch.exp(per_token_logps - per_token_old_logps) + ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) + + # 计算每个token的损失 + advantages = data.adv.unsqueeze(1) # [batch_size, 1] + per_token_loss_unclipped = ratio * advantages + per_token_loss_clipped = ratio_clipped * advantages + per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped) + + # 添加KL散度正则化项 + per_token_loss = per_token_loss + beta * per_token_kl + + # 使用weight计算平均损失 + weight = data.weight if data.weight is not None else torch.ones_like(per_token_loss) + loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean() + + # 计算额外的指标 + metrics = { + 'mean_kl': ((per_token_kl * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(), + 'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(), + 'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + \ + (ratio < (1 - clip_ratio)).float().mean().item(), + } + + # 创建返回的namedtuple对象 + loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss) + metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics) + + return loss_info, metric_info + + raise NotImplementedError diff --git a/ding/rl_utils/rloo.py b/ding/rl_utils/rloo.py index 85f7b95aa9..3cb8272373 100644 --- a/ding/rl_utils/rloo.py +++ b/ding/rl_utils/rloo.py @@ -11,8 +11,63 @@ def rloo_policy_error( data: namedtuple, clip_ratio: float = 0.2, ) -> Tuple[namedtuple, namedtuple]: + """ .. note:: Each element in this input data is a group of response samples from the same prompt. """ + """计算RLOO (Rejection Learning with Optimistic Optimization) 的策略损失 + + Args: + data (rloo_policy_data): 包含以下字段的数据: + - logit_new: 当前策略的logits [batch_size, seq_len, vocab_size] + - logit_old: 旧策略的logits [batch_size, seq_len, vocab_size] + - action: 实际采取的动作 [batch_size, seq_len] + - adv: 优势值 [batch_size] + - weight: 注意力掩码 [batch_size, seq_len] + clip_ratio (float): PPO截断比率,默认0.2 + + Returns: + Tuple[namedtuple, namedtuple]: + - 第一个namedtuple包含policy_loss + - 第二个namedtuple包含额外的指标信息 + """ + # 计算每个token的log概率 + log_prob_new = torch.log_softmax(data.logit_new, dim=-1) + log_prob_old = torch.log_softmax(data.logit_old, dim=-1) + + # 获取选定动作的log概率 + action = data.action.unsqueeze(-1) # [batch_size, seq_len, 1] + per_token_logps = torch.gather(log_prob_new, -1, action).squeeze(-1) # [batch_size, seq_len] + per_token_old_logps = torch.gather(log_prob_old, -1, action).squeeze(-1) + + # 计算策略比率 + ratio = torch.exp(per_token_logps - per_token_old_logps) + ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) + + # 计算每个token的损失 + advantages = data.adv.unsqueeze(1) # [batch_size, 1] + per_token_loss_unclipped = ratio * advantages + per_token_loss_clipped = ratio_clipped * advantages + per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped) + + # 使用weight计算平均损失 + weight = data.weight if data.weight is not None else torch.ones_like(per_token_loss) + loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean() + + # 计算额外的指标 + metrics = { + 'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(), + 'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + \ + (ratio < (1 - clip_ratio)).float().mean().item(), + 'mean_advantage': advantages.mean().item(), + } + + # 创建返回的namedtuple对象 + loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss) + metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics) + + return loss_info, metric_info + + raise NotImplementedError diff --git a/ding/rl_utils/tests/test_grpo_rlhf.py b/ding/rl_utils/tests/test_grpo_rlhf.py new file mode 100644 index 0000000000..fa4fe259a2 --- /dev/null +++ b/ding/rl_utils/tests/test_grpo_rlhf.py @@ -0,0 +1,113 @@ +import pytest +import numpy as np +import torch +from ding.rl_utils.grpo import grpo_policy_data, grpo_policy_error # 导入GRPO相关函数 + + +@pytest.fixture +def batch_size(): + return 4 + + +@pytest.fixture +def seq_length(): + return 8 + + +@pytest.fixture +def dictionary_num(): + return 1000 + + + +@pytest.mark.unittest +def test_grpo_policy_loss_with_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000): + """测试GRPO策略损失的计算""" + # 1. 创建测试数据 + logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True) # 当前策略的logits + logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 旧策略的logits(稍微偏离当前策略) + logit_ref = logit_new + torch.randn_like(logit_new) * 0.2 # 参考策略的logits + action = torch.randint(0, vocab_size, (batch_size, seq_length)) # 随机采样的token + adv = torch.randn(batch_size) # 每个序列的优势值 + weight = torch.ones(batch_size, seq_length) # 掩码 + weight[:, -2:] = 0 # 设置最后两个时间步为padding + + # 2. 创建grpo_policy_data实例 + data = grpo_policy_data( + logit_new=logit_new, # 当前策略的输出 + logit_old=logit_old, # 旧策略的输出 + logit_ref=logit_ref, # 参考策略的输出 + action=action, # 实际采样的token + adv=adv, # 优势值 + weight=weight # 掩码 + ) + + # 3. 计算GRPO损失 + loss, info = grpo_policy_error( + data=data, + clip_ratio=0.2, # PPO截断比率 + beta=0.1 # KL散度权重 + ) + + # 4. 验证输出 + assert isinstance(loss.policy_loss, torch.Tensor) + assert loss.policy_loss.shape == torch.Size([]) # 确保是标量 + assert not torch.isnan(loss.policy_loss) + assert not torch.isinf(loss.policy_loss) + + # 5. 测试梯度 + assert logit_new.grad is None + loss.policy_loss.backward() + assert isinstance(logit_new.grad, torch.Tensor) + + # 6. 验证指标 + assert 'mean_kl' in info._asdict() + assert 'mean_ratio' in info._asdict() + assert 'mean_clipped' in info._asdict() + assert all([np.isscalar(v) for v in info._asdict().values()]) + + +@pytest.mark.unittest +def test_grpo_policy_loss_without_mask(batch_size: int = 4, seq_length: int = 8, vocab_size: int = 1000): + """测试GRPO策略损失的计算""" + # 1. 创建测试数据 + logit_new = torch.randn(batch_size, seq_length, vocab_size).requires_grad_(True) # 当前策略的logits + logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 旧策略的logits(稍微偏离当前策略) + logit_ref = logit_new + torch.randn_like(logit_new) * 0.2 # 参考策略的logits + action = torch.randint(0, vocab_size, (batch_size, seq_length)) # 随机采样的token + adv = torch.randn(batch_size) # 每个序列的优势值 + + + # 2. 创建grpo_policy_data实例 + data = grpo_policy_data( + logit_new=logit_new, # 当前策略的输出 + logit_old=logit_old, # 旧策略的输出 + logit_ref=logit_ref, # 参考策略的输出 + action=action, # 实际采样的token + adv=adv, # 优势值 + weight=None # 掩码 + ) + + # 3. 计算GRPO损失 + loss, info = grpo_policy_error( + data=data, + clip_ratio=0.2, # PPO截断比率 + beta=0.1 # KL散度权重 + ) + + # 4. 验证输出 + assert isinstance(loss.policy_loss, torch.Tensor) + assert loss.policy_loss.shape == torch.Size([]) # 确保是标量 + assert not torch.isnan(loss.policy_loss) + assert not torch.isinf(loss.policy_loss) + + # 5. 测试梯度 + assert logit_new.grad is None + loss.policy_loss.backward() + assert isinstance(logit_new.grad, torch.Tensor) + + # 6. 验证指标 + assert 'mean_kl' in info._asdict() + assert 'mean_ratio' in info._asdict() + assert 'mean_clipped' in info._asdict() + assert all([np.isscalar(v) for v in info._asdict().values()]) \ No newline at end of file diff --git a/ding/rl_utils/tests/test_rloo_rlhf.py b/ding/rl_utils/tests/test_rloo_rlhf.py new file mode 100644 index 0000000000..c92ef56d61 --- /dev/null +++ b/ding/rl_utils/tests/test_rloo_rlhf.py @@ -0,0 +1,88 @@ +import pytest +import numpy as np +import torch +from ding.rl_utils.rloo import rloo_policy_data, rloo_policy_error + + +@pytest.fixture +def batch_size(): + return 4 + + +@pytest.fixture +def seq_length(): + return 8 + + +@pytest.fixture +def dictionary_num(): + return 1000 + + +@pytest.mark.unittest +def test_rloo_policy_loss_without_mask(batch_size: int, seq_length: int, dictionary_num: int): + """测试不带掩码的RLOO策略损失计算""" + # 创建测试数据 + logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True) + logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 稍微偏离当前策略 + action = torch.randint(0, dictionary_num, (batch_size, seq_length)) + advantages = torch.randn(batch_size) # RLOO中每个序列只有一个优势值 + + # 计算损失 + data = rloo_policy_data( + logit_new=logit_new, + logit_old=logit_old, + action=action, + adv=advantages, + weight=None + ) + loss, info = rloo_policy_error(data, clip_ratio=0.2) + + # 验证输出 + assert isinstance(loss.policy_loss, torch.Tensor) + assert loss.policy_loss.shape == torch.Size([]) # 标量 + assert not torch.isnan(loss.policy_loss) + assert not torch.isinf(loss.policy_loss) + assert logit_new.grad is None + loss.policy_loss.backward() + assert isinstance(logit_new.grad, torch.Tensor) + assert all([np.isscalar(v) for v in info._asdict().values()]) + + +@pytest.mark.unittest +def test_rloo_policy_loss_with_mask(batch_size: int, seq_length: int, dictionary_num: int): + """测试带掩码的RLOO策略损失计算""" + # 创建测试数据 + logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True) + logit_old = logit_new + torch.randn_like(logit_new) * 0.1 # 稍微偏离当前策略 + action = torch.randint(0, dictionary_num, (batch_size, seq_length)) + advantages = torch.randn(batch_size) # RLOO中每个序列只有一个优势值 + action_mask = torch.ones(batch_size, seq_length) + action_mask[:, -2:] = 0 # 设置最后两个时间步为padding + + # 计算损失 + data = rloo_policy_data( + logit_new=logit_new, + logit_old=logit_old, + action=action, + adv=advantages, + weight=action_mask + ) + loss, info = rloo_policy_error(data, clip_ratio=0.2) + + # 验证输出 + assert isinstance(loss.policy_loss, torch.Tensor) + assert loss.policy_loss.shape == torch.Size([]) # 标量 + assert not torch.isnan(loss.policy_loss) + assert not torch.isinf(loss.policy_loss) + assert logit_new.grad is None + loss.policy_loss.backward() + assert isinstance(logit_new.grad, torch.Tensor) + + # 验证指标 + assert 'mean_ratio' in info._asdict() + assert 'mean_clipped' in info._asdict() + assert 'mean_advantage' in info._asdict() + assert all([np.isscalar(v) for v in info._asdict().values()]) + +