Skip to content

Commit abcf972

Browse files
authored
feature(nyz): add rlhf dataset (#854)
* feature(nyz): add rlhf dataset * fix(nyz): fix import bugs * feature(nyz): add vision input support and fix bugs * style(nyz): add comments for rlhf dataset
1 parent 64efcb3 commit abcf972

File tree

6 files changed

+532
-0
lines changed

6 files changed

+532
-0
lines changed

ding/utils/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
from .dataloader import AsyncDataLoader
33
from .dataset import NaiveRLDataset, D4RLDataset, HDF5Dataset, BCODataset, \
44
create_dataset, hdf5_save, offline_data_save_type
5+
from .rlhf_online_dataset import OnlineRLDataset
6+
from .rlhf_offline_dataset import OfflineRLDataset
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
from typing import Iterable, Dict, List, Union, Any, Callable
2+
from functools import partial
3+
from tqdm import tqdm
4+
from torch.utils.data import Dataset
5+
from torch.distributed import get_rank
6+
from transformers import AutoTokenizer
7+
import torch
8+
import torch.nn.functional as F
9+
10+
11+
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left", value: int = 0) -> torch.Tensor:
12+
"""
13+
Overview:
14+
Pad sequences with zeros to create a batch tensor of uniform length.
15+
Arguments:
16+
- sequences (List[torch.Tensor]): A list of PyTorch tensors to be padded.
17+
- side (str): The side to pad ('left' or 'right'), default is 'left'.
18+
- value (int): The padding value to use, default is 0.
19+
Returns:
20+
- padded_sequences (torch.Tensor): A padded tensor of shape [batch_size, max_sequence_length].
21+
"""
22+
assert side in ("left", "right"), side
23+
max_len = max(seq.size(-1) for seq in sequences)
24+
padded_sequences = []
25+
for seq in sequences:
26+
pad_len = max_len - seq.size(-1)
27+
padding = (pad_len, 0) if side == "left" else (0, pad_len)
28+
padded_sequences.append(F.pad(seq, padding, value=value))
29+
return torch.stack(padded_sequences, dim=0)
30+
31+
32+
class OfflineRLDataset(Dataset):
33+
"""
34+
Overview:
35+
PyTorch Dataset for OfflineRL LLM training like KTO and DPO.
36+
This dataset supports pure text input, as well as image, video, audio, etc.
37+
"""
38+
39+
def __init__(
40+
self,
41+
dataset: Iterable[Dict],
42+
tokenizer: AutoTokenizer,
43+
max_length: int,
44+
input_key: str = "input",
45+
extra_input_keys: List[str] = [],
46+
output_key: str = "output",
47+
label_key: str = "label",
48+
apply_chat_template: bool = False,
49+
tokenizer_chat_template: str = None,
50+
input_template: str = None,
51+
num_processors: int = 8,
52+
parallel_load: bool = True
53+
) -> None:
54+
"""
55+
Overview:
56+
Initialize the OfflineRLDataset.
57+
Arguments:
58+
- dataset (Iterable[Dict]): The iterable dataset object to be used, such as list or huggingface dataset.
59+
- tokenizer (AutoTokenizer): The tokenizer to be used.
60+
- max_length (int): The maximum length of the input.
61+
- input_key (str): The key of the input, default is "input".
62+
- extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc.
63+
- output_key (str): The key of the output, default is "output".
64+
- label_key (str): The key of the label, default is "label".
65+
- apply_chat_template (bool): Whether to apply the chat template, default is False.
66+
- tokenizer_chat_template (str): The chat template to be used.
67+
- input_template (str): The input template to be used.
68+
- num_processors (int): The number of processors to be used, default is 8.
69+
- parallel_load (bool): Whether to parallel load the dataset in the `__init__` method, default is True.
70+
Parallel loading is usually used for huggingface dataset.
71+
"""
72+
super().__init__()
73+
self.tokenizer = tokenizer
74+
self.max_length = max_length
75+
self.extra_input_keys = extra_input_keys
76+
77+
if apply_chat_template:
78+
apply_chat_template = self.tokenizer.apply_chat_template
79+
if tokenizer_chat_template:
80+
self.tokenizer.chat_template = tokenizer_chat_template
81+
82+
# Parallel loading datasets
83+
if parallel_load:
84+
preprocess_data_fn = partial(
85+
self._preprocess_data,
86+
input_template=input_template,
87+
input_key=input_key,
88+
extra_input_keys=extra_input_keys,
89+
output_key=output_key,
90+
label_key=label_key,
91+
apply_chat_template=apply_chat_template
92+
)
93+
processed_dataset = dataset.map(
94+
preprocess_data_fn, remove_columns=dataset.column_names, num_proc=num_processors
95+
)
96+
# preprocess function may return None, so filter out the None
97+
processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None)
98+
99+
self.prompts = processed_dataset["prompt"]
100+
self.responses = processed_dataset["response"]
101+
self.labels = processed_dataset["label"]
102+
self.prompt_ids_lens = processed_dataset["prompt_ids_len"]
103+
for key in extra_input_keys:
104+
setattr(self, key, processed_dataset[key])
105+
else:
106+
self.prompts = []
107+
self.responses = []
108+
self.labels = []
109+
self.prompt_ids_lens = []
110+
for key in extra_input_keys:
111+
setattr(self, key, [])
112+
for data in tqdm(dataset, desc="Preprocessing data", disable=not get_rank() == 0):
113+
processed_data = self._preprocess_data(data)
114+
if processed_data["prompt"] is not None:
115+
self.prompts.append(processed_data["prompt"])
116+
self.responses.append(processed_data["response"])
117+
self.labels.append(processed_data["label"])
118+
self.prompt_ids_lens.append(processed_data["prompt_ids_len"])
119+
for key in extra_input_keys:
120+
getattr(self, key).append(processed_data[key])
121+
122+
def _preprocess_data(
123+
self,
124+
data: Dict[str, Any],
125+
input_template: str = None,
126+
input_key: str = "input",
127+
extra_input_keys: List[str] = [],
128+
output_key: str = "output",
129+
label_key: str = "label",
130+
apply_chat_template: Union[bool, Callable] = False,
131+
) -> Dict[str, Any]:
132+
"""
133+
Overview:
134+
Preprocess the data and return the processed data.
135+
Arguments:
136+
- data (Dict[str, Any]): The data to be processed.
137+
- input_template (str): The input template to be used.
138+
- input_key (str): The key of the input, default is "input".
139+
- extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc.
140+
- output_key (str): The key of the output, default is "output".
141+
- label_key (str): The key of the label, default is "label".
142+
- apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \
143+
tokenizer's default template. If a Callable is provided, uses that function to apply the template \
144+
(typically tokenizer.apply_chat_template).
145+
Returns:
146+
- processed_data (Dict[str, Any]): The processed data.
147+
"""
148+
label = data[label_key]
149+
if extra_input_keys:
150+
extra_inputs = {key: data[key] for key in extra_input_keys}
151+
else:
152+
extra_inputs = {}
153+
154+
if apply_chat_template:
155+
if output_key:
156+
prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True)
157+
response = apply_chat_template(data[input_key] + data[output_key], tokenize=False)[len(prompt):]
158+
else:
159+
prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True)
160+
response = apply_chat_template(data[input_key], tokenize=False)[len(prompt):]
161+
else:
162+
prompt = data[input_key]
163+
response = data[output_key]
164+
if input_template:
165+
prompt = input_template.format(prompt)
166+
167+
prompt_token = self.tokenizer(
168+
prompt,
169+
max_length=self.max_length,
170+
# use the batch max length (in `collate_fn`) to pad rather than the global max length
171+
padding=False,
172+
truncation=True,
173+
return_tensors="pt",
174+
# add special tokens for the prompt in `collate_fn`
175+
add_special_tokens=False,
176+
)
177+
prompt_ids_len = prompt_token["attention_mask"].int().sum().item()
178+
179+
# filter the sample whose length is greater than max_length (2 for answer length)
180+
if prompt_ids_len >= self.max_length - 2:
181+
prompt = None
182+
183+
return {
184+
"prompt": prompt,
185+
"response": response,
186+
"label": label,
187+
"prompt_ids_len": prompt_ids_len,
188+
**extra_inputs
189+
}
190+
191+
def __len__(self) -> int:
192+
"""
193+
Overview:
194+
Get the length of the dataset.
195+
Returns:
196+
- length (int): The length of the dataset.
197+
"""
198+
return len(self.prompts)
199+
200+
def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
201+
"""
202+
Overview:
203+
Get the item at the given index.
204+
Arguments:
205+
- idx (int): The index of the item to get.
206+
Returns:
207+
- item (Dict[str, Union[torch.Tensor, int]]): The item at the given index.
208+
"""
209+
# extra inputs: usually image, video, audio, etc.
210+
if self.extra_input_keys:
211+
extra_inputs = {key: getattr(self, key)[idx] for key in self.extra_input_keys}
212+
else:
213+
extra_inputs = {}
214+
return {
215+
"prompt": self.prompts[idx],
216+
"response": self.responses[idx],
217+
"label": self.labels[idx],
218+
"prompt_ids_len": self.prompt_ids_lens[idx],
219+
**extra_inputs
220+
}
221+
222+
def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]):
223+
"""
224+
Overview:
225+
Collate the items into a batch, which is used to create a batch for training.
226+
Arguments:
227+
- item_list (List[Dict[str, Union[torch.Tensor, int]]]): The list of items to be collated.
228+
Returns:
229+
- collated_items (Dict[str, Union[torch.Tensor, int]]): The collated items.
230+
"""
231+
232+
def tokenizer(prompt: str, response: str):
233+
text = (prompt + response).rstrip("\n")
234+
if not text.endswith(self.tokenizer.eos_token):
235+
text += " " + self.tokenizer.eos_token
236+
inputs = self.tokenizer(
237+
text,
238+
max_length=self.max_length,
239+
padding=False,
240+
truncation=True,
241+
return_tensors="pt",
242+
add_special_tokens=False,
243+
)
244+
245+
inputs["input_ids"][0][-1] = self.tokenizer.eos_token_id
246+
inputs["attention_mask"][0][-1] = True
247+
return inputs["input_ids"], inputs["attention_mask"]
248+
249+
# tot_extra_inputs: Dict[str, List[torch.Tensor]]
250+
tot_ids, tot_masks, tot_labels, prompt_ids_lens, tot_extra_inputs = [], [], [], [], {}
251+
for item in item_list:
252+
input_ids, attention_mask = tokenizer(item["prompt"], item["response"])
253+
tot_ids.append(input_ids)
254+
tot_masks.append(attention_mask)
255+
tot_labels.append(item["label"])
256+
prompt_ids_lens.append(item["prompt_ids_len"])
257+
for key in self.extra_input_keys:
258+
if key not in tot_extra_inputs:
259+
tot_extra_inputs[key] = []
260+
tot_extra_inputs[key].append(item[key])
261+
262+
# add unmatched y'| x (used to estimate the KL divergence between policy and reference)
263+
for idx in range(len(item_list)):
264+
next_idx = (idx + 1) % len(item_list)
265+
input_ids, attention_mask = tokenizer(item_list[idx]["prompt"], item_list[next_idx]["response"])
266+
tot_ids.append(input_ids)
267+
tot_masks.append(attention_mask)
268+
tot_labels.append(-1)
269+
prompt_ids_lens.append(item_list[idx]["prompt_ids_len"])
270+
for key in self.extra_input_keys:
271+
if key not in tot_extra_inputs:
272+
tot_extra_inputs[key] = []
273+
tot_extra_inputs[key].append(item_list[idx][key])
274+
275+
input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id)
276+
attention_mask = zero_pad_sequences(tot_masks, side="right")
277+
return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens, tot_extra_inputs
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import Any, Dict, Union, Callable, Iterable
2+
from tqdm import tqdm
3+
from torch.utils.data import Dataset
4+
from torch.distributed import get_rank
5+
from transformers import AutoTokenizer
6+
7+
8+
class OnlineRLDataset(Dataset):
9+
"""
10+
Overview:
11+
PyTorch Dataset for OnlineRL LLM training like PPO.
12+
This dataset only supports pure text input now.
13+
"""
14+
15+
def __init__(
16+
self,
17+
dataset: Iterable[Dict],
18+
tokenizer: AutoTokenizer,
19+
input_key: str = "input",
20+
apply_chat_template: bool = False,
21+
input_template: str = None,
22+
) -> None:
23+
"""
24+
Overview:
25+
Initialize the OnlineRLDataset.
26+
Arguments:
27+
- dataset (torch.utils.data.Dataset): The dataset to preprocess.
28+
- tokenizer (AutoTokenizer): The tokenizer to preprocess the data.
29+
- input_key (str): The key of the input data, default is "input".
30+
- apply_chat_template (bool): Whether to apply the chat template, default is False.
31+
- input_template (str): The template to format the data.
32+
"""
33+
super().__init__()
34+
self.tokenizer = tokenizer
35+
self.input_template = input_template
36+
37+
if apply_chat_template:
38+
apply_chat_template = self.tokenizer.apply_chat_template
39+
40+
self.prompts = []
41+
try:
42+
rank = get_rank()
43+
except ValueError: # not initialized yet, which is the case in unit test
44+
rank = 0
45+
for data in tqdm(dataset, desc="Preprocessing data", disable=not rank == 0):
46+
prompt = self._preprocess_data(data, input_template, input_key, apply_chat_template)
47+
self.prompts.append(prompt)
48+
49+
def __len__(self) -> int:
50+
"""
51+
Overview:
52+
Get the length of the dataset.
53+
Returns:
54+
- length (int): The length of the dataset.
55+
"""
56+
return len(self.prompts)
57+
58+
def __getitem__(self, idx: int) -> str:
59+
"""
60+
Overview:
61+
Get the item at the given index.
62+
Arguments:
63+
- idx (int): The index of the item to get.
64+
Returns:
65+
- item (str): The item at the given index.
66+
"""
67+
return self.prompts[idx]
68+
69+
def _preprocess_data(
70+
self,
71+
data: Dict[str, Any],
72+
input_template: str = None,
73+
input_key: str = "input",
74+
apply_chat_template: Union[bool, Callable] = False,
75+
) -> str:
76+
"""
77+
Overview:
78+
Preprocess the data to get the formatted prompt.
79+
Arguments:
80+
- data (Dict[str, Any]): The data to preprocess.
81+
- input_template (str): The template to format the data.
82+
- input_key (str): The key of the input data.
83+
- apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \
84+
tokenizer's default template. If a Callable is provided, uses that function to apply the template \
85+
(typically tokenizer.apply_chat_template).
86+
Returns:
87+
- prompt (str): The formatted prompt.
88+
"""
89+
if apply_chat_template:
90+
chat = data[input_key]
91+
if isinstance(chat, str):
92+
chat = [{"role": "user", "content": chat}]
93+
assert isinstance(chat, list) and all(isinstance(t, dict) for t in chat), "chat must be a list of dict"
94+
prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
95+
else:
96+
prompt = data[input_key]
97+
if input_template:
98+
prompt = input_template.format(prompt)
99+
return prompt

0 commit comments

Comments
 (0)