|
27 | 27 |
|
28 | 28 | def prepare_transforms_vocabs(opt, transforms_cls): |
29 | 29 | """Prepare or dump transforms before training.""" |
30 | | - # if transform + options set in 'valid' we need to copy in main |
31 | | - # transform / options for scoring considered as inference |
32 | | - validset_transforms = opt.data.get("valid", {}).get("transforms", None) |
33 | | - if validset_transforms: |
34 | | - opt.transforms = validset_transforms |
35 | | - if opt.data.get("valid", {}).get("tgt_prefix", None): |
36 | | - opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None) |
37 | | - opt.tgt_file_prefix = True |
38 | | - if opt.data.get("valid", {}).get("src_prefix", None): |
39 | | - opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None) |
40 | | - if opt.data.get("valid", {}).get("tgt_suffix", None): |
41 | | - opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None) |
42 | | - if opt.data.get("valid", {}).get("src_suffix", None): |
43 | | - opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None) |
44 | 30 | specials = get_specials(opt, transforms_cls) |
45 | 31 |
|
46 | 32 | vocabs = build_vocab(opt, specials) |
@@ -77,6 +63,20 @@ def _init_train(opt): |
77 | 63 | """ |
78 | 64 | ArgumentParser.validate_prepare_opts(opt) |
79 | 65 | transforms_cls = get_transforms_cls(opt._all_transform) |
| 66 | + # if transform + options set in 'valid' we need to copy in main |
| 67 | + # transform / options for scoring considered as inference |
| 68 | + validset_transforms = opt.data.get("valid", {}).get("transforms", None) |
| 69 | + if validset_transforms: |
| 70 | + opt.transforms = validset_transforms |
| 71 | + if opt.data.get("valid", {}).get("tgt_prefix", None): |
| 72 | + opt.tgt_prefix = opt.data.get("valid", {}).get("tgt_prefix", None) |
| 73 | + opt.tgt_file_prefix = True |
| 74 | + if opt.data.get("valid", {}).get("src_prefix", None): |
| 75 | + opt.src_prefix = opt.data.get("valid", {}).get("src_prefix", None) |
| 76 | + if opt.data.get("valid", {}).get("tgt_suffix", None): |
| 77 | + opt.tgt_suffix = opt.data.get("valid", {}).get("tgt_suffix", None) |
| 78 | + if opt.data.get("valid", {}).get("src_suffix", None): |
| 79 | + opt.src_suffix = opt.data.get("valid", {}).get("src_suffix", None) |
80 | 80 | if opt.train_from: |
81 | 81 | # Load checkpoint if we resume from a previous training. |
82 | 82 | checkpoint = load_checkpoint(ckpt_path=opt.train_from) |
|
0 commit comments