Skip to content

Commit

Permalink
(dcy) redefine adv
Browse files Browse the repository at this point in the history
  • Loading branch information
Berit-chengyi committed Feb 18, 2025
1 parent 2cbd9fb commit 7bcd64d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
27 changes: 18 additions & 9 deletions ding/rl_utils/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import namedtuple
import torch

rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight'])
rloo_policy_data = namedtuple('rloo_policy_data',['logit_new', 'logit_old', 'action', 'reward', 'weight'])


def rloo_policy_error(
Expand All @@ -16,7 +16,7 @@ def rloo_policy_error(
- logit_new: Current policy logits [B, L, V]
- logit_old: Old policy logits [B, L, V]
- action: Actions taken [B, L]
- adv: Advantage values [B]
- reward: Advantage values [B]
- weight: Attention mask [B, L]
clip_ratio (float): PPO clipping ratio, default 0.2
Expand All @@ -25,6 +25,12 @@ def rloo_policy_error(
- First namedtuple contains policy_loss
- Second namedtuple contains additional metrics
"""
# Calculate advantage of each action
rloo_k = data.reward.size(0)
baseline = (data.reward.sum(0) - data.reward) / (rloo_k - 1)
adv = data.reward - baseline
adv = adv.flatten()

# Calculate log probabilities for each token
log_prob_new = torch.log_softmax(data.logit_new, dim=-1)
log_prob_old = torch.log_softmax(data.logit_old, dim=-1)
Expand All @@ -39,25 +45,28 @@ def rloo_policy_error(
ratio_clipped = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)

# Calculate loss for each token
advantages = data.adv.unsqueeze(1) # [B, 1]
advantages = adv.unsqueeze(1) # [B, 1]
per_token_loss_unclipped = ratio * advantages
per_token_loss_clipped = ratio_clipped * advantages
per_token_loss = -torch.min(per_token_loss_unclipped, per_token_loss_clipped)
per_token_loss = -torch.min(per_token_loss_unclipped,
per_token_loss_clipped)

# Calculate average loss using weight mask
weight = data.weight if data.weight is not None else (torch.ones_like(per_token_loss))
weight = data.weight if data.weight is not None else (
torch.ones_like(per_token_loss))
loss = ((per_token_loss * weight).sum(dim=1) / weight.sum(dim=1)).mean()

# Calculate additional metrics
metrics = {
'mean_ratio': ((ratio * weight).sum(dim=1) / weight.sum(dim=1)).mean().item(),
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() + (ratio <
(1 - clip_ratio)).float().mean().item(),
'mean_ratio': ((ratio * weight).sum(dim=1) /
weight.sum(dim=1)).mean().item(),
'mean_clipped': (ratio > (1 + clip_ratio)).float().mean().item() +
(ratio < (1 - clip_ratio)).float().mean().item(),
'mean_advantage': advantages.mean().item(),
}

# Create return namedtuples
loss_info = namedtuple('LossInfo', ['policy_loss'])(policy_loss=loss)
metric_info = namedtuple('MetricInfo', list(metrics.keys()))(**metrics)

return loss_info, metric_info
return loss_info, metric_info
28 changes: 21 additions & 7 deletions ding/rl_utils/tests/test_rloo_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,20 @@ def dictionary_num():
def test_rloo_policy_loss_without_mask(batch_size, seq_length, dictionary_num):
"""Test RLOO policy loss calculation without mask"""
# Create test data
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
logit_new = torch.randn(batch_size,
seq_length, dictionary_num).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
action = torch.randint(0, dictionary_num, (batch_size, seq_length))
advantages = torch.randn(batch_size)
reward = torch.randn(batch_size)

# Calculate loss
data = rloo_policy_data(logit_new=logit_new, logit_old=logit_old, action=action, adv=advantages, weight=None)
data = rloo_policy_data(
logit_new=logit_new,
logit_old=logit_old,
action=action,
reward=reward,
weight=None
)
loss, info = rloo_policy_error(data, clip_ratio=0.2)

# Verify outputs
Expand All @@ -47,15 +54,22 @@ def test_rloo_policy_loss_without_mask(batch_size, seq_length, dictionary_num):
def test_rloo_policy_loss_with_mask(batch_size, seq_length, dictionary_num):
"""Test RLOO policy loss calculation with mask"""
# Create test data
logit_new = torch.randn(batch_size, seq_length, dictionary_num).requires_grad_(True)
logit_new = torch.randn(batch_size,
seq_length, dictionary_num).requires_grad_(True)
logit_old = logit_new + torch.randn_like(logit_new) * 0.1
action = torch.randint(0, dictionary_num, (batch_size, seq_length))
advantages = torch.randn(batch_size)
reward = torch.randn(batch_size)
action_mask = torch.ones(batch_size, seq_length)
action_mask[:, -2:] = 0

# Calculate loss
data = rloo_policy_data(logit_new=logit_new, logit_old=logit_old, action=action, adv=advantages, weight=action_mask)
data = rloo_policy_data(
logit_new=logit_new,
logit_old=logit_old,
action=action,
reward=reward,
weight=action_mask
)
loss, info = rloo_policy_error(data, clip_ratio=0.2)

# Verify outputs
Expand All @@ -71,4 +85,4 @@ def test_rloo_policy_loss_with_mask(batch_size, seq_length, dictionary_num):
assert 'mean_ratio' in info._asdict()
assert 'mean_clipped' in info._asdict()
assert 'mean_advantage' in info._asdict()
assert all([np.isscalar(v) for v in info._asdict().values()])
assert all([np.isscalar(v) for v in info._asdict().values()])

0 comments on commit 7bcd64d

Please sign in to comment.