Skip to content

Commit 428c1a6

Browse files
committed
style(nyz): add comments for rlhf dataset
1 parent 6e0e14b commit 428c1a6

File tree

2 files changed

+69
-10
lines changed

2 files changed

+69
-10
lines changed

ding/utils/data/rlhf_offline_dataset.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,23 @@
33
from tqdm import tqdm
44
from torch.utils.data import Dataset
55
from torch.distributed import get_rank
6+
from transformers import AutoTokenizer
67
import torch
78
import torch.nn.functional as F
89

910

1011
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left", value: int = 0) -> torch.Tensor:
11-
assert side in ("left", "right")
12+
"""
13+
Overview:
14+
Pad sequences with zeros to create a batch tensor of uniform length.
15+
Arguments:
16+
- sequences (List[torch.Tensor]): A list of PyTorch tensors to be padded.
17+
- side (str): The side to pad ('left' or 'right'), default is 'left'.
18+
- value (int): The padding value to use, default is 0.
19+
Returns:
20+
- padded_sequences (torch.Tensor): A padded tensor of shape [batch_size, max_sequence_length].
21+
"""
22+
assert side in ("left", "right"), side
1223
max_len = max(seq.size(-1) for seq in sequences)
1324
padded_sequences = []
1425
for seq in sequences:
@@ -22,12 +33,13 @@ class OfflineRLDataset(Dataset):
2233
"""
2334
Overview:
2435
PyTorch Dataset for OfflineRL LLM training like KTO and DPO.
36+
This dataset supports pure text input, as well as image, video, audio, etc.
2537
"""
2638

2739
def __init__(
2840
self,
2941
dataset: Iterable[Dict],
30-
tokenizer,
42+
tokenizer: AutoTokenizer,
3143
max_length: int,
3244
input_key: str = "input",
3345
extra_input_keys: List[str] = [],
@@ -39,6 +51,24 @@ def __init__(
3951
num_processors: int = 8,
4052
parallel_load: bool = True
4153
) -> None:
54+
"""
55+
Overview:
56+
Initialize the OfflineRLDataset.
57+
Arguments:
58+
- dataset (Iterable[Dict]): The iterable dataset object to be used, such as list or huggingface dataset.
59+
- tokenizer (AutoTokenizer): The tokenizer to be used.
60+
- max_length (int): The maximum length of the input.
61+
- input_key (str): The key of the input, default is "input".
62+
- extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc.
63+
- output_key (str): The key of the output, default is "output".
64+
- label_key (str): The key of the label, default is "label".
65+
- apply_chat_template (bool): Whether to apply the chat template, default is False.
66+
- tokenizer_chat_template (str): The chat template to be used.
67+
- input_template (str): The input template to be used.
68+
- num_processors (int): The number of processors to be used, default is 8.
69+
- parallel_load (bool): Whether to parallel load the dataset in the `__init__` method, default is True.
70+
Parallel loading is usually used for huggingface dataset.
71+
"""
4272
super().__init__()
4373
self.tokenizer = tokenizer
4474
self.max_length = max_length
@@ -98,7 +128,23 @@ def _preprocess_data(
98128
output_key: str = "output",
99129
label_key: str = "label",
100130
apply_chat_template: Union[bool, Callable] = False,
101-
) -> str:
131+
) -> Dict[str, Any]:
132+
"""
133+
Overview:
134+
Preprocess the data and return the processed data.
135+
Arguments:
136+
- data (Dict[str, Any]): The data to be processed.
137+
- input_template (str): The input template to be used.
138+
- input_key (str): The key of the input, default is "input".
139+
- extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc.
140+
- output_key (str): The key of the output, default is "output".
141+
- label_key (str): The key of the label, default is "label".
142+
- apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \
143+
tokenizer's default template. If a Callable is provided, uses that function to apply the template \
144+
(typically tokenizer.apply_chat_template).
145+
Returns:
146+
- processed_data (Dict[str, Any]): The processed data.
147+
"""
102148
label = data[label_key]
103149
if extra_input_keys:
104150
extra_inputs = {key: data[key] for key in extra_input_keys}
@@ -160,6 +206,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
160206
Returns:
161207
- item (Dict[str, Union[torch.Tensor, int]]): The item at the given index.
162208
"""
209+
# extra inputs: usually image, video, audio, etc.
163210
if self.extra_input_keys:
164211
extra_inputs = {key: getattr(self, key)[idx] for key in self.extra_input_keys}
165212
else:
@@ -173,6 +220,14 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
173220
}
174221

175222
def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]):
223+
"""
224+
Overview:
225+
Collate the items into a batch, which is used to create a batch for training.
226+
Arguments:
227+
- item_list (List[Dict[str, Union[torch.Tensor, int]]]): The list of items to be collated.
228+
Returns:
229+
- collated_items (Dict[str, Union[torch.Tensor, int]]): The collated items.
230+
"""
176231

177232
def tokenizer(prompt: str, response: str):
178233
text = (prompt + response).rstrip("\n")
@@ -191,6 +246,7 @@ def tokenizer(prompt: str, response: str):
191246
inputs["attention_mask"][0][-1] = True
192247
return inputs["input_ids"], inputs["attention_mask"]
193248

249+
# tot_extra_inputs: Dict[str, List[torch.Tensor]]
194250
tot_ids, tot_masks, tot_labels, prompt_ids_lens, tot_extra_inputs = [], [], [], [], {}
195251
for item in item_list:
196252
input_ids, attention_mask = tokenizer(item["prompt"], item["response"])

ding/utils/data/rlhf_online_dataset.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,20 @@
22
from tqdm import tqdm
33
from torch.utils.data import Dataset
44
from torch.distributed import get_rank
5+
from transformers import AutoTokenizer
56

67

78
class OnlineRLDataset(Dataset):
89
"""
910
Overview:
1011
PyTorch Dataset for OnlineRL LLM training like PPO.
12+
This dataset only supports pure text input now.
1113
"""
1214

1315
def __init__(
1416
self,
1517
dataset: Iterable[Dict],
16-
tokenizer,
18+
tokenizer: AutoTokenizer,
1719
input_key: str = "input",
1820
apply_chat_template: bool = False,
1921
input_template: str = None,
@@ -23,9 +25,9 @@ def __init__(
2325
Initialize the OnlineRLDataset.
2426
Arguments:
2527
- dataset (torch.utils.data.Dataset): The dataset to preprocess.
26-
- tokenizer (): The tokenizer to preprocess the data.
27-
- input_key (str): The key of the input data.
28-
- apply_chat_template (bool): Whether to apply the chat template.
28+
- tokenizer (AutoTokenizer): The tokenizer to preprocess the data.
29+
- input_key (str): The key of the input data, default is "input".
30+
- apply_chat_template (bool): Whether to apply the chat template, default is False.
2931
- input_template (str): The template to format the data.
3032
"""
3133
super().__init__()
@@ -57,7 +59,7 @@ def __getitem__(self, idx: int) -> str:
5759
"""
5860
Overview:
5961
Get the item at the given index.
60-
Args:
62+
Arguments:
6163
- idx (int): The index of the item to get.
6264
Returns:
6365
- item (str): The item at the given index.
@@ -78,8 +80,9 @@ def _preprocess_data(
7880
- data (Dict[str, Any]): The data to preprocess.
7981
- input_template (str): The template to format the data.
8082
- input_key (str): The key of the input data.
81-
- apply_chat_template (Union[bool, Callable]): The function to apply the chat template, \
82-
usually is the `tokenizer.apply_chat_template`.
83+
- apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \
84+
tokenizer's default template. If a Callable is provided, uses that function to apply the template \
85+
(typically tokenizer.apply_chat_template).
8386
Returns:
8487
- prompt (str): The formatted prompt.
8588
"""

0 commit comments

Comments
 (0)