-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathhugging_face_nmt_engine.py
413 lines (373 loc) · 17.1 KB
/
hugging_face_nmt_engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
from __future__ import annotations
import gc
import logging
import re
from math import exp, prod
from typing import Iterable, List, Optional, Sequence, Tuple, Union, cast
import torch # pyright: ignore[reportMissingImports]
from sacremoses import MosesPunctNormalizer
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
M2M100Tokenizer,
NllbTokenizer,
NllbTokenizerFast,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
TranslationPipeline,
)
from transformers.generation import BeamSearchEncoderDecoderOutput, GreedySearchEncoderDecoderOutput
from transformers.tokenization_utils import BatchEncoding, TruncationStrategy
from ...annotations.range import Range
from ...utils.typeshed import StrPath
from ..translation_engine import TranslationEngine
from ..translation_result import TranslationResult
from ..translation_result_builder import TranslationResultBuilder
from ..translation_sources import TranslationSources
from ..word_alignment_matrix import WordAlignmentMatrix
logger = logging.getLogger(__name__)
class HuggingFaceNmtEngine(TranslationEngine):
def __init__(
self,
model: Union[PreTrainedModel, StrPath, str],
oom_batch_size_backoff_mult: float = 1.0,
**pipeline_kwargs,
) -> None:
self._model = model
self._pipeline_kwargs = pipeline_kwargs
if isinstance(self._model, PreTrainedModel):
self._model.eval()
self._is_model_owned = False
else:
model_config = AutoConfig.from_pretrained(str(self._model), label2id={}, id2label={}, num_labels=0)
self._model = cast(
PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(self._model), config=model_config)
)
self._is_model_owned = True
self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True)
if isinstance(self._tokenizer, (NllbTokenizer, NllbTokenizerFast)):
self._mpn = MosesPunctNormalizer()
self._mpn.substitutions = [ # type: ignore
(re.compile(r), sub)
for r, sub in self._mpn.substitutions
if isinstance(r, str) and isinstance(sub, str)
]
else:
self._mpn = None
src_lang = self._pipeline_kwargs.get("src_lang")
tgt_lang = self._pipeline_kwargs.get("tgt_lang")
if (
src_lang is not None
and tgt_lang is not None
and "prefix" not in self._pipeline_kwargs
and (self._model.name_or_path.startswith("t5-") or self._model.name_or_path.startswith("google/mt5-"))
):
self._pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
else:
additional_special_tokens = self._tokenizer.additional_special_tokens
if isinstance(self._tokenizer, M2M100Tokenizer):
src_lang_token = self._tokenizer.lang_code_to_token.get(src_lang) if src_lang is not None else None
tgt_lang_token = self._tokenizer.lang_code_to_token.get(tgt_lang) if tgt_lang is not None else None
else:
src_lang_token = src_lang
tgt_lang_token = tgt_lang
if src_lang is not None and (
src_lang_token is None
or (
src_lang_token not in self._tokenizer.added_tokens_encoder
and src_lang_token not in additional_special_tokens # type: ignore - we already check for None
)
):
raise ValueError(f"The specified model does not support the language code '{src_lang}'")
if tgt_lang is not None and (
tgt_lang_token is None
or (
tgt_lang_token not in self._tokenizer.added_tokens_encoder
and tgt_lang_token not in additional_special_tokens # type: ignore - we already check for None
)
):
raise ValueError(f"The specified model does not support the language code '{tgt_lang}'")
self._batch_size = int(self._pipeline_kwargs.pop("batch_size", 1))
self._oom_batch_size_backoff_mult = oom_batch_size_backoff_mult
self._pipeline = _TranslationPipeline(
model=self._model,
tokenizer=self._tokenizer,
mpn=self._mpn,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)
@property
def tokenizer(self) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
return self._tokenizer
def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult:
return self.translate_batch([segment])[0]
def translate_n(self, n: int, segment: Union[str, Sequence[str]]) -> Sequence[TranslationResult]:
return self.translate_n_batch(n, [segment])[0]
def translate_batch(self, segments: Sequence[Union[str, Sequence[str]]]) -> Sequence[TranslationResult]:
return [results[0] for results in self.translate_n_batch(1, segments)]
def translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
while True:
if type(segments) is str:
segments = [segments]
else:
segments = [segment for segment in segments]
outer_batch_size = len(segments)
all_results: List[Sequence[TranslationResult]] = []
try:
for step in range(0, outer_batch_size, self._batch_size):
all_results.extend(self._try_translate_n_batch(n, segments[step : step + self._batch_size]))
return all_results
except torch.cuda.OutOfMemoryError:
if self._oom_batch_size_backoff_mult >= 0.9999 or self._batch_size <= 1:
raise
self._batch_size = max(int(round(self._batch_size * self._oom_batch_size_backoff_mult)), 1)
logger.warning(f"Out of memory error caught. Reducing batch size to {self._batch_size} and retrying.")
self._pipeline = _TranslationPipeline(
model=self._model,
tokenizer=self._tokenizer,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)
def _try_translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
all_results: List[List[TranslationResult]] = []
i = 0
for outputs in cast(
Iterable[Union[List[dict], dict]],
self._pipeline(segments, num_return_sequences=n),
):
if isinstance(outputs, dict):
outputs = [outputs]
segment_results: List[TranslationResult] = []
for output in outputs:
input_tokens: Sequence[str] = output["input_tokens"]
output_length = len(output["translation_tokens"])
builder = TranslationResultBuilder(input_tokens)
for token, score in zip(output["translation_tokens"], output["token_scores"]):
builder.append_token(token, TranslationSources.NMT, exp(score))
src_indices = torch.argmax(output["token_attentions"], dim=1).tolist()
wa_matrix = WordAlignmentMatrix.from_word_pairs(
len(input_tokens), output_length, set(zip(src_indices, range(output_length)))
)
builder.mark_phrase(Range.create(0, len(input_tokens)), wa_matrix)
segment_results.append(builder.to_result(output["translation_text"]))
all_results.append(segment_results)
i += 1
return all_results
def __enter__(self) -> HuggingFaceNmtEngine:
return self
def close(self) -> None:
del self._pipeline
if self._is_model_owned:
del self._model
gc.collect()
with torch.no_grad():
torch.cuda.empty_cache()
class _TranslationPipeline(TranslationPipeline):
def __init__(
self,
model: Union[PreTrainedModel, StrPath, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
batch_size: int,
mpn: Optional[MosesPunctNormalizer] = None,
**kwargs,
) -> None:
super().__init__(model=model, tokenizer=tokenizer, batch_size=batch_size, **kwargs)
self._mpn = mpn
def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None):
if self.tokenizer is None:
raise RuntimeError("No tokenizer is specified.")
if self._mpn:
sentences = [
(
self._mpn.normalize(s)
if isinstance(s, str)
else self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s), use_source_tokenizer=True)
)
for s in args
]
else:
sentences = [
(
s
if isinstance(s, str)
else self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s), use_source_tokenizer=True)
)
for s in args
]
inputs = cast(
BatchEncoding,
super().preprocess(*sentences, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang),
)
if inputs.encodings is not None:
inputs["input_tokens"] = [
_get_encoding_fast_tokens(inputs.encodings[i]) if isinstance(args[i], str) else args[i]
for i in range(len(args))
]
else:
inputs["input_tokens"] = [self.tokenizer.tokenize(s) if isinstance(s, str) else s for s in args]
return inputs
def _forward(self, model_inputs, **generate_kwargs):
in_b, input_length = model_inputs["input_ids"].shape
input_tokens = model_inputs["input_tokens"]
del model_inputs["input_tokens"]
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
config = self.model.generation_config
else:
config = self.model.config
generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length)
generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length)
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
output = self.model.generate(
**model_inputs,
**generate_kwargs,
output_scores=True,
output_attentions=True,
return_dict_in_generate=True,
)
if isinstance(output, BeamSearchEncoderDecoderOutput):
output_ids = output.sequences
beam_indices = output.beam_indices
scores = output.scores
attentions = output.cross_attentions
elif isinstance(output, GreedySearchEncoderDecoderOutput):
output_ids = output.sequences
beam_indices = torch.zeros_like(output_ids)
assert output.scores is not None
scores = tuple(torch.nn.functional.log_softmax(logits, dim=-1) for logits in output.scores)
attentions = output.cross_attentions
else:
raise RuntimeError("Cannot postprocess the output of the model.")
assert beam_indices is not None and scores is not None
out_b = output_ids.shape[0]
num_beams = scores[0].shape[0] // in_b
n_sequences = out_b // in_b
start_index = 0
if self.model.config.decoder_start_token_id is not None:
start_index = 1
indices = torch.stack(
(
torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand(in_b, n_sequences, -1),
torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)),
torch.reshape(output_ids[:, start_index:], (in_b, n_sequences, -1)),
),
dim=3,
)
scores = torch.stack(scores, dim=0).reshape(len(scores), in_b, num_beams, -1).transpose(0, 1)
scores = torch_gather_nd(scores, indices, 1)
if self.model.config.decoder_start_token_id is not None:
scores = torch.cat((torch.zeros(scores.shape[0], scores.shape[1], 1, device=scores.device), scores), dim=2)
assert attentions is not None
num_heads = attentions[0][0].shape[1]
indices = torch.stack(
(
torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand(in_b, n_sequences, -1),
torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)),
),
dim=3,
)
num_layers = len(attentions[0])
layer = (2 * num_layers) // 3
attentions = (
torch.stack([cast(Tuple[torch.FloatTensor, ...], a)[layer][:, :, -1, :] for a in attentions], dim=0)
.squeeze()
.reshape(len(attentions), in_b, num_beams, num_heads, -1)
.transpose(0, 1)
)
attentions = torch.mean(attentions, dim=3)
attentions = torch_gather_nd(attentions, indices, 1)
if self.model.config.decoder_start_token_id is not None:
attentions = torch.cat(
(
torch.zeros(
(attentions.shape[0], attentions.shape[1], 1, attentions.shape[3]),
device=attentions.device,
),
attentions,
),
dim=2,
)
output_ids = output_ids.reshape(in_b, n_sequences, *output_ids.shape[1:])
return {
"input_ids": model_inputs["input_ids"],
"input_tokens": input_tokens,
"output_ids": output_ids,
"scores": scores,
"attentions": attentions,
}
def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
if self.tokenizer is None:
raise RuntimeError("No tokenizer is specified.")
all_special_ids = set(self.tokenizer.all_special_ids)
input_ids = model_outputs["input_ids"][0]
input_indices: List[int] = []
for i, input_id in enumerate(input_ids):
id = cast(int, input_id.item())
if id not in all_special_ids:
input_indices.append(i)
input_tokens = model_outputs["input_tokens"][0]
records = []
output_ids: torch.Tensor
scores: torch.Tensor
attentions: torch.Tensor
for output_ids, scores, attentions in zip(
model_outputs["output_ids"][0],
model_outputs["scores"][0],
model_outputs["attentions"][0],
):
output_tokens: List[str] = []
output_indices: List[int] = []
for i, output_id in enumerate(output_ids):
id = cast(int, output_id.item())
if id not in all_special_ids:
output_tokens.append(self.tokenizer.convert_ids_to_tokens(id))
output_indices.append(i)
scores = scores[output_indices]
attentions = attentions[output_indices]
attentions = attentions[:, input_indices]
records.append(
{
"input_tokens": input_tokens,
"translation_tokens": output_tokens,
"token_scores": scores,
"token_attentions": attentions,
"translation_text": self.tokenizer.decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
),
}
)
return records
def torch_gather_nd(params: torch.Tensor, indices: torch.Tensor, batch_dim: int = 0) -> torch.Tensor:
"""
torch_gather_nd implements tf.gather_nd in PyTorch.
This supports multiple batch dimensions as well as multiple channel dimensions.
"""
index_shape = indices.shape[:-1]
num_dim = indices.size(-1)
tail_sizes = params.shape[batch_dim + num_dim :]
# flatten extra dimensions
for s in tail_sizes:
row_indices = torch.arange(s, device=params.device)
indices = indices.unsqueeze(-2)
indices = indices.repeat(*[1 for _ in range(indices.dim() - 2)], s, 1)
row_indices = row_indices.expand(*indices.shape[:-2], -1).unsqueeze(-1)
indices = torch.cat((indices, row_indices), dim=-1)
num_dim += 1
# flatten indices and params to batch specific ones instead of channel specific
for i in range(num_dim):
size = prod(params.shape[batch_dim + i + 1 : batch_dim + num_dim])
indices[..., i] *= size
indices = indices.sum(dim=-1)
params = params.flatten(batch_dim, -1)
indices = indices.flatten(batch_dim, -1)
out = torch.gather(params, dim=batch_dim, index=indices)
return out.reshape(*index_shape, *tail_sizes)
def _get_encoding_fast_tokens(encoding) -> List[str]:
return [token for (token, mask) in zip(encoding.tokens, encoding.special_tokens_mask) if not mask]