Skip to content

Commit 211aeec

Browse files
authored
Fixes (#2569)
* fix rotary when very long first input * fix valid transform at scoring and tokenize with onmt_tokenize when docify
1 parent bedbcc4 commit 211aeec

File tree

6 files changed

+33
-26
lines changed

6 files changed

+33
-26
lines changed

onmt/modules/multi_headed_attn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def forward(
466466
if self.max_relative_positions == -1: # Rotary Embeddings
467467
if seqlen + start_pos > self.rope.size(0):
468468
# Resize rotary embeddings.
469-
self.rope, _, _ = rotaryembeddings(
469+
self.rope, self.cos, self.sin = rotaryembeddings(
470470
self.rotary_dim,
471471
maxseqlen=(seqlen + start_pos + 2048),
472472
base=self.rotary_theta,

onmt/tests/test_transform.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -327,12 +327,12 @@ def test_pyonmttok_bpe(self):
327327
"struc■",
328328
"tion■",
329329
":",
330-
"\n■",
330+
"⦅newline⦆■",
331331
"in■",
332332
"struc■",
333333
"tion■",
334-
"\n■",
335-
"\n■",
334+
"⦅newline⦆■",
335+
"⦅newline⦆■",
336336
"#■",
337337
"#■",
338338
"#",
@@ -342,7 +342,7 @@ def test_pyonmttok_bpe(self):
342342
"on■",
343343
"se",
344344
":",
345-
"\n",
345+
"⦅newline⦆",
346346
"<blank>",
347347
"respon■",
348348
"se",

onmt/trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def train(
297297
logger.info(
298298
"Start training loop and validate every %d steps...", valid_steps
299299
)
300-
logger.info("Scoring with: {}".format(self.scoring_preparator.transform))
300+
logger.info("Scoring with: {}".format(self.scoring_preparator.transforms))
301301

302302
total_stats = onmt.utils.Statistics()
303303
report_stats = onmt.utils.Statistics()

onmt/transforms/tokenize.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _tokenize(self, tokens, side="src", is_train=False):
157157
"""Tokenize a list of words."""
158158
# This method embeds a custom logic to correctly handle certain placeholders
159159
# in case the tokenizer doesn't preserve them.
160-
sentence = " ".join(tokens).replace(DefaultTokens.SEP, "\n")
160+
sentence = " ".join(tokens)
161161
# Locate the end-of-sentence placeholders.
162162
sent_list = sentence.split(DefaultTokens.EOS)
163163
# Tokenize each sentence separately.
@@ -257,6 +257,7 @@ def tokenize_string(self, string, side="src", is_train=False):
257257
"""Apply subword sampling or deterministic subwording"""
258258
sp_model = self.load_models[side]
259259
nbest_size = self.tgt_subword_nbest if side == "tgt" else self.src_subword_nbest
260+
string = string.replace(DefaultTokens.SEP, "\n")
260261
if is_train is False or nbest_size in [0, 1]:
261262
# derterministic subwording
262263
tokens = sp_model.encode(string, out_type=str)
@@ -441,6 +442,9 @@ def _parse_opts(self):
441442
self.src_other_kwargs = self.opts.src_onmttok_kwargs
442443
self.tgt_other_kwargs = self.opts.tgt_onmttok_kwargs
443444
self.gpt2_pretok = self.opts.gpt2_pretok
445+
self.preserve_placeholders = self.opts.tgt_onmttok_kwargs.get(
446+
"preserve_placeholders", False
447+
)
444448

445449
@classmethod
446450
def get_specials(cls, opts):
@@ -558,6 +562,11 @@ def tokenize_string(self, sentence, side="src", is_train=False):
558562
segmented.extend(["Ċ", "Ċ"])
559563
else:
560564
segmented.append(s)
565+
elif (
566+
self.src_subword_type == "sentencepiece" and not self.preserve_placeholders
567+
):
568+
sentence = sentence.replace(DefaultTokens.SEP, "\n")
569+
segmented = tokenizer(sentence)
561570
else:
562571
segmented = tokenizer(sentence)
563572
return segmented

onmt/transforms/transform.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -264,13 +264,16 @@ def _repr_args(self):
264264
def make_transforms(opts, transforms_cls, vocabs):
265265
"""Build transforms in `transforms_cls` with vocab of `fields`."""
266266
transforms = {}
267-
for name, transform_cls in transforms_cls.items():
268-
if transform_cls.require_vocab() and vocabs is None:
269-
logger.warning(f"{transform_cls.__name__} require vocab to apply, skip it.")
270-
continue
271-
transform_obj = transform_cls(opts)
272-
transform_obj.warm_up(vocabs)
273-
transforms[name] = transform_obj
267+
if transforms_cls:
268+
for name, transform_cls in transforms_cls.items():
269+
if transform_cls.require_vocab() and vocabs is None:
270+
logger.warning(
271+
f"{transform_cls.__name__} require vocab to apply, skip it."
272+
)
273+
continue
274+
transform_obj = transform_cls(opts)
275+
transform_obj.warm_up(vocabs)
276+
transforms[name] = transform_obj
274277
return transforms
275278

276279

onmt/utils/scoring_utils.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from onmt.opts import translate_opts
66
from onmt.constants import CorpusTask
77
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
8-
from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe
8+
from onmt.transforms import get_transforms_cls
99

1010

1111
class ScoringPreparator:
@@ -19,16 +19,12 @@ def __init__(self, vocabs, opt):
1919
if self.opt.dump_preds is not None:
2020
if not os.path.exists(self.opt.dump_preds):
2121
os.makedirs(self.opt.dump_preds)
22-
self.transforms = opt.transforms
23-
transforms_cls = get_transforms_cls(self.transforms)
24-
transforms = make_transforms(self.opt, transforms_cls, self.vocabs)
25-
self.transform = TransformPipe.build_from(transforms.values())
22+
self.transforms = None
23+
self.transforms_cls = None
2624

2725
def warm_up(self, transforms):
2826
self.transforms = transforms
29-
transforms_cls = get_transforms_cls(self.transforms)
30-
transforms = make_transforms(self.opt, transforms_cls, self.vocabs)
31-
self.transform = TransformPipe.build_from(transforms.values())
27+
self.transforms_cls = get_transforms_cls(transforms)
3228

3329
def translate(self, model, gpu_rank, step):
3430
"""Compute and save the sentences predicted by the
@@ -84,7 +80,7 @@ def translate(self, model, gpu_rank, step):
8480

8581
# Reinstantiate the validation iterator
8682

87-
transforms_cls = get_transforms_cls(model_opt._all_transform)
83+
# transforms_cls = get_transforms_cls(model_opt._all_transform)
8884
model_opt.num_workers = 0
8985
model_opt.tgt = None
9086

@@ -100,7 +96,7 @@ def translate(self, model, gpu_rank, step):
10096

10197
valid_iter = build_dynamic_dataset_iter(
10298
model_opt,
103-
transforms_cls,
99+
self.transforms_cls,
104100
translator.vocabs,
105101
task=CorpusTask.VALID,
106102
tgt="", # This force to clear the target side (needed when using tgt_file_prefix)
@@ -125,12 +121,11 @@ def translate(self, model, gpu_rank, step):
125121

126122
# Flatten predictions
127123
preds = [x.lstrip() for sublist in preds for x in sublist]
128-
129124
# Save results
130125
if len(preds) > 0 and self.opt.scoring_debug:
131126
path = os.path.join(self.opt.dump_preds, f"preds.valid_step_{step}.txt")
132127
with open(path, "a") as file:
133-
for i in range(len(preds)):
128+
for i in range(len(raw_srcs)):
134129
file.write("SOURCE: {}\n".format(raw_srcs[i]))
135130
file.write("REF: {}\n".format(raw_refs[i]))
136131
file.write("PRED: {}\n\n".format(preds[i]))

0 commit comments

Comments
 (0)