@@ -30,6 +30,7 @@ def __init__(
30
30
tokenizer ,
31
31
max_length : int ,
32
32
input_key : str = "input" ,
33
+ extra_input_keys : List [str ] = [],
33
34
output_key : str = "output" ,
34
35
label_key : str = "label" ,
35
36
apply_chat_template : bool = False ,
@@ -41,6 +42,7 @@ def __init__(
41
42
super ().__init__ ()
42
43
self .tokenizer = tokenizer
43
44
self .max_length = max_length
45
+ self .extra_input_keys = extra_input_keys
44
46
45
47
if apply_chat_template :
46
48
apply_chat_template = self .tokenizer .apply_chat_template
@@ -53,6 +55,7 @@ def __init__(
53
55
self ._preprocess_data ,
54
56
input_template = input_template ,
55
57
input_key = input_key ,
58
+ extra_input_keys = extra_input_keys ,
56
59
output_key = output_key ,
57
60
label_key = label_key ,
58
61
apply_chat_template = apply_chat_template
@@ -67,29 +70,40 @@ def __init__(
67
70
self .responses = processed_dataset ["response" ]
68
71
self .labels = processed_dataset ["label" ]
69
72
self .prompt_ids_lens = processed_dataset ["prompt_ids_len" ]
73
+ for key in extra_input_keys :
74
+ setattr (self , key , processed_dataset [key ])
70
75
else :
71
76
self .prompts = []
72
77
self .responses = []
73
78
self .labels = []
74
79
self .prompt_ids_lens = []
80
+ for key in extra_input_keys :
81
+ setattr (self , key , [])
75
82
for data in tqdm (dataset , desc = "Preprocessing data" , disable = not get_rank () == 0 ):
76
83
processed_data = self ._preprocess_data (data )
77
84
if processed_data ["prompt" ] is not None :
78
85
self .prompts .append (processed_data ["prompt" ])
79
86
self .responses .append (processed_data ["response" ])
80
87
self .labels .append (processed_data ["label" ])
81
88
self .prompt_ids_lens .append (processed_data ["prompt_ids_len" ])
89
+ for key in extra_input_keys :
90
+ getattr (self , key ).append (processed_data [key ])
82
91
83
92
def _preprocess_data (
84
93
self ,
85
94
data : Dict [str , Any ],
86
95
input_template : str = None ,
87
96
input_key : str = "input" ,
97
+ extra_input_keys : List [str ] = [],
88
98
output_key : str = "output" ,
89
99
label_key : str = "label" ,
90
100
apply_chat_template : Union [bool , Callable ] = False ,
91
101
) -> str :
92
102
label = data [label_key ]
103
+ if extra_input_keys :
104
+ extra_inputs = {key : data [key ] for key in extra_input_keys }
105
+ else :
106
+ extra_inputs = {}
93
107
94
108
if apply_chat_template :
95
109
if output_key :
@@ -120,7 +134,13 @@ def _preprocess_data(
120
134
if prompt_ids_len >= self .max_length - 2 :
121
135
prompt = None
122
136
123
- return {"prompt" : prompt , "response" : response , "label" : label , "prompt_ids_len" : prompt_ids_len }
137
+ return {
138
+ "prompt" : prompt ,
139
+ "response" : response ,
140
+ "label" : label ,
141
+ "prompt_ids_len" : prompt_ids_len ,
142
+ ** extra_inputs
143
+ }
124
144
125
145
def __len__ (self ) -> int :
126
146
"""
@@ -135,14 +155,21 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
135
155
"""
136
156
Overview:
137
157
Get the item at the given index.
158
+ Arguments:
159
+ - idx (int): The index of the item to get.
138
160
Returns:
139
161
- item (Dict[str, Union[torch.Tensor, int]]): The item at the given index.
140
162
"""
163
+ if self .extra_input_keys :
164
+ extra_inputs = {key : getattr (self , key )[idx ] for key in self .extra_input_keys }
165
+ else :
166
+ extra_inputs = {}
141
167
return {
142
168
"prompt" : self .prompts [idx ],
143
169
"response" : self .responses [idx ],
144
170
"label" : self .labels [idx ],
145
- "prompt_ids_len" : self .prompt_ids_lens [idx ]
171
+ "prompt_ids_len" : self .prompt_ids_lens [idx ],
172
+ ** extra_inputs
146
173
}
147
174
148
175
def collate_fn (self , item_list : List [Dict [str , Union [torch .Tensor , int ]]]):
@@ -164,13 +191,17 @@ def tokenizer(prompt: str, response: str):
164
191
inputs ["attention_mask" ][0 ][- 1 ] = True
165
192
return inputs ["input_ids" ], inputs ["attention_mask" ]
166
193
167
- tot_ids , tot_masks , tot_labels , prompt_ids_lens = [], [], [], []
194
+ tot_ids , tot_masks , tot_labels , prompt_ids_lens , tot_extra_inputs = [], [], [], [], {}
168
195
for item in item_list :
169
196
input_ids , attention_mask = tokenizer (item ["prompt" ], item ["response" ])
170
197
tot_ids .append (input_ids )
171
198
tot_masks .append (attention_mask )
172
199
tot_labels .append (item ["label" ])
173
200
prompt_ids_lens .append (item ["prompt_ids_len" ])
201
+ for key in self .extra_input_keys :
202
+ if key not in tot_extra_inputs :
203
+ tot_extra_inputs [key ] = []
204
+ tot_extra_inputs [key ].append (item [key ])
174
205
175
206
# add unmatched y'| x (used to estimate the KL divergence between policy and reference)
176
207
for idx in range (len (item_list )):
@@ -180,7 +211,11 @@ def tokenizer(prompt: str, response: str):
180
211
tot_masks .append (attention_mask )
181
212
tot_labels .append (- 1 )
182
213
prompt_ids_lens .append (item_list [idx ]["prompt_ids_len" ])
214
+ for key in self .extra_input_keys :
215
+ if key not in tot_extra_inputs :
216
+ tot_extra_inputs [key ] = []
217
+ tot_extra_inputs [key ].append (item_list [idx ][key ])
183
218
184
219
input_ids = zero_pad_sequences (tot_ids , side = "right" , value = self .tokenizer .pad_token_id )
185
220
attention_mask = zero_pad_sequences (tot_masks , side = "right" )
186
- return input_ids , attention_mask , torch .LongTensor (tot_labels ), prompt_ids_lens
221
+ return input_ids , attention_mask , torch .LongTensor (tot_labels ), prompt_ids_lens , tot_extra_inputs
0 commit comments