3
3
from tqdm import tqdm
4
4
from torch .utils .data import Dataset
5
5
from torch .distributed import get_rank
6
+ from transformers import AutoTokenizer
6
7
import torch
7
8
import torch .nn .functional as F
8
9
9
10
10
11
def zero_pad_sequences (sequences : List [torch .Tensor ], side : str = "left" , value : int = 0 ) -> torch .Tensor :
11
- assert side in ("left" , "right" )
12
+ """
13
+ Overview:
14
+ Pad sequences with zeros to create a batch tensor of uniform length.
15
+ Arguments:
16
+ - sequences (List[torch.Tensor]): A list of PyTorch tensors to be padded.
17
+ - side (str): The side to pad ('left' or 'right'), default is 'left'.
18
+ - value (int): The padding value to use, default is 0.
19
+ Returns:
20
+ - padded_sequences (torch.Tensor): A padded tensor of shape [batch_size, max_sequence_length].
21
+ """
22
+ assert side in ("left" , "right" ), side
12
23
max_len = max (seq .size (- 1 ) for seq in sequences )
13
24
padded_sequences = []
14
25
for seq in sequences :
@@ -22,12 +33,13 @@ class OfflineRLDataset(Dataset):
22
33
"""
23
34
Overview:
24
35
PyTorch Dataset for OfflineRL LLM training like KTO and DPO.
36
+ This dataset supports pure text input, as well as image, video, audio, etc.
25
37
"""
26
38
27
39
def __init__ (
28
40
self ,
29
41
dataset : Iterable [Dict ],
30
- tokenizer ,
42
+ tokenizer : AutoTokenizer ,
31
43
max_length : int ,
32
44
input_key : str = "input" ,
33
45
extra_input_keys : List [str ] = [],
@@ -39,6 +51,24 @@ def __init__(
39
51
num_processors : int = 8 ,
40
52
parallel_load : bool = True
41
53
) -> None :
54
+ """
55
+ Overview:
56
+ Initialize the OfflineRLDataset.
57
+ Arguments:
58
+ - dataset (Iterable[Dict]): The iterable dataset object to be used, such as list or huggingface dataset.
59
+ - tokenizer (AutoTokenizer): The tokenizer to be used.
60
+ - max_length (int): The maximum length of the input.
61
+ - input_key (str): The key of the input, default is "input".
62
+ - extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc.
63
+ - output_key (str): The key of the output, default is "output".
64
+ - label_key (str): The key of the label, default is "label".
65
+ - apply_chat_template (bool): Whether to apply the chat template, default is False.
66
+ - tokenizer_chat_template (str): The chat template to be used.
67
+ - input_template (str): The input template to be used.
68
+ - num_processors (int): The number of processors to be used, default is 8.
69
+ - parallel_load (bool): Whether to parallel load the dataset in the `__init__` method, default is True.
70
+ Parallel loading is usually used for huggingface dataset.
71
+ """
42
72
super ().__init__ ()
43
73
self .tokenizer = tokenizer
44
74
self .max_length = max_length
@@ -98,7 +128,23 @@ def _preprocess_data(
98
128
output_key : str = "output" ,
99
129
label_key : str = "label" ,
100
130
apply_chat_template : Union [bool , Callable ] = False ,
101
- ) -> str :
131
+ ) -> Dict [str , Any ]:
132
+ """
133
+ Overview:
134
+ Preprocess the data and return the processed data.
135
+ Arguments:
136
+ - data (Dict[str, Any]): The data to be processed.
137
+ - input_template (str): The input template to be used.
138
+ - input_key (str): The key of the input, default is "input".
139
+ - extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc.
140
+ - output_key (str): The key of the output, default is "output".
141
+ - label_key (str): The key of the label, default is "label".
142
+ - apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \
143
+ tokenizer's default template. If a Callable is provided, uses that function to apply the template \
144
+ (typically tokenizer.apply_chat_template).
145
+ Returns:
146
+ - processed_data (Dict[str, Any]): The processed data.
147
+ """
102
148
label = data [label_key ]
103
149
if extra_input_keys :
104
150
extra_inputs = {key : data [key ] for key in extra_input_keys }
@@ -160,6 +206,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
160
206
Returns:
161
207
- item (Dict[str, Union[torch.Tensor, int]]): The item at the given index.
162
208
"""
209
+ # extra inputs: usually image, video, audio, etc.
163
210
if self .extra_input_keys :
164
211
extra_inputs = {key : getattr (self , key )[idx ] for key in self .extra_input_keys }
165
212
else :
@@ -173,6 +220,14 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
173
220
}
174
221
175
222
def collate_fn (self , item_list : List [Dict [str , Union [torch .Tensor , int ]]]):
223
+ """
224
+ Overview:
225
+ Collate the items into a batch, which is used to create a batch for training.
226
+ Arguments:
227
+ - item_list (List[Dict[str, Union[torch.Tensor, int]]]): The list of items to be collated.
228
+ Returns:
229
+ - collated_items (Dict[str, Union[torch.Tensor, int]]): The collated items.
230
+ """
176
231
177
232
def tokenizer (prompt : str , response : str ):
178
233
text = (prompt + response ).rstrip ("\n " )
@@ -191,6 +246,7 @@ def tokenizer(prompt: str, response: str):
191
246
inputs ["attention_mask" ][0 ][- 1 ] = True
192
247
return inputs ["input_ids" ], inputs ["attention_mask" ]
193
248
249
+ # tot_extra_inputs: Dict[str, List[torch.Tensor]]
194
250
tot_ids , tot_masks , tot_labels , prompt_ids_lens , tot_extra_inputs = [], [], [], [], {}
195
251
for item in item_list :
196
252
input_ids , attention_mask = tokenizer (item ["prompt" ], item ["response" ])
0 commit comments