-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhugging_face_nmt_model_trainer.py
426 lines (372 loc) · 17.9 KB
/
hugging_face_nmt_model_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
import gc
import json
import logging
import os
import re
from pathlib import Path
from typing import Any, Callable, List, Optional, Union, cast
import torch # pyright: ignore[reportMissingImports]
from datasets.arrow_dataset import Dataset
from sacremoses import MosesPunctNormalizer
from torch import Tensor # pyright: ignore[reportMissingImports]
from torch.utils.checkpoint import checkpoint # pyright: ignore[reportMissingImports] # noqa: F401
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
DataCollatorForSeq2Seq,
M2M100ForConditionalGeneration,
M2M100Tokenizer,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
NllbTokenizer,
NllbTokenizerFast,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
TrainerCallback,
set_seed,
)
from transformers.models.mbart50 import MBart50Tokenizer
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import get_last_checkpoint
from transformers.training_args import TrainingArguments
from ...corpora.parallel_text_corpus import ParallelTextCorpus
from ...utils.progress_status import ProgressStatus
from ..trainer import Trainer, TrainStats
logger = logging.getLogger(__name__)
def prepare_decoder_input_ids_from_labels(self: M2M100ForConditionalGeneration, labels: Tensor) -> Tensor:
# shift ids to the right
shifted_input_ids = labels.new_zeros(labels.shape)
shifted_input_ids[:, 1:] = labels[:, :-1].clone()
assert self.config.decoder_start_token_id is not None
shifted_input_ids[:, 0] = self.config.decoder_start_token_id
if self.config.pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.config.pad_token_id)
return shifted_input_ids
setattr(
M2M100ForConditionalGeneration,
"prepare_decoder_input_ids_from_labels",
prepare_decoder_input_ids_from_labels,
)
MULTILINGUAL_TOKENIZERS = (
MBartTokenizer,
MBartTokenizerFast,
MBart50Tokenizer,
MBart50TokenizerFast,
M2M100Tokenizer,
NllbTokenizer,
NllbTokenizerFast,
)
class HuggingFaceNmtModelTrainer(Trainer):
def __init__(
self,
model: Union[PreTrainedModel, str],
training_args: Seq2SeqTrainingArguments,
corpus: Union[ParallelTextCorpus, Dataset],
src_lang: Optional[str] = None,
tgt_lang: Optional[str] = None,
max_src_length: Optional[int] = None,
max_tgt_length: Optional[int] = None,
add_unk_src_tokens: bool = False,
add_unk_tgt_tokens: bool = True,
) -> None:
self._model = model
self._training_args = training_args
self._corpus = corpus
self._src_lang = src_lang
self._tgt_lang = tgt_lang
self._trainer: Optional[Seq2SeqTrainer] = None
self._metrics = {}
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self._add_unk_src_tokens = add_unk_src_tokens
self._add_unk_tgt_tokens = add_unk_tgt_tokens
self._mpn = MosesPunctNormalizer()
self._mpn.substitutions = [ # type: ignore
(re.compile(r), sub) for r, sub in self._mpn.substitutions if isinstance(r, str) and isinstance(sub, str)
]
self._stats = TrainStats()
@property
def stats(self) -> TrainStats:
return self._stats
def train(
self,
progress: Optional[Callable[[ProgressStatus], None]] = None,
check_canceled: Optional[Callable[[], None]] = None,
) -> None:
last_checkpoint = None
if self._training_args.output_dir is None:
raise ValueError("Output directory is not set")
if os.path.isdir(self._training_args.output_dir) and not self._training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(self._training_args.output_dir)
if last_checkpoint is None and any(os.path.isfile(p) for p in os.listdir(self._training_args.output_dir)):
raise ValueError(
f"Output directory ({self._training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and self._training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
# Set seed before initializing model.
set_seed(self._training_args.seed)
if isinstance(self._model, PreTrainedModel):
model = self._model
self._original_use_cache = model.config.use_cache
model.config.use_cache = not self._training_args.gradient_checkpointing
else:
config = AutoConfig.from_pretrained(
self._model,
use_cache=not self._training_args.gradient_checkpointing,
label2id={},
id2label={},
num_labels=0,
)
model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(self._model, config=config))
logger.info("Initializing tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True)
src_lang = self._src_lang
if src_lang is None:
src_lang = "src"
tgt_lang = self._tgt_lang
if tgt_lang is None:
tgt_lang = "tgt"
if isinstance(self._corpus, Dataset):
train_dataset = self._corpus
else:
train_dataset = self._corpus.filter_nonempty().to_hf_dataset(src_lang, tgt_lang)
def find_missing_characters(tokenizer: Any, train_dataset: Dataset, lang_codes: List[str]) -> List[str]:
vocab = tokenizer.get_vocab().keys()
charset = set()
mpn_normalize = True if isinstance(tokenizer, (NllbTokenizerFast)) else False
for ex in train_dataset["translation"]:
for lang_code in lang_codes:
ex_text = ex[lang_code]
if mpn_normalize:
ex_text = self._mpn.normalize(ex_text)
ex_text = tokenizer.backend_tokenizer.normalizer.normalize_str(ex_text)
charset = charset | set(ex_text)
charset = set(filter(None, {char.strip() for char in charset}))
missing_characters = sorted(list(charset - vocab))
return missing_characters
def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
if self._training_args.output_dir is None:
raise ValueError("Output directory is not set")
tokenizer_dir = Path(self._training_args.output_dir)
tokenizer.save_pretrained(str(tokenizer_dir))
with open(tokenizer_dir / "tokenizer.json", "r+", encoding="utf-8") as file:
data = json.load(file)
vocab_len = len(tokenizer)
if isinstance(data["model"]["vocab"], dict):
for i, token in enumerate(missing_tokens):
data["model"]["vocab"][token] = vocab_len + i
elif isinstance(data["model"]["vocab"], list):
for i, token in enumerate(missing_tokens):
data["model"]["vocab"].append([token, vocab_len + i])
file.seek(0)
json.dump(data, file, ensure_ascii=False, indent=4)
file.truncate()
logger.info(f"Added {len(missing_tokens)} tokens to the tokenizer: {missing_tokens}")
return AutoTokenizer.from_pretrained(str(tokenizer_dir), use_fast=True)
if self._add_unk_src_tokens or self._add_unk_tgt_tokens:
logger.info("Checking for missing tokens")
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
f"Tokenizer can not be updated from default configuration: \
tokenizer type {type(tokenizer)} is not an instance of PreTrainedTokenizerFast."
)
else:
norm_tok = PreTrainedTokenizerFast.from_pretrained(
str(Path(os.path.dirname(os.path.abspath(__file__))) / "custom_normalizer"),
use_fast=True,
)
# using unofficially supported behavior to set the normalizer
tokenizer.backend_tokenizer.normalizer = norm_tok.backend_tokenizer.normalizer # type: ignore
if self._add_unk_src_tokens and self._add_unk_tgt_tokens:
lang_codes = [src_lang, tgt_lang]
elif self._add_unk_src_tokens:
lang_codes = [src_lang]
else:
lang_codes = [tgt_lang]
missing_tokens = find_missing_characters(tokenizer, train_dataset, lang_codes)
if missing_tokens:
tokenizer = add_tokens(tokenizer, missing_tokens)
if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
logger.info("Add new language codes as tokens")
if self._src_lang is not None:
add_lang_code_to_tokenizer(tokenizer, self._src_lang)
if self._tgt_lang is not None:
add_lang_code_to_tokenizer(tokenizer, self._tgt_lang)
# We resize the embeddings only when necessary to avoid index errors.
embedding_size = cast(Any, model.get_input_embeddings()).weight.shape[0]
if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer))
# Set decoder_start_token_id
if (
self._tgt_lang is not None
and model.config.decoder_start_token_id is None
and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast))
):
if isinstance(tokenizer, MBartTokenizer):
model.config.decoder_start_token_id = tokenizer.lang_code_to_id[self._tgt_lang]
else:
model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(self._tgt_lang)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
# For translation we set the codes of our source and target languages (only useful for mBART, the others will
# ignore those attributes).
if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
if self._src_lang is None or self._tgt_lang is None:
raise ValueError(
f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires source_lang and "
"target_lang to be set."
)
tokenizer.src_lang = self._src_lang
tokenizer.tgt_lang = self._tgt_lang
# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
forced_bos_token_id = tokenizer.convert_tokens_to_ids(self._tgt_lang)
model.config.forced_bos_token_id = forced_bos_token_id
if model.generation_config is not None:
model.generation_config.forced_bos_token_id = forced_bos_token_id
prefix = ""
if model.name_or_path.startswith("t5-") or model.name_or_path.startswith("google/mt5-"):
prefix = f"translate {self._src_lang} to {self._tgt_lang}: "
max_src_length = self.max_src_length
if max_src_length is None:
max_src_length = model.config.max_length
max_tgt_length = self.max_tgt_length
if max_tgt_length is None:
max_tgt_length = model.config.max_length
if self._training_args.label_smoothing_factor > 0 and not hasattr(
model, "prepare_decoder_input_ids_from_labels"
):
logger.warning(
"label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more "
"memory"
)
def preprocess_function(examples):
if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
inputs = [self._mpn.normalize(prefix + ex[src_lang]) for ex in examples["translation"]]
targets = [self._mpn.normalize(ex[tgt_lang]) for ex in examples["translation"]]
else:
inputs = [prefix + ex[src_lang] for ex in examples["translation"]]
targets = [ex[tgt_lang] for ex in examples["translation"]]
model_inputs = tokenizer(inputs, max_length=max_src_length, truncation=True)
# Tokenize targets with the `text_target` keyword argument
labels = tokenizer(text_target=targets, max_length=max_tgt_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
logger.info("Run tokenizer")
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
remove_columns=train_dataset.column_names,
load_from_cache_file=True,
desc="Running tokenizer on train dataset",
)
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=-100,
pad_to_multiple_of=8 if self._training_args.fp16 else None,
)
self._trainer = Seq2SeqTrainer(
model=model,
args=self._training_args,
train_dataset=cast(Any, train_dataset),
processing_class=tokenizer,
data_collator=data_collator,
callbacks=[
_ProgressCallback(
self._training_args.max_steps if self._training_args.max_steps > 0 else None,
progress,
check_canceled,
)
],
)
logger.info("Train NMT model")
ckpt = None
if self._training_args.resume_from_checkpoint is not None:
ckpt = self._training_args.resume_from_checkpoint
elif last_checkpoint is not None:
ckpt = last_checkpoint
train_result = self._trainer.train(
resume_from_checkpoint=ckpt,
)
self._metrics = train_result.metrics
self._metrics["train_samples"] = len(train_dataset)
self._stats.train_corpus_size = self._metrics["train_samples"]
self._trainer.log_metrics("train", self._metrics)
logger.info("Model training finished")
def save(self) -> None:
if self._trainer is None:
raise RuntimeError("The model has not been trained.")
self._trainer.save_model()
self._trainer.save_metrics("train", self._metrics)
self._trainer.save_state()
if isinstance(self._model, PreTrainedModel):
self._model.name_or_path = self._training_args.output_dir
self._model.config.name_or_path = self._training_args.output_dir
self._model.config.use_cache = self._original_use_cache
def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
if self._trainer is not None:
del self._trainer
gc.collect()
with torch.no_grad():
torch.cuda.empty_cache()
class _ProgressCallback(TrainerCallback):
def __init__(
self,
max_steps: Optional[int],
progress: Optional[Callable[[ProgressStatus], None]],
check_canceled: Optional[Callable[[], None]],
) -> None:
self._max_steps = max_steps
self._progress = progress
self._check_canceled = check_canceled
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
if self._check_canceled is not None:
self._check_canceled()
if self._progress is not None and state.is_local_process_zero:
self._progress(
ProgressStatus(0) if self._max_steps is None else ProgressStatus.from_step(0, self._max_steps)
)
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
if self._check_canceled is not None:
self._check_canceled()
if self._progress is not None and state.is_local_process_zero:
self._progress(
ProgressStatus(state.global_step)
if self._max_steps is None
else ProgressStatus.from_step(state.global_step, self._max_steps)
)
def add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str):
if isinstance(tokenizer, M2M100Tokenizer):
lang_token = "__" + lang_code + "__"
else:
lang_token = lang_code
if lang_token in tokenizer.added_tokens_encoder:
return
tokenizer.add_special_tokens(
{"additional_special_tokens": tokenizer.additional_special_tokens + [lang_token]} # type: ignore
)
lang_id = cast(int, tokenizer.convert_tokens_to_ids(lang_token))
if isinstance(tokenizer, (MBart50Tokenizer, MBartTokenizer)):
tokenizer.lang_code_to_id[lang_code] = lang_id
tokenizer.id_to_lang_code[lang_id] = lang_code
tokenizer.fairseq_tokens_to_ids[lang_code] = lang_id
tokenizer.fairseq_ids_to_tokens[lang_id] = lang_code
elif isinstance(tokenizer, M2M100Tokenizer):
tokenizer.lang_code_to_id[lang_code] = lang_id
tokenizer.lang_code_to_token[lang_code] = lang_token
tokenizer.lang_token_to_id[lang_token] = lang_id
tokenizer.id_to_lang_token[lang_id] = lang_token