1
1
import os
2
2
import tiktoken
3
3
from itertools import chain
4
- from typing import TYPE_CHECKING , Any , Dict , Generator , List , Literal , Union
4
+ from typing import TYPE_CHECKING , Any , Dict , Generator , List , Literal , Tuple , Union
5
5
6
6
from datasets import load_from_disk
7
7
19
19
logger = get_logger (__name__ )
20
20
21
21
22
+ def construct_example (examples : Dict [str , List [Any ]]) -> Generator [Any , None , None ]:
23
+ for i in range (len (examples ["prompt" ])):
24
+ query , response = examples ["prompt" ][i ], examples ["response" ][i ]
25
+ query = query + "\n " + examples ["query" ][i ] if "query" in examples and examples ["query" ][i ] else query
26
+ history = examples ["history" ][i ] if "history" in examples else None
27
+ system = examples ["system" ][i ] if "system" in examples else None
28
+ yield query , response , history , system
29
+
30
+
31
+ def infer_max_len (source_len : int , target_len : int , data_args : "DataArguments" ) -> Tuple [int , int ]:
32
+ max_target_len = int (data_args .cutoff_len * (target_len / (source_len + target_len )))
33
+ max_target_len = max (max_target_len , data_args .reserved_label_len )
34
+ max_source_len = data_args .cutoff_len - max_target_len
35
+ return max_source_len , max_target_len
36
+
37
+
22
38
def preprocess_dataset (
23
39
dataset : Union ["Dataset" , "IterableDataset" ],
24
40
tokenizer : "PreTrainedTokenizer" ,
@@ -31,14 +47,6 @@ def preprocess_dataset(
31
47
if data_args .train_on_prompt and template .efficient_eos :
32
48
raise ValueError ("Current template does not support `train_on_prompt`." )
33
49
34
- def construct_example (examples : Dict [str , List [Any ]]) -> Generator [Any , None , None ]:
35
- for i in range (len (examples ["prompt" ])):
36
- query , response = examples ["prompt" ][i ], examples ["response" ][i ]
37
- query = query + "\n " + examples ["query" ][i ] if "query" in examples and examples ["query" ][i ] else query
38
- history = examples ["history" ][i ] if "history" in examples else None
39
- system = examples ["system" ][i ] if "system" in examples else None
40
- yield query , response , history , system
41
-
42
50
def preprocess_pretrain_dataset (examples : Dict [str , List [Any ]]) -> Dict [str , List [List [int ]]]:
43
51
# build grouped texts with format `X1 X2 X3 ...`
44
52
if isinstance (getattr (tokenizer , "tokenizer" , None ), tiktoken .Encoding ): # for tiktoken tokenizer (Qwen)
@@ -79,13 +87,11 @@ def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, L
79
87
for turn_idx , (source_ids , target_ids ) in enumerate (template .encode_multiturn (
80
88
tokenizer , query , response , history , system
81
89
)):
82
- total_len = len (source_ids ) + len (target_ids )
83
- max_source_len = int (data_args .cutoff_len * (len (source_ids ) / total_len ))
84
- max_target_len = int (data_args .cutoff_len * (len (target_ids ) / total_len ))
85
-
86
- if len (source_ids ) > max_source_len :
90
+ source_len , target_len = len (source_ids ), len (target_ids )
91
+ max_source_len , max_target_len = infer_max_len (source_len , target_len , data_args )
92
+ if source_len > max_source_len :
87
93
source_ids = source_ids [:max_source_len ]
88
- if len ( target_ids ) > max_target_len :
94
+ if target_len > max_target_len :
89
95
target_ids = target_ids [:max_target_len ]
90
96
91
97
if data_args .train_on_prompt :
@@ -187,15 +193,12 @@ def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Lis
187
193
chosen_ids += [tokenizer .eos_token_id ]
188
194
rejected_ids += [tokenizer .eos_token_id ]
189
195
190
- total_len = len (prompt_ids ) + max (len (chosen_ids ), len (rejected_ids ))
191
- max_source_len = int (data_args .cutoff_len * (len (prompt_ids ) / total_len ))
192
- max_target_len = int (data_args .cutoff_len * (max (len (chosen_ids ), len (rejected_ids )) / total_len ))
193
-
194
- if len (prompt_ids ) > max_source_len :
196
+ source_len , target_len = len (prompt_ids ), max (len (chosen_ids ), len (rejected_ids ))
197
+ max_source_len , max_target_len = infer_max_len (source_len , target_len , data_args )
198
+ if source_len > max_source_len :
195
199
prompt_ids = prompt_ids [:max_source_len ]
196
- if len ( chosen_ids ) > max_target_len :
200
+ if target_len > max_target_len :
197
201
chosen_ids = chosen_ids [:max_target_len ]
198
- if len (rejected_ids ) > max_target_len :
199
202
rejected_ids = rejected_ids [:max_target_len ]
200
203
201
204
model_inputs ["prompt_ids" ].append (prompt_ids )
0 commit comments