Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Source features support for V2.0 #2090

Merged
merged 24 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def build_vocab_main(opts):
transforms = make_transforms(opts, transforms_cls, fields)

logger.info(f"Counter vocab from {opts.n_sample} samples.")
src_counter, tgt_counter = build_vocab(
src_counter, tgt_counter, src_feats_counter = build_vocab(
opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for feat_name, feat_counter in src_feats_counter.items():
logger.info(f"Counters {feat_name}:{len(feat_counter)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -52,6 +54,9 @@ def save_counter(counter, save_path):
else:
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter.items():
save_counter(v, opts.src_feats_vocab[k])


def _get_parser():
Expand Down
1 change: 1 addition & 0 deletions onmt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class CorpusName(object):
class SubwordMarker(object):
SPACER = '▁'
JOINER = '■'
CASE_MARKUP = ["⦅mrk_case_modifier_C⦆", "⦅mrk_begin_case_region_U⦆", "⦅mrk_end_case_region_U⦆"]


class ModelTask(object):
Expand Down
46 changes: 37 additions & 9 deletions onmt/inputters/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchtext.data import Dataset as TorchtextDataset, \
Example as TorchtextExample

from collections import Counter
from collections import Counter, defaultdict
from contextlib import contextmanager

import multiprocessing as mp
Expand Down Expand Up @@ -74,6 +74,9 @@ def _process(item, is_train):
maybe_example['tgt'] = ' '.join(maybe_example['tgt'])
if 'align' in maybe_example:
maybe_example['align'] = ' '.join(maybe_example['align'])
if 'src_feats' in maybe_example:
for k in maybe_example['src_feats'].keys():
maybe_example['src_feats'][k] = ' '.join(maybe_example['src_feats'][k])
return maybe_example

def _maybe_add_dynamic_dict(self, example, fields):
Expand Down Expand Up @@ -107,23 +110,32 @@ def __call__(self, bucket):
class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(self, name, src, tgt, align=None):
def __init__(self, name, src, tgt, align=None, src_feats=None):
"""Initialize src & tgt side file path."""
self.id = name
self.src = src
self.tgt = tgt
self.align = align
self.src_feats = src_feats

def load(self, offset=0, stride=1):
"""
Load file and iterate by lines.
`offset` and `stride` allow to iterate only on every
`stride` example, starting from `offset`.
"""
if self.src_feats:
features_names = []
features_files = []
for feat_name, feat_path in self.src_feats.items():
features_names.append(feat_name)
features_files.append(open(feat_path, mode='rb'))
else:
features_files = []
with exfile_open(self.src, mode='rb') as fs,\
exfile_open(self.tgt, mode='rb') as ft,\
exfile_open(self.align, mode='rb') as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
for i, (sline, tline, align, *features) in enumerate(zip(fs, ft, fa, *features_files)):
if (i % stride) == offset:
sline = sline.decode('utf-8')
tline = tline.decode('utf-8')
Expand All @@ -133,12 +145,18 @@ def load(self, offset=0, stride=1):
}
if align is not None:
example['align'] = align.decode('utf-8')
if features:
example["src_feats"] = dict()
for j, feat in enumerate(features):
example["src_feats"][features_names[j]] = feat.decode("utf-8")
yield example
for f in features_files:
f.close()

def __str__(self):
cls_name = type(self).__name__
return '{}({}, {}, align={})'.format(
cls_name, self.src, self.tgt, self.align)
return '{}({}, {}, align={}, src_feats={})'.format(
cls_name, self.src, self.tgt, self.align, self.src_feats)


def get_corpora(opts, is_train=False):
Expand All @@ -150,7 +168,8 @@ def get_corpora(opts, is_train=False):
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"])
corpus_dict["path_align"],
corpus_dict["src_feats"])
else:
if CorpusName.VALID in opts.data.keys():
corpora_dict[CorpusName.VALID] = ParallelCorpus(
Expand Down Expand Up @@ -193,6 +212,9 @@ def _tokenize(self, stream):
example['src'], example['tgt'] = src, tgt
if 'align' in example:
example['align'] = example['align'].strip('\n').split()
if 'src_feats' in example:
for k in example['src_feats'].keys():
example['src_feats'][k] = example['src_feats'][k].strip('\n').split()
yield example

def _transform(self, stream):
Expand Down Expand Up @@ -286,6 +308,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = defaultdict(Counter)
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -298,6 +321,9 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
build_sub_vocab.queues[c_name][offset].put("blank")
continue
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
if 'src_feats' in maybe_example:
for feat_name, feat_line in maybe_example["src_feats"].items():
sub_counter_src_feats[feat_name].update(feat_line.split(' '))
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
if opts.dump_samples:
Expand All @@ -309,7 +335,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
break
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
return sub_counter_src, sub_counter_tgt
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats


def init_pool(queues):
Expand All @@ -333,6 +359,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, is_train=True)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = defaultdict(Counter)
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -349,13 +376,14 @@ def build_vocab(opts, transforms, n_sample=3):
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt in p.imap(
for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt
return counter_src, counter_tgt, counter_src_feats


def save_transformed_sample(opts, transforms, n_sample=3):
Expand Down
13 changes: 9 additions & 4 deletions onmt/inputters/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@


def _get_dynamic_fields(opts):
# NOTE: not support nfeats > 0 yet
src_nfeats = 0
tgt_nfeats = 0
# NOTE: not support tgt feats yet
tgt_feats = None
with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0
fields = get_fields('text', src_nfeats, tgt_nfeats,
fields = get_fields('text', opts.src_feats_vocab, tgt_feats,
dynamic_dict=opts.copy_attn,
src_truncate=opts.src_seq_length_trunc,
tgt_truncate=opts.tgt_seq_length_trunc,
Expand All @@ -33,6 +32,12 @@ def build_dynamic_fields(opts, src_specials=None, tgt_specials=None):
opts.src_vocab, 'src', counters,
min_freq=opts.src_words_min_frequency)

if opts.src_feats_vocab:
for feat_name, filepath in opts.src_feats_vocab.items():
_, _ = _load_vocab(
filepath, feat_name, counters,
min_freq=0)

if opts.tgt_vocab:
_tgt_vocab, _tgt_vocab_size = _load_vocab(
opts.tgt_vocab, 'tgt', counters,
Expand Down
12 changes: 6 additions & 6 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def get_task_spec_tokens(data_task, pad, bos, eos):

def get_fields(
src_data_type,
n_src_feats,
n_tgt_feats,
src_feats,
tgt_feats,
pad=DefaultTokens.PAD,
bos=DefaultTokens.BOS,
eos=DefaultTokens.EOS,
Expand All @@ -125,11 +125,11 @@ def get_fields(
"""
Args:
src_data_type: type of the source input. Options are [text].
n_src_feats (int): the number of source features (not counting tokens)
src_feats (int): source features dict containing their names
to create a :class:`torchtext.data.Field` for. (If
``src_data_type=="text"``, these fields are stored together
as a ``TextMultiField``).
n_tgt_feats (int): See above.
tgt_feats (int): See above.
anderleich marked this conversation as resolved.
Show resolved Hide resolved
pad (str): Special pad symbol. Used on src and tgt side.
bos (str): Special beginning of sequence symbol. Only relevant
for tgt.
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_fields(
task_spec_tokens = get_task_spec_tokens(data_task, pad, bos, eos)

src_field_kwargs = {
"n_feats": n_src_feats,
"feats": src_feats,
"include_lengths": True,
"pad": task_spec_tokens["src"]["pad"],
"bos": task_spec_tokens["src"]["bos"],
Expand All @@ -169,7 +169,7 @@ def get_fields(
fields["src"] = fields_getters[src_data_type](**src_field_kwargs)

tgt_field_kwargs = {
"n_feats": n_tgt_feats,
"feats": tgt_feats,
"include_lengths": False,
"pad": task_spec_tokens["tgt"]["pad"],
"bos": task_spec_tokens["tgt"]["bos"],
Expand Down
48 changes: 32 additions & 16 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def text_fields(**kwargs):

Args:
base_name (str): Name associated with the field.
n_feats (int): Number of word level feats (not counting the tokens)
feats (int): Word level feats
include_lengths (bool): Optionally return the sequence lengths.
pad (str, optional): Defaults to ``"<blank>"``.
bos (str or NoneType, optional): Defaults to ``"<s>"``.
Expand All @@ -163,28 +163,44 @@ def text_fields(**kwargs):
TextMultiField
"""

n_feats = kwargs["n_feats"]
feats = kwargs["feats"]
include_lengths = kwargs["include_lengths"]
base_name = kwargs["base_name"]
pad = kwargs.get("pad", DefaultTokens.PAD)
bos = kwargs.get("bos", DefaultTokens.BOS)
eos = kwargs.get("eos", DefaultTokens.EOS)
truncate = kwargs.get("truncate", None)
fields_ = []
feat_delim = u"│" if n_feats > 0 else None
for i in range(n_feats + 1):
name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
tokenize = partial(
_feature_tokenize,
layer=i,
truncate=truncate,
feat_delim=feat_delim)
use_len = i == 0 and include_lengths
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=use_len)
fields_.append((name, feat))

feat_delim = None #u"│" if n_feats > 0 else None

# Base field
tokenize = partial(
_feature_tokenize,
layer=None,
truncate=truncate,
feat_delim=feat_delim)
anderleich marked this conversation as resolved.
Show resolved Hide resolved
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=include_lengths)
fields_.append((base_name, feat))

# Feats fields
if feats:
for feat_name in feats.keys():
# Legacy function, it is not really necessary
tokenize = partial(
_feature_tokenize,
layer=None,
truncate=truncate,
feat_delim=feat_delim)
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=False)
fields_.append((feat_name, feat))

assert fields_[0][0] == base_name # sanity check
field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:])
return field
5 changes: 5 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def _add_dynamic_fields_opts(parser, build_vocab_only=False):
group.add("-share_vocab", "--share_vocab", action="store_true",
help="Share source and target vocabulary.")

group.add("-src_feats_vocab", "--src_feats_vocab",
help=("List of paths to save" if build_vocab_only else "List of paths to")
+ " src features vocabulary files. "
"Files format: one <word> or <word>\t<count> per line.")

if not build_vocab_only:
group.add("-src_vocab_size", "--src_vocab_size",
type=int, default=50000,
Expand Down
3 changes: 3 additions & 0 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def main(opt, fields, transforms_cls, checkpoint, device_id,
"""Start training on `device_id`."""
# NOTE: It's important that ``opt`` has been validated and updated
# at this point.

#import pdb
#pdb.set_trace()
configure_process(opt, device_id)
init_logger(opt.log_file)

Expand Down
Loading