Skip to content

Commit d125ef5

Browse files
committed
fix #1494
1 parent 3743b74 commit d125ef5

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

src/llmtuner/dsets/preprocess.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import tiktoken
33
from itertools import chain
4-
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
4+
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
55

66
from datasets import load_from_disk
77

@@ -19,6 +19,22 @@
1919
logger = get_logger(__name__)
2020

2121

22+
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
23+
for i in range(len(examples["prompt"])):
24+
query, response = examples["prompt"][i], examples["response"][i]
25+
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
26+
history = examples["history"][i] if "history" in examples else None
27+
system = examples["system"][i] if "system" in examples else None
28+
yield query, response, history, system
29+
30+
31+
def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
32+
max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
33+
max_target_len = max(max_target_len, data_args.reserved_label_len)
34+
max_source_len = data_args.cutoff_len - max_target_len
35+
return max_source_len, max_target_len
36+
37+
2238
def preprocess_dataset(
2339
dataset: Union["Dataset", "IterableDataset"],
2440
tokenizer: "PreTrainedTokenizer",
@@ -31,14 +47,6 @@ def preprocess_dataset(
3147
if data_args.train_on_prompt and template.efficient_eos:
3248
raise ValueError("Current template does not support `train_on_prompt`.")
3349

34-
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
35-
for i in range(len(examples["prompt"])):
36-
query, response = examples["prompt"][i], examples["response"][i]
37-
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
38-
history = examples["history"][i] if "history" in examples else None
39-
system = examples["system"][i] if "system" in examples else None
40-
yield query, response, history, system
41-
4250
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
4351
# build grouped texts with format `X1 X2 X3 ...`
4452
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
@@ -79,13 +87,11 @@ def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, L
7987
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
8088
tokenizer, query, response, history, system
8189
)):
82-
total_len = len(source_ids) + len(target_ids)
83-
max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
84-
max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
85-
86-
if len(source_ids) > max_source_len:
90+
source_len, target_len = len(source_ids), len(target_ids)
91+
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
92+
if source_len > max_source_len:
8793
source_ids = source_ids[:max_source_len]
88-
if len(target_ids) > max_target_len:
94+
if target_len > max_target_len:
8995
target_ids = target_ids[:max_target_len]
9096

9197
if data_args.train_on_prompt:
@@ -187,15 +193,12 @@ def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Lis
187193
chosen_ids += [tokenizer.eos_token_id]
188194
rejected_ids += [tokenizer.eos_token_id]
189195

190-
total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
191-
max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
192-
max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
193-
194-
if len(prompt_ids) > max_source_len:
196+
source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
197+
max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
198+
if source_len > max_source_len:
195199
prompt_ids = prompt_ids[:max_source_len]
196-
if len(chosen_ids) > max_target_len:
200+
if target_len > max_target_len:
197201
chosen_ids = chosen_ids[:max_target_len]
198-
if len(rejected_ids) > max_target_len:
199202
rejected_ids = rejected_ids[:max_target_len]
200203

201204
model_inputs["prompt_ids"].append(prompt_ids)

src/llmtuner/hparams/data_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ class DataArguments:
5252
default=1024,
5353
metadata={"help": "The maximum length of the model inputs after tokenization."}
5454
)
55+
reserved_label_len: Optional[int] = field(
56+
default=1,
57+
metadata={"help": "The maximum length reserved for label after tokenization."}
58+
)
5559
train_on_prompt: Optional[bool] = field(
5660
default=False,
5761
metadata={"help": "Whether to disable the mask on the prompt or not."}
@@ -110,6 +114,9 @@ class DataArguments:
110114
)
111115

112116
def __post_init__(self):
117+
if self.reserved_label_len >= self.cutoff_len:
118+
raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
119+
113120
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
114121
raise ValueError("Streaming mode should have an integer val size.")
115122

0 commit comments

Comments
 (0)