Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Source features support for V2.0 #2090

Merged
merged 24 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions onmt/bin/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import onmt.opts as opts
from onmt.utils.parse import ArgumentParser
from collections import defaultdict


def translate(opt):
Expand All @@ -15,12 +16,21 @@ def translate(opt):
translator = build_translator(opt, logger=logger, report_score=True)
src_shards = split_corpus(opt.src, opt.shard_size)
tgt_shards = split_corpus(opt.tgt, opt.shard_size)
shard_pairs = zip(src_shards, tgt_shards)

for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
features_shards = []
features_names = []
for feat_name, feat_path in opt.src_feats.items():
features_shards.append(split_corpus(feat_path, opt.shard_size))
features_names.append(feat_name)
shard_pairs = zip(src_shards, tgt_shards, *features_shards)

for i, (src_shard, tgt_shard, *features_shard) in enumerate(shard_pairs):
features_shard_ = defaultdict(list)
for j, x in enumerate(features_shard):
features_shard_[features_names[j]] = x
logger.info("Translating shard %d." % i)
translator.translate(
src=src_shard,
src_feats=features_shard_,
tgt=tgt_shard,
batch_size=opt.batch_size,
batch_type=opt.batch_type,
Expand Down
28 changes: 18 additions & 10 deletions onmt/inputters/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from contextlib import contextmanager

import multiprocessing as mp
from collections import defaultdict


@contextmanager
Expand Down Expand Up @@ -70,13 +71,19 @@ def _process(item, is_train):
example, is_train=is_train, corpus_name=cid)
if maybe_example is None:
return None
maybe_example['src'] = ' '.join(maybe_example['src'])
maybe_example['tgt'] = ' '.join(maybe_example['tgt'])

maybe_example['src'] = {"src": ' '.join(maybe_example['src'])}

# Make features part of src as in MultiTextField
if 'src_feats' in maybe_example:
for feat_name, feat_value in maybe_example['src_feats'].items():
maybe_example['src'][feat_name] = ' '.join(feat_value)
del maybe_example["src_feats"]

maybe_example['tgt'] = {"tgt": ' '.join(maybe_example['tgt'])}
if 'align' in maybe_example:
maybe_example['align'] = ' '.join(maybe_example['align'])
if 'src_feats' in maybe_example:
for k in maybe_example['src_feats'].keys():
maybe_example['src_feats'][k] = ' '.join(maybe_example['src_feats'][k])

return maybe_example

def _maybe_add_dynamic_dict(self, example, fields):
Expand Down Expand Up @@ -176,7 +183,8 @@ def get_corpora(opts, is_train=False):
CorpusName.VALID,
opts.data[CorpusName.VALID]["path_src"],
opts.data[CorpusName.VALID]["path_tgt"],
opts.data[CorpusName.VALID]["path_align"])
opts.data[CorpusName.VALID]["path_align"],
opts.data[CorpusName.VALID]["src_feats"])
else:
return None
return corpora_dict
Expand Down Expand Up @@ -321,11 +329,11 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
build_sub_vocab.queues[c_name][offset].put("blank")
continue
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
if 'src_feats' in maybe_example:
for feat_name, feat_line in maybe_example["src_feats"].items():
for feat_name, feat_line in maybe_example["src"].items():
if feat_name != "src":
sub_counter_src_feats[feat_name].update(feat_line.split(' '))
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
sub_counter_src.update(src_line["src"].split(' '))
sub_counter_tgt.update(tgt_line["tgt"].split(' '))
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put(
(i, src_line, tgt_line))
Expand Down
4 changes: 2 additions & 2 deletions onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(self, fields, readers, data, sort_key, filter_pred=None):
self.sort_key = sort_key
can_copy = 'src_map' in fields and 'alignment' in fields

read_iters = [r.read(dat[1], dat[0]) for r, dat in zip(readers, data)]
read_iters = [r.read(dat, name, feats) for r, (name, dat, feats) in zip(readers, data)]

# self.src_vocabs is used in collapse_copy_scores and Translator.py
self.src_vocabs = []
Expand Down Expand Up @@ -162,5 +162,5 @@ def config(fields):
for name, field in fields:
if field["data"] is not None:
readers.append(field["reader"])
data.append((name, field["data"]))
data.append((name, field["data"], field["features"]))
return readers, data
24 changes: 19 additions & 5 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class TextDataReader(DataReaderBase):
def read(self, sequences, side):
def read(self, sequences, side, features={}):
"""Read text data from disk.

Args:
Expand All @@ -25,10 +25,25 @@ def read(self, sequences, side):
"""
if isinstance(sequences, str):
sequences = DataReaderBase._read_file(sequences)
for i, seq in enumerate(sequences):

features_names = []
features_values = []
for feat_name, v in features.items():
features_names.append(feat_name)
if isinstance(v, str):
features_values.append(DataReaderBase._read_file(features))
else:
features_values.append(v)
for i, (seq, *feats) in enumerate(zip(sequences, *features_values)):
ex_dict = {}
if isinstance(seq, bytes):
seq = seq.decode("utf-8")
yield {side: seq, "indices": i}
ex_dict[side] = seq
for i, f in enumerate(feats):
if isinstance(f, bytes):
f = f.decode("utf-8")
ex_dict[features_names[i]] = f
yield {side: ex_dict, "indices": i}


def text_sort_key(ex):
Expand Down Expand Up @@ -140,8 +155,7 @@ def preprocess(self, x):
lists of tokens/feature tags for the sentence. The output
is ordered like ``self.fields``.
"""

return [f.preprocess(x) for _, f in self.fields]
return [f.preprocess(x[fn]) for fn, f in self.fields]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key part of the last change


def __getitem__(self, item):
return self.fields[item]
Expand Down
3 changes: 3 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,9 @@ def translate_opts(parser):
group.add('--src', '-src', required=True,
help="Source sequence to decode (one line per "
"sequence)")
group.add("-src_feats", "--src_feats", required=False,
help="Source sequence features (one line per "
anderleich marked this conversation as resolved.
Show resolved Hide resolved
"sequence)")
group.add('--tgt', '-tgt',
help='True target sequence (optional)')
group.add('--tgt_prefix', '-tgt_prefix', action='store_true',
Expand Down
5 changes: 3 additions & 2 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def _gold_score(
def translate(
self,
src,
src_feats={},
tgt=None,
batch_size=None,
batch_type="sents",
Expand Down Expand Up @@ -363,8 +364,8 @@ def translate(
if self.tgt_prefix and tgt is None:
raise ValueError("Prefix should be feed to tgt if -tgt_prefix.")

src_data = {"reader": self.src_reader, "data": src}
tgt_data = {"reader": self.tgt_reader, "data": tgt}
src_data = {"reader": self.src_reader, "data": src, "features": src_feats}
tgt_data = {"reader": self.tgt_reader, "data": tgt, "features": {}}
_readers, _data = inputters.Dataset.config(
[("src", src_data), ("tgt", tgt_data)]
)
Expand Down
2 changes: 1 addition & 1 deletion onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,4 @@ def validate_train_opts(cls, opt):

@classmethod
def validate_translate_opts(cls, opt):
pass
opt.src_feats = eval(opt.src_feats) if opt.src_feats else {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this line? What's the expected input for this argument?
I'm assuming it would be a list of file paths, then using nargs in the group.add("-src_feats", ...) of onmt/opts.py could do the trick.

Copy link
Contributor Author

@anderleich anderleich Aug 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not a list of paths, it's a dictionary mapping feature names with the corresponding file path. Like this:

--src_feats "{'feat0': '../kk.txt.feats0', 'feat1': '../kk.txt.feats1'}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's reasonable then.