5
5
from onmt .opts import translate_opts
6
6
from onmt .constants import CorpusTask
7
7
from onmt .inputters .dynamic_iterator import build_dynamic_dataset_iter
8
- from onmt .transforms import get_transforms_cls , make_transforms , TransformPipe
8
+ from onmt .transforms import get_transforms_cls
9
9
10
10
11
11
class ScoringPreparator :
@@ -19,16 +19,12 @@ def __init__(self, vocabs, opt):
19
19
if self .opt .dump_preds is not None :
20
20
if not os .path .exists (self .opt .dump_preds ):
21
21
os .makedirs (self .opt .dump_preds )
22
- self .transforms = opt .transforms
23
- transforms_cls = get_transforms_cls (self .transforms )
24
- transforms = make_transforms (self .opt , transforms_cls , self .vocabs )
25
- self .transform = TransformPipe .build_from (transforms .values ())
22
+ self .transforms = None
23
+ self .transforms_cls = None
26
24
27
25
def warm_up (self , transforms ):
28
26
self .transforms = transforms
29
- transforms_cls = get_transforms_cls (self .transforms )
30
- transforms = make_transforms (self .opt , transforms_cls , self .vocabs )
31
- self .transform = TransformPipe .build_from (transforms .values ())
27
+ self .transforms_cls = get_transforms_cls (transforms )
32
28
33
29
def translate (self , model , gpu_rank , step ):
34
30
"""Compute and save the sentences predicted by the
@@ -84,7 +80,7 @@ def translate(self, model, gpu_rank, step):
84
80
85
81
# Reinstantiate the validation iterator
86
82
87
- transforms_cls = get_transforms_cls (model_opt ._all_transform )
83
+ # transforms_cls = get_transforms_cls(model_opt._all_transform)
88
84
model_opt .num_workers = 0
89
85
model_opt .tgt = None
90
86
@@ -100,7 +96,7 @@ def translate(self, model, gpu_rank, step):
100
96
101
97
valid_iter = build_dynamic_dataset_iter (
102
98
model_opt ,
103
- transforms_cls ,
99
+ self . transforms_cls ,
104
100
translator .vocabs ,
105
101
task = CorpusTask .VALID ,
106
102
tgt = "" , # This force to clear the target side (needed when using tgt_file_prefix)
@@ -125,12 +121,11 @@ def translate(self, model, gpu_rank, step):
125
121
126
122
# Flatten predictions
127
123
preds = [x .lstrip () for sublist in preds for x in sublist ]
128
-
129
124
# Save results
130
125
if len (preds ) > 0 and self .opt .scoring_debug :
131
126
path = os .path .join (self .opt .dump_preds , f"preds.valid_step_{ step } .txt" )
132
127
with open (path , "a" ) as file :
133
- for i in range (len (preds )):
128
+ for i in range (len (raw_srcs )):
134
129
file .write ("SOURCE: {}\n " .format (raw_srcs [i ]))
135
130
file .write ("REF: {}\n " .format (raw_refs [i ]))
136
131
file .write ("PRED: {}\n \n " .format (preds [i ]))
0 commit comments