Skip to content

Commit 2e49437

Browse files
committed
interface(nyz): add naive interface about grpo/rloo
1 parent 2a51392 commit 2e49437

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

ding/rl_utils/grpo.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Optional, Tuple
2+
from collections import namedtuple
3+
import torch
4+
import torch.nn as nn
5+
from torch.distributions import Independent, Normal
6+
7+
grpo_policy_data = namedtuple('grpo_policy_data', ['logit_new', 'logit_old', 'logit_ref', 'action', 'adv', 'weight'])
8+
9+
10+
def grpo_policy_error(
11+
data: namedtuple,
12+
clip_ratio: float = 0.2,
13+
) -> Tuple[namedtuple, namedtuple]:
14+
"""
15+
.. note::
16+
Each element in this input data is a group of response samples from the same prompt.
17+
"""
18+
raise NotImplementedError

ding/rl_utils/ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def ppo_policy_error(
142142
>>> loss, info = ppo_policy_error(data)
143143
144144
.. note::
145-
This function can be extended from `B` to more parallel dimensions, like `(B, S)`, where `S` is the
145+
This function can be extended from `B` to more parallel dimensions, like `(B, S)`, where `S` is the
146146
sequence length in LLM/VLM.
147147
148148
.. note::

ding/rl_utils/rloo.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Optional, Tuple
2+
from collections import namedtuple
3+
import torch
4+
import torch.nn as nn
5+
from torch.distributions import Independent, Normal
6+
7+
rloo_policy_data = namedtuple('rloo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight'])
8+
9+
10+
def rloo_policy_error(
11+
data: namedtuple,
12+
clip_ratio: float = 0.2,
13+
) -> Tuple[namedtuple, namedtuple]:
14+
"""
15+
.. note::
16+
Each element in this input data is a group of response samples from the same prompt.
17+
"""
18+
raise NotImplementedError

0 commit comments

Comments
 (0)