From 02a407949f1be43e68dc040bc9fac3b6c569accf Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 25 Aug 2022 18:48:28 -0700 Subject: [PATCH 01/35] Add UnitY implementation --- fairseq/criterions/ctc.py | 34 +- .../label_smoothed_cross_entropy.py | 85 ++- .../criterions/speech_to_speech_criterion.py | 262 ++++++- fairseq/data/audio/data_cfg.py | 14 + .../data/audio/speech_to_speech_dataset.py | 57 +- fairseq/data/audio/speech_to_text_dataset.py | 207 ++++- fairseq/data/audio/text_to_speech_dataset.py | 10 +- fairseq/dataclass/configs.py | 22 + fairseq/models/speech_to_speech/__init__.py | 3 +- .../models/speech_to_speech/s2s_conformer.py | 102 ++- .../speech_to_speech/s2s_conformer_t2.py | 711 ++++++++++++++++++ .../speech_to_speech/s2s_transformer.py | 27 +- fairseq/models/speech_to_text/__init__.py | 3 +- .../models/speech_to_text/convtransformer.py | 21 +- .../speech_to_text/modules/convolution.py | 126 ++++ .../models/speech_to_text/s2t_conformer.py | 94 ++- .../models/speech_to_text/s2t_transformer.py | 88 +-- .../models/speech_to_text/xm_transformer.py | 71 +- .../speech_to_text/xm_transformer_unity.py | 312 ++++++++ .../transformer/transformer_decoder_aug.py | 391 ++++++++++ fairseq/modules/transformer_layer_aug.py | 319 ++++++++ fairseq/sequence_generator.py | 116 ++- fairseq/sequence_generator_multi_decoder.py | 258 +++++++ fairseq/speech_generator.py | 200 ++++- fairseq/tasks/speech_to_speech.py | 227 ++++-- fairseq/tasks/speech_to_text.py | 160 +++- 26 files changed, 3576 insertions(+), 344 deletions(-) create mode 100644 fairseq/models/speech_to_speech/s2s_conformer_t2.py create mode 100644 fairseq/models/speech_to_text/modules/convolution.py create mode 100644 fairseq/models/speech_to_text/xm_transformer_unity.py create mode 100644 fairseq/models/transformer/transformer_decoder_aug.py create mode 100644 fairseq/modules/transformer_layer_aug.py create mode 100644 fairseq/sequence_generator_multi_decoder.py diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index e966e47cf2..19025a30a5 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -7,17 +7,18 @@ import math from argparse import Namespace from dataclasses import dataclass, field -from omegaconf import II from typing import Optional import torch import torch.nn.functional as F +from omegaconf import II + from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion -from fairseq.dataclass import FairseqDataclass from fairseq.data.data_utils import post_process -from fairseq.tasks import FairseqTask +from fairseq.dataclass import FairseqDataclass from fairseq.logging.meters import safe_round +from fairseq.tasks import FairseqTask @dataclass @@ -64,7 +65,9 @@ class CtcCriterionConfig(FairseqDataclass): @register_criterion("ctc", dataclass=CtcCriterionConfig) class CtcCriterion(FairseqCriterion): - def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask): + def __init__( + self, cfg: CtcCriterionConfig, task: FairseqTask, rdrop_alpha: bool = False + ): super().__init__(task) self.blank_idx = ( task.target_dictionary.index(task.blank_symbol) @@ -75,6 +78,8 @@ def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask): self.eos_idx = task.target_dictionary.eos() self.post_process = cfg.post_process + self.rdrop_alpha = rdrop_alpha + if cfg.wer_args is not None: ( cfg.wer_kenlm_model, @@ -106,12 +111,31 @@ def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask): self.zero_infinity = cfg.zero_infinity self.sentence_avg = cfg.sentence_avg - def forward(self, model, sample, reduce=True): + def forward(self, model, sample, reduce=True, net_output=None): net_output = model(**sample["net_input"]) lprobs = model.get_normalized_probs( net_output, log_probs=True ).contiguous() # (T, B, C) from the encoder + # CTC loss is calculated over duplicated inputs + # sample is already duplicated for R-Drop + if self.rdrop_alpha > 0: + for k, v in sample.items(): + if k in ["target", "target_lengths"]: + sample[k] = torch.cat([v, v.clone()], dim=0) + elif k == "net_input": + if sample[k]["src_tokens"].size(1) != sample[k]["src_lengths"].size( + 0 + ): + # for decoder CTC loss + sample[k]["src_lengths"] = torch.cat( + [ + sample[k]["src_lengths"], + sample[k]["src_lengths"].clone(), + ], + dim=0, + ) + if "src_lengths" in sample["net_input"]: input_lengths = sample["net_input"]["src_lengths"] else: diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index cb43be0ca5..036dff943a 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -7,10 +7,11 @@ from dataclasses import dataclass, field import torch +from omegaconf import II + from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion from fairseq.dataclass import FairseqDataclass -from omegaconf import II @dataclass @@ -19,6 +20,10 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass): default=0.0, metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, ) + rdrop_alpha: float = field( + default=0.0, + metadata={"help": "alpha for r-drop, 0 means no r-drop"}, + ) report_accuracy: bool = field( default=False, metadata={"help": "report accuracy metric"}, @@ -59,6 +64,7 @@ def __init__( task, sentence_avg, label_smoothing, + rdrop_alpha, ignore_prefix_size=0, report_accuracy=False, ): @@ -67,8 +73,9 @@ def __init__( self.eps = label_smoothing self.ignore_prefix_size = ignore_prefix_size self.report_accuracy = report_accuracy + self.rdrop_alpha = rdrop_alpha - def forward(self, model, sample, reduce=True): + def forward(self, model, sample, reduce=True, net_output=None): """Compute the loss for the given sample. Returns a tuple with three elements: @@ -76,8 +83,15 @@ def forward(self, model, sample, reduce=True): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - net_output = model(**sample["net_input"]) - loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) + if net_output is None: + if self.rdrop_alpha > 0 and sample["net_input"]["src_tokens"].size( + 0 + ) == sample["target"].size(0): + sample = duplicate_input(sample) + net_output = model(**sample["net_input"]) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, net_output, sample, reduce=reduce + ) sample_size = ( sample["target"].size(0) if self.sentence_avg else sample["ntokens"] ) @@ -92,11 +106,16 @@ def forward(self, model, sample, reduce=True): n_correct, total = self.compute_accuracy(model, net_output, sample) logging_output["n_correct"] = utils.item(n_correct.data) logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) return loss, sample_size, logging_output def get_lprobs_and_target(self, model, net_output, sample): lprobs = model.get_normalized_probs(net_output, log_probs=True) target = model.get_targets(sample, net_output) + if self.rdrop_alpha > 0 or target.size(0) != lprobs.size(0): + target = torch.cat([target, target.clone()], dim=0) + if self.ignore_prefix_size > 0: # lprobs: B x T x C lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() @@ -112,7 +131,15 @@ def compute_loss(self, model, net_output, sample, reduce=True): ignore_index=self.padding_idx, reduce=reduce, ) - return loss, nll_loss + + if self.rdrop_alpha > 0: + pad_mask = target[: target.size(0) // 2].unsqueeze(-1).eq(self.padding_idx) + rdrop_kl_loss = compute_kl_loss(model, net_output, pad_mask) + loss += self.rdrop_alpha * rdrop_kl_loss + else: + rdrop_kl_loss = loss.new_zeros(1) + + return loss, nll_loss, rdrop_kl_loss def compute_accuracy(self, model, net_output, sample): lprobs, target = self.get_lprobs_and_target(model, net_output, sample) @@ -156,6 +183,13 @@ def reduce_metrics(cls, logging_outputs) -> None: if meters["total"].sum > 0 else float("nan"), ) + rdrop_kl_loss = utils.item( + sum(log.get("rdrop_kl_loss", 0) for log in logging_outputs) + / sample_size + / math.log(2) + ) + if rdrop_kl_loss > 0: + metrics.log_scalar("rdrop_kl_loss", rdrop_kl_loss) @staticmethod def logging_outputs_can_be_summed() -> bool: @@ -165,3 +199,44 @@ def logging_outputs_can_be_summed() -> bool: to True will improves distributed training speed. """ return True + + +def duplicate_input(sample): + if "net_input" in sample.keys(): + sample_input = sample["net_input"] + else: + sample_input = sample + + for k, v in sample_input.items(): + if isinstance(v, torch.Tensor): + sample_input[k] = torch.cat([v, v.clone()], dim=0) + if "net_input" in sample.keys(): + sample["net_input"] = sample_input + else: + sample = sample_input + return sample + + +def compute_kl_loss(model, net_output, pad_mask=None, reduce=True): + net_prob = model.get_normalized_probs(net_output, log_probs=True) + net_prob_tec = model.get_normalized_probs(net_output, log_probs=False) + + net_prob = net_prob.view(-1, net_prob.size(-1)) + net_prob_tec = net_prob_tec.view(-1, net_prob_tec.size(-1)) + + p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0) + p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0) + + p_loss = torch.nn.functional.kl_div(p, q_tec, reduction="none") + q_loss = torch.nn.functional.kl_div(q, p_tec, reduction="none") + + if pad_mask is not None: + p_loss.masked_fill_(pad_mask, 0.0) + q_loss.masked_fill_(pad_mask, 0.0) + + if reduce: + p_loss = p_loss.sum() + q_loss = q_loss.sum() + + loss = (p_loss + q_loss) / 2 + return loss diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py index 7fba673d25..a6bc0cb73f 100644 --- a/fairseq/criterions/speech_to_speech_criterion.py +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -3,7 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import math +from collections import OrderedDict + import torch from fairseq import metrics, utils @@ -12,21 +15,35 @@ from fairseq.criterions.label_smoothed_cross_entropy import ( LabelSmoothedCrossEntropyCriterion, LabelSmoothedCrossEntropyCriterionConfig, + duplicate_input, ) from fairseq.criterions.tacotron2_loss import ( Tacotron2Criterion, Tacotron2CriterionConfig, ) +logger = logging.getLogger(__name__) + class MultitaskCriterion: - def __init__(self, multitask_tasks): - self.multitask_criterion = {} - self.multitask_loss_weight = {} + def __init__(self, multitask_tasks, rdrop_alpha=0.0): + self.rdrop_alpha = rdrop_alpha + self.rdrop_alpha_mtl = rdrop_alpha + + self.multitask_criterion = OrderedDict() + self.multitask_loss_weight = OrderedDict() for task_name, task_obj in multitask_tasks.items(): + rdrop_alpha_task = task_obj.args.rdrop_alpha + if rdrop_alpha_task is None: + rdrop_alpha_task = rdrop_alpha + self.rdrop_alpha_mtl = rdrop_alpha_task + logger.info(f"rdrop_alpha is set to {rdrop_alpha_task}") + if task_obj.args.decoder_type == "ctc": self.multitask_criterion[task_name] = CtcCriterion( - task_obj.args.criterion_cfg, task_obj + task_obj.args.criterion_cfg, + task_obj, + rdrop_alpha=rdrop_alpha_task, ) else: self.multitask_criterion[ @@ -35,6 +52,7 @@ def __init__(self, multitask_tasks): task_obj, task_obj.args.criterion_cfg.sentence_avg, label_smoothing=task_obj.args.criterion_cfg.label_smoothing, + rdrop_alpha=rdrop_alpha_task, ) def set_multitask_loss_weight(self, task_name, weight=0.0): @@ -47,8 +65,15 @@ def get_multitask_loss(self, model, sample, model_out): layer_id = task_criterion.task.args.input_layer if isinstance(task_criterion, CtcCriterion): if task_criterion.task.args.input_from == "encoder": - non_padding_mask = ~model_out["encoder_padding_mask"][0] - input_lengths = non_padding_mask.long().sum(-1) + if len(model_out["encoder_padding_mask"]) > 0: + non_padding_mask = ~model_out["encoder_padding_mask"][0] + input_lengths = non_padding_mask.long().sum(-1) + else: + out = model_out["encoder_states"][layer_id] + input_lengths = out.new_full( + (out.shape[1],), out.shape[0] + ).long() + task_sample = { "net_input": { "src_tokens": model_out["encoder_states"][ @@ -82,8 +107,12 @@ def get_multitask_loss(self, model, sample, model_out): for key in ["target", "target_lengths", "ntokens"]: task_sample[key] = sample["multitask"][task_name][key] + if task_name == getattr(model, "mt_task_name", None): + decoder_out = model_out["mt_decoder_out"] + else: + decoder_out = None task_loss, task_sample_size, task_logging_output = task_criterion( - model.multitask_decoders[task_name], task_sample + model.multitask_decoders[task_name], task_sample, net_output=decoder_out ) loss = loss + self.multitask_loss_weight[task_name] * task_loss @@ -134,24 +163,36 @@ def __init__( task, sentence_avg, label_smoothing, + rdrop_alpha, ignore_prefix_size=0, report_accuracy=False, ): super().__init__( - task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy + task, + sentence_avg, + label_smoothing, + rdrop_alpha, + ignore_prefix_size, + report_accuracy, ) - MultitaskCriterion.__init__(self, task.multitask_tasks) + MultitaskCriterion.__init__(self, task.multitask_tasks, rdrop_alpha) def forward(self, model, sample, reduce=True): - net_output, extra = model( - src_tokens=sample["net_input"]["src_tokens"], - src_lengths=sample["net_input"]["src_lengths"], - prev_output_tokens=sample["net_input"]["prev_output_tokens"], - tgt_speaker=sample["net_input"]["tgt_speaker"], - return_all_hiddens=True, - ) + net_input_concat = { + "src_tokens": sample["net_input"]["src_tokens"], + "src_lengths": sample["net_input"]["src_lengths"], + "prev_output_tokens": sample["net_input"]["prev_output_tokens"], + "tgt_speaker": sample["net_input"].get("tgt_speaker", None), + "return_all_hiddens": True, + } - loss, nll_loss = self.compute_loss(model, [net_output], sample, reduce=reduce) + if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0: + net_input_concat = duplicate_input(net_input_concat) + + net_output, extra = model(**net_input_concat) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, [net_output], sample, reduce=reduce + ) sample_size = ( sample["target"].size(0) if self.sentence_avg else sample["ntokens"] ) @@ -166,6 +207,8 @@ def forward(self, model, sample, reduce=True): n_correct, total = self.compute_accuracy(model, [net_output], sample) logging_output["n_correct"] = utils.item(n_correct.data) logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) if len(self.multitask_criterion) == 0: return loss, sample_size, logging_output @@ -208,6 +251,82 @@ def logging_outputs_can_be_summed() -> bool: return False +@register_criterion( + "speech_to_unit_translatotron2", dataclass=LabelSmoothedCrossEntropyCriterionConfig +) +class SpeechToUnitTranslatotron2MultitaskTaskCriterion( + SpeechToUnitMultitaskTaskCriterion +): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + rdrop_alpha, + ignore_prefix_size=0, + report_accuracy=False, + ): + super().__init__( + task, + sentence_avg, + label_smoothing, + rdrop_alpha, + ignore_prefix_size, + report_accuracy, + ) + + def forward(self, model, sample, reduce=True): + net_input_concat = { + "src_tokens": sample["net_input"]["src_tokens"], + "src_lengths": sample["net_input"]["src_lengths"], + "prev_output_tokens": sample["net_input"]["prev_output_tokens"], + "prev_output_tokens_mt": sample["multitask"][model.mt_task_name][ + "net_input" + ]["prev_output_tokens"], + "tgt_speaker": sample["net_input"].get("tgt_speaker", None), + "return_all_hiddens": True, + } + if getattr(model, "asr_task_name", None) is not None: + net_input_concat["prev_output_tokens_asr"] = sample["multitask"][ + model.asr_task_name + ]["net_input"]["prev_output_tokens"] + + if self.rdrop_alpha > 0 or self.rdrop_alpha_mtl > 0: + net_input_concat = duplicate_input(net_input_concat) + + net_output, extra = model(**net_input_concat) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, [net_output], sample, reduce=reduce + ) + + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, [net_output], sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + + return loss, sample_size, logging_output + + @register_criterion("speech_to_spectrogram", dataclass=Tacotron2CriterionConfig) class SpeechToSpectrogramMultitaskTaskCriterion(Tacotron2Criterion, MultitaskCriterion): def __init__( @@ -308,3 +427,112 @@ def reduce_metrics(cls, logging_outputs) -> None: return MultitaskCriterion.reduce_metrics(logging_outputs) + + +@register_criterion( + "speech_to_spectrogram_translatotron2", dataclass=Tacotron2CriterionConfig +) +class SpeechToSpectrogramTranslatotron2MultitaskTaskCriterion( + Tacotron2Criterion, MultitaskCriterion +): + def __init__( + self, + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ): + super().__init__( + task, + sentence_avg, + use_guided_attention_loss, + guided_attention_loss_sigma, + bce_pos_weight, + ctc_weight, + ) + MultitaskCriterion.__init__(self, task.multitask_tasks) + + def forward(self, model, sample, reduction="mean"): + bsz, max_len, _ = sample["target"].size() + feat_tgt = sample["target"] + feat_len = sample["target_lengths"].view(bsz, 1).expand(-1, max_len) + eos_tgt = torch.arange(max_len).to(sample["target"].device) + eos_tgt = eos_tgt.view(1, max_len).expand(bsz, -1) + eos_tgt = (eos_tgt == (feat_len - 1)).float() + + feat_out, eos_out, extra = model( + src_tokens=sample["net_input"]["src_tokens"], + src_lengths=sample["net_input"]["src_lengths"], + prev_output_tokens=sample["net_input"]["prev_output_tokens"], + prev_output_tokens_mt=sample["multitask"][model.mt_task_name]["net_input"][ + "prev_output_tokens" + ], + tgt_speaker=sample["net_input"]["tgt_speaker"], + target_lengths=sample["target_lengths"], + return_all_hiddens=True, + ) + + l1_loss, mse_loss, eos_loss = self.compute_loss( + extra["feature_out"], + feat_out, + eos_out, + feat_tgt, + eos_tgt, + sample["target_lengths"], + reduction, + ) + attn_loss = torch.tensor(0.0).type_as(l1_loss) + if self.guided_attn is not None: + attn_loss = self.guided_attn( + extra["attn"], + sample["net_input"]["src_lengths"], + sample["target_lengths"], + reduction, + ) + loss = ( + l1_loss + mse_loss + eos_loss + attn_loss + ) # do not include ctc loss as there's no text target + + sample_size = sample["nsentences"] if self.sentence_avg else sample["ntokens"] + logging_output = { + "loss": utils.item(loss.data), + "ntokens": sample["ntokens"], + "nsentences": sample["nsentences"], + "sample_size": sample_size, + "l1_loss": utils.item(l1_loss.data), + "mse_loss": utils.item(mse_loss.data), + "eos_loss": utils.item(eos_loss.data), + "attn_loss": utils.item(attn_loss.data), + } + + if len(self.multitask_criterion) == 0: + return loss, sample_size, logging_output + + # multitask + multitask_loss, multitask_log = self.get_multitask_loss(model, sample, extra) + loss += multitask_loss + logging_output["multitask"] = multitask_log + return loss, sample_size, logging_output + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + super().reduce_metrics(logging_outputs) + + # inference metrics + if "targ_frames" in logging_outputs[0]: + n = sum(log.get("norm_frames", 0) for log in logging_outputs) + for key, new_key in [ + ("mcd_loss", "mcd_loss"), + ("pred_frames", "pred_ratio"), + ("nins", "ins_rate"), + ("ndel", "del_rate"), + ]: + val = sum(log.get(key, 0) for log in logging_outputs) + metrics.log_scalar(new_key, val / n, n, round=3) + + if "multitask" not in logging_outputs[0]: + return + + MultitaskCriterion.reduce_metrics(logging_outputs) diff --git a/fairseq/data/audio/data_cfg.py b/fairseq/data/audio/data_cfg.py index fba36dfcf0..c79f4f85df 100644 --- a/fairseq/data/audio/data_cfg.py +++ b/fairseq/data/audio/data_cfg.py @@ -297,3 +297,17 @@ def get_loss_weight(self, num_updates): loss_weight_min, ) return weight + + @property + def prepend_bos_and_append_tgt_lang_tag(self) -> bool: + """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining).""" + return self.config.get("prepend_bos_and_append_tgt_lang_tag", False) + + @property + def eos_token(self): + """EOS token during generation""" + return self.config.get("eos_token", "") + + @property + def rdrop_alpha(self): + return self.config.get("rdrop_alpha", 0.0) diff --git a/fairseq/data/audio/speech_to_speech_dataset.py b/fairseq/data/audio/speech_to_speech_dataset.py index 4b7f8b6824..833bcedc54 100644 --- a/fairseq/data/audio/speech_to_speech_dataset.py +++ b/fairseq/data/audio/speech_to_speech_dataset.py @@ -12,11 +12,12 @@ from fairseq.data import ConcatDataset, Dictionary from fairseq.data import data_utils as fairseq_data_utils -from fairseq.data.audio.data_cfg import S2SDataConfig from fairseq.data.audio.audio_utils import get_features_or_waveform +from fairseq.data.audio.data_cfg import S2SDataConfig from fairseq.data.audio.speech_to_text_dataset import ( SpeechToTextDataset, SpeechToTextDatasetCreator, + TextTargetMultitaskData, _collate_frames, ) @@ -231,57 +232,6 @@ def collater( return out -class TextTargetMultitaskData(object): - # mandatory columns - KEY_ID, KEY_TEXT = "id", "tgt_text" - - def __init__(self, args, split, tgt_dict): - samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split) - self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples} - self.dict = tgt_dict - self.append_eos = args.decoder_type != "ctc" - - def get(self, sample_id): - if sample_id in self.data: - return self.dict.encode_line( - self.data[sample_id], - add_if_not_exist=False, - append_eos=self.append_eos, - ) - else: - logger.warning(f"no target for {sample_id}") - return torch.IntTensor([]) - - def collater(self, samples: List[torch.Tensor]) -> torch.Tensor: - out = fairseq_data_utils.collate_tokens( - samples, - self.dict.pad(), - self.dict.eos(), - left_pad=False, - move_eos_to_beginning=False, - ).long() - - prev_out = fairseq_data_utils.collate_tokens( - samples, - self.dict.pad(), - self.dict.eos(), - left_pad=False, - move_eos_to_beginning=True, - ).long() - - target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long) - ntokens = sum(t.size(0) for t in samples) - - output = { - "prev_output_tokens": prev_out, - "target": out, - "target_lengths": target_lengths, - "ntokens": ntokens, - } - - return output - - class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset): def __init__(self, *argv): super().__init__(*argv) @@ -297,8 +247,9 @@ def __getitem__( multitask_target = {} sample_id = self.ids[index] + tgt_lang = self.tgt_langs[index] for task_name, task_dataset in self.multitask_data.items(): - multitask_target[task_name] = task_dataset.get(sample_id) + multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang) return s2s_data, multitask_target diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index 53fd2ea203..536405c57a 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -4,13 +4,13 @@ # LICENSE file in the root directory of this source tree. import csv -import io import logging import re +from argparse import Namespace from collections import defaultdict from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import numpy as np import torch @@ -18,6 +18,7 @@ from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset from fairseq.data import data_utils as fairseq_data_utils +from fairseq.data import encoders from fairseq.data.audio.audio_utils import get_features_or_waveform from fairseq.data.audio.data_cfg import S2TDataConfig from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform @@ -316,6 +317,165 @@ def prefetch(self, indices): raise False +class TextTargetMultitaskData(object): + # mandatory columns + KEY_ID, KEY_TEXT = "id", "tgt_text" + LANG_TAG_TEMPLATE = "" + LANG_TAG_MAPPING = { + "": "[en_XX]", + "": "[es_XX]", + "": "[ru_RU]", + } # FIXME: make this optional + + def __init__(self, args, split, tgt_dict): + samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split) + self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples} + self.dict = tgt_dict + self.append_eos = args.decoder_type != "ctc" + self.pre_tokenizer = self.build_tokenizer(args) + self.bpe_tokenizer = self.build_bpe(args) + self.prepend_bos_and_append_tgt_lang_tag = ( + args.prepend_bos_and_append_tgt_lang_tag + ) + self.eos_token = args.eos_token + + @classmethod + def is_lang_tag(cls, token): + pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") + return re.match(pattern, token) + + @classmethod + def tokenize(cls, tokenizer, text: str): + return text if tokenizer is None else tokenizer.encode(text) + + def get_tokenized_tgt_text(self, index: int): + text = self.tokenize(self.pre_tokenizer, self.data[index]) + text = self.tokenize(self.bpe_tokenizer, text) + return text + + @classmethod + def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): + lang_tag = cls.LANG_TAG_TEMPLATE.format(lang) + lang_tag = cls.LANG_TAG_MAPPING.get(lang_tag, lang_tag) + lang_tag_idx = dictionary.index(lang_tag) + assert lang_tag_idx != dictionary.unk(), (lang, lang_tag) + return lang_tag_idx + + def build_tokenizer(self, args): + pre_tokenizer = args.config.get("pre_tokenizer") + if pre_tokenizer is not None: + logger.info(f"pre-tokenizer: {pre_tokenizer}") + return encoders.build_tokenizer(Namespace(**pre_tokenizer)) + else: + return None + + def build_bpe(self, args): + bpe_tokenizer = args.config.get("bpe_tokenizer") + if bpe_tokenizer is not None: + logger.info(f"tokenizer: {bpe_tokenizer}") + return encoders.build_bpe(Namespace(**bpe_tokenizer)) + else: + return None + + def get(self, sample_id, tgt_lang): + if sample_id in self.data: + tokenized = self.get_tokenized_tgt_text(sample_id) + target = self.dict.encode_line( + tokenized, + add_if_not_exist=False, + append_eos=self.append_eos, + ) + if self.prepend_bos_and_append_tgt_lang_tag: + bos = torch.LongTensor([self.dict.bos()]) + lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict) + assert lang_tag_idx != self.dict.unk() + lang_tag_idx = torch.LongTensor([lang_tag_idx]) + target = torch.cat((bos, target, lang_tag_idx), 0) + return target + else: + logger.warning(f"no target for {sample_id}") + return torch.IntTensor([]) + + def collater(self, samples: List[torch.Tensor]) -> torch.Tensor: + out = fairseq_data_utils.collate_tokens( + samples, + self.dict.pad(), + eos_idx=None, + left_pad=False, + move_eos_to_beginning=False, + ).long() + + prev_out = fairseq_data_utils.collate_tokens( + samples, + self.dict.pad(), + eos_idx=None, + left_pad=False, + move_eos_to_beginning=True, + ).long() + + target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long) + ntokens = sum(t.size(0) for t in samples) + + output = { + "prev_output_tokens": prev_out, + "target": out, + "target_lengths": target_lengths, + "ntokens": ntokens, + } + + return output + + +class SpeechToTextMultitaskDataset(SpeechToTextDataset): + def __init__(self, *argv): + super().__init__(*argv) + self.multitask_data = {} + + def add_multitask_dataset(self, task_name, task_data): + self.multitask_data[task_name] = task_data + + def __getitem__( + self, index: int + ) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]: + s2t_data = super().__getitem__(index) + + multitask_target = {} + sample_id = self.ids[index] + tgt_lang = self.tgt_langs[index] + for task_name, task_dataset in self.multitask_data.items(): + multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang) + + return s2t_data, multitask_target + + def collater( + self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]] + ) -> Dict: + if len(samples) == 0: + return {} + + out = super().collater([s for s, _ in samples], return_order=True) + order = out["order"] + del out["order"] + + for task_name, task_dataset in self.multitask_data.items(): + if "multitask" not in out: + out["multitask"] = {} + d = [s[task_name] for _, s in samples] + task_target = task_dataset.collater(d) + out["multitask"][task_name] = { + "target": task_target["target"].index_select(0, order), + "target_lengths": task_target["target_lengths"].index_select(0, order), + "ntokens": task_target["ntokens"], + } + out["multitask"][task_name]["net_input"] = { + "prev_output_tokens": task_target["prev_output_tokens"].index_select( + 0, order + ), + } + + return out + + class SpeechToTextDatasetCreator(object): # mandatory columns KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" @@ -338,6 +498,7 @@ def _from_list( bpe_tokenizer, n_frames_per_step, speaker_to_id, + multitask: Optional[Dict] = None, ) -> SpeechToTextDataset: audio_root = Path(cfg.audio_root) ids = [s[cls.KEY_ID] for s in samples] @@ -348,25 +509,39 @@ def _from_list( speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] - return SpeechToTextDataset( + + has_multitask = len(multitask) > 0 + dataset_cls = ( + SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset + ) + + ds = dataset_cls( split_name, is_train_split, cfg, audio_paths, n_frames, - src_texts=src_texts, - tgt_texts=tgt_texts, - speakers=speakers, - src_langs=src_langs, - tgt_langs=tgt_langs, - ids=ids, - tgt_dict=tgt_dict, - pre_tokenizer=pre_tokenizer, - bpe_tokenizer=bpe_tokenizer, - n_frames_per_step=n_frames_per_step, - speaker_to_id=speaker_to_id, + src_texts, + tgt_texts, + speakers, + src_langs, + tgt_langs, + ids, + tgt_dict, + pre_tokenizer, + bpe_tokenizer, + n_frames_per_step, + speaker_to_id, ) + if has_multitask: + for task_name, task_obj in multitask.items(): + task_data = TextTargetMultitaskData( + task_obj.args, split_name, task_obj.target_dictionary + ) + ds.add_multitask_dataset(task_name, task_data) + return ds + @classmethod def get_size_ratios( cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0 @@ -431,6 +606,7 @@ def _from_tsv( bpe_tokenizer, n_frames_per_step, speaker_to_id, + multitask: Optional[Dict] = None, ) -> SpeechToTextDataset: samples = cls._load_samples_from_tsv(root, split) return cls._from_list( @@ -443,6 +619,7 @@ def _from_tsv( bpe_tokenizer, n_frames_per_step, speaker_to_id, + multitask, ) @classmethod @@ -459,6 +636,7 @@ def from_tsv( seed: int, n_frames_per_step: int = 1, speaker_to_id=None, + multitask: Optional[Dict] = None, ) -> SpeechToTextDataset: datasets = [ cls._from_tsv( @@ -471,6 +649,7 @@ def from_tsv( bpe_tokenizer, n_frames_per_step, speaker_to_id, + multitask, ) for split in splits.split(",") ] diff --git a/fairseq/data/audio/text_to_speech_dataset.py b/fairseq/data/audio/text_to_speech_dataset.py index 27e52df1a3..13612b458b 100644 --- a/fairseq/data/audio/text_to_speech_dataset.py +++ b/fairseq/data/audio/text_to_speech_dataset.py @@ -5,21 +5,22 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory.abs -from pathlib import Path -from typing import List, Dict, Optional, Any from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional import numpy as np import torch +from fairseq.data import Dictionary +from fairseq.data import data_utils as fairseq_data_utils from fairseq.data.audio.audio_utils import get_features_or_waveform from fairseq.data.audio.speech_to_text_dataset import ( + S2TDataConfig, SpeechToTextDataset, SpeechToTextDatasetCreator, - S2TDataConfig, _collate_frames, ) -from fairseq.data import Dictionary, data_utils as fairseq_data_utils @dataclass @@ -196,6 +197,7 @@ def _from_list( bpe_tokenizer, n_frames_per_step, speaker_to_id, + multitask=None, ) -> TextToSpeechDataset: audio_root = Path(cfg.audio_root) ids = [s[cls.KEY_ID] for s in samples] diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index 3079101db3..5fdfab38d3 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -811,6 +811,10 @@ class GenerationConfig(FairseqDataclass): default=5, metadata={"help": "beam size"}, ) + beam_mt: int = field( + default=0, + metadata={"help": "beam size for the first-pass decoder"}, + ) nbest: int = field( default=1, metadata={"help": "number of hypotheses to output"}, @@ -827,6 +831,18 @@ class GenerationConfig(FairseqDataclass): "help": "generate sequences of maximum length ax + b, where x is the source length" }, ) + max_len_a_mt: float = field( + default=0, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder" + }, + ) + max_len_b_mt: int = field( + default=200, + metadata={ + "help": "generate sequences of maximum length ax + b, where x is the source length for the first-pass decoder" + }, + ) min_len: int = field( default=1, metadata={"help": "minimum generation length"}, @@ -853,6 +869,12 @@ class GenerationConfig(FairseqDataclass): "help": "length penalty: <1.0 favors shorter, >1.0 favors longer sentences" }, ) + lenpen_mt: float = field( + default=1, + metadata={ + "help": "length penalty for the first-pass decoder: <1.0 favors shorter, >1.0 favors longer sentences" + }, + ) unkpen: float = field( default=0, metadata={ diff --git a/fairseq/models/speech_to_speech/__init__.py b/fairseq/models/speech_to_speech/__init__.py index 41be5e75c6..d3105bf429 100644 --- a/fairseq/models/speech_to_speech/__init__.py +++ b/fairseq/models/speech_to_speech/__init__.py @@ -4,5 +4,6 @@ # LICENSE file in the root directory of this source tree. from .modules import * # noqa -from .s2s_transformer import * # noqa from .s2s_conformer import * # noqa +from .s2s_conformer_t2 import * # noqa +from .s2s_transformer import * # noqa diff --git a/fairseq/models/speech_to_speech/s2s_conformer.py b/fairseq/models/speech_to_speech/s2s_conformer.py index a232412cc5..7f1d49c8bc 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer.py +++ b/fairseq/models/speech_to_speech/s2s_conformer.py @@ -5,23 +5,19 @@ import logging from pathlib import Path + import torch from fairseq import checkpoint_utils -from fairseq.models import ( - register_model, - register_model_architecture, -) -from fairseq.data.audio.data_cfg import S2SDataConfig -from fairseq.models.speech_to_text import S2TConformerEncoder -from fairseq.models.speech_to_speech import ( +from fairseq.models import register_model, register_model_architecture +from fairseq.models.speech_to_speech.s2s_transformer import ( + S2SpecTTransformerModel, S2UTTransformerModel, - s2ut_architecture_base as s2ut_transformer_architecture_base, -) -from fairseq.models.transformer import ( - Linear, + s2spect_architecture_base, + s2ut_architecture_base, ) - +from fairseq.models.speech_to_text import S2TConformerEncoder +from fairseq.models.transformer import Linear logger = logging.getLogger(__name__) @@ -63,25 +59,62 @@ class S2UTConformerModel(S2UTTransformerModel): @staticmethod def add_args(parser): S2UTTransformerModel.add_args(parser) - parser.add_argument("--depthwise-conv-kernel-size", default=31) + parser.add_argument("--depthwise-conv-kernel-size", type=int, default=31) parser.add_argument( "--attn-type", + type=str, default=None, help="If not specified uses fairseq MHA. Other valid option is espnet for using conformer", ) parser.add_argument( "--pos-enc-type", + type=str, default="abs", help="Must be specified in addition to attn-type=espnet for rel_pos and rope", ) @classmethod def build_encoder(cls, args): - print(args) - data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) - args.input_feat_per_channel = data_cfg.input_feat_per_channel - args.input_channels = data_cfg.input_transformed_channels + encoder = S2SConformerEncoder(args) + pretraining_path = getattr(args, "load_pretrained_encoder_from", None) + if pretraining_path is not None: + if not Path(pretraining_path).exists(): + logger.warning( + f"skipped pretraining because {pretraining_path} does not exist" + ) + else: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=pretraining_path + ) + logger.info(f"loaded pretrained encoder from: {pretraining_path}") + return encoder + + +@register_model("s2spect_conformer") +class S2SpecTConformerModel(S2SpecTTransformerModel): + """ + Direct speech-to-speech translation model with S2T Conformer encoder + TTS Transformer decoder + """ + + @staticmethod + def add_args(parser): + S2SpecTTransformerModel.add_args(parser) + parser.add_argument("--depthwise-conv-kernel-size", type=int, default=31) + parser.add_argument( + "--attn-type", + type=str, + default=None, + help="If not specified uses fairseq MHA. Other valid option is espnet for using conformer", + ) + parser.add_argument( + "--pos-enc-type", + type=str, + default="abs", + help="Must be specified in addition to attn-type=espnet for rel_pos and rope", + ) + @classmethod + def build_encoder(cls, args): encoder = S2SConformerEncoder(args) pretraining_path = getattr(args, "load_pretrained_encoder_from", None) if pretraining_path is not None: @@ -98,9 +131,27 @@ def build_encoder(cls, args): @register_model_architecture("s2ut_conformer", "s2ut_conformer") -def s2ut_base_architecture(args): +def s2ut_conformer_architecture_base(args): + args.attn_type = getattr(args, "attn_type", None) + args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.input_channels = getattr(args, "input_channels", 1) + args.max_source_positions = getattr(args, "max_source_positions", 6000) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) + s2ut_architecture_base(args) + + +@register_model_architecture("s2spect_conformer", "s2spect_conformer") +def s2spect_conformer_architecture_base(args): args.attn_type = getattr(args, "attn_type", None) args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) + args.input_channels = getattr(args, "input_channels", 1) args.max_source_positions = getattr(args, "max_source_positions", 6000) args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) @@ -108,4 +159,17 @@ def s2ut_base_architecture(args): args.dropout = getattr(args, "dropout", 0.1) args.encoder_layers = getattr(args, "encoder_layers", 16) args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) - s2ut_transformer_architecture_base(args) + s2spect_architecture_base(args) + + +@register_model_architecture("s2spect_conformer", "s2spect_conformer_fisher") +def s2spect_architecture_fisher(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + + # decoder + args.prenet_dim = getattr(args, "prenet_dim", 32) + + s2spect_conformer_architecture_base(args) diff --git a/fairseq/models/speech_to_speech/s2s_conformer_t2.py b/fairseq/models/speech_to_speech/s2s_conformer_t2.py new file mode 100644 index 0000000000..cab15344e5 --- /dev/null +++ b/fairseq/models/speech_to_speech/s2s_conformer_t2.py @@ -0,0 +1,711 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging +from typing import Any, Dict, List, Optional + +import torch.nn as nn +from torch import Tensor + +from fairseq.models import ( + FairseqEncoder, + FairseqEncoderModel, + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_speech.modules import CTCDecoder +from fairseq.models.speech_to_speech.s2s_conformer import ( + S2SpecTConformerModel, + S2UTConformerModel, +) +from fairseq.models.speech_to_speech.s2s_transformer import ( + TransformerUnitDecoder, + base_multitask_text_transformer_decoder_arch, + s2spect_architecture_base, + s2ut_architecture_base, +) +from fairseq.models.text_to_speech import TTSTransformerDecoder +from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase +from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder +from fairseq.modules import LayerNorm, TransformerEncoderLayer + +logger = logging.getLogger(__name__) + + +def multitask_text_transformer_decoder_arch( + args, decoder_layers, decoder_embed_dim=256, decoder_attention_heads=4 +): + args.decoder_layers = decoder_layers + args.decoder_embed_dim = decoder_embed_dim + args.decoder_attention_heads = decoder_attention_heads + base_multitask_text_transformer_decoder_arch(args) + + +@register_model("unity_conformer") +class UnitYConformerModel(S2UTConformerModel): + """ + Direct speech-to-speech translation model with S2T Conformer encoder + MT Transformer decoder + Transformer discrete unit decoder + """ + + @staticmethod + def add_args(parser): + S2UTConformerModel.add_args(parser) + parser.add_argument( + "--translation-decoder-layers", + type=int, + default=4, + metavar="N", + help="num decoder layers in the first-pass translation module", + ) + parser.add_argument( + "--synthesizer", + default="transformer", + choices=["transformer"], + help="", + ) + parser.add_argument( + "--synthesizer-encoder-layers", + type=int, + default=0, + metavar="N", + help="num encoder layers in the second-pass synthesizer module", + ) + parser.add_argument( + "--synthesizer-augmented-cross-attention", + action="store_true", + default=False, + help="augmented cross-attention over speech encoder output", + ) + + @classmethod + def build_multitask_decoder( + cls, + args, + tgt_dict, + in_dim, + is_mt_decoder, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ): + decoder_args = args.decoder_args + decoder_args.encoder_embed_dim = in_dim + if args.decoder_type == "transformer": + if is_mt_decoder: + multitask_text_transformer_decoder_arch( + decoder_args, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ) # 4L + else: + base_multitask_text_transformer_decoder_arch(decoder_args) # 2L + task_decoder = TransformerDecoder( + decoder_args, + tgt_dict, + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) + elif args.decoder_type == "ctc": + task_decoder = CTCDecoder( + dictionary=tgt_dict, + in_dim=in_dim, + ) + else: + raise NotImplementedError( + "currently only support multitask decoder_type 'transformer', 'ctc'" + ) + + return task_decoder + + @classmethod + def build_decoder(cls, args, tgt_dict, aug_attn=False): + from fairseq.models.speech_to_speech.modules import StackedEmbedding + + num_embeddings = len(tgt_dict) + padding_idx = tgt_dict.pad() + embed_tokens = StackedEmbedding( + num_embeddings, + args.decoder_embed_dim, + padding_idx, + num_stacked=args.n_frames_per_step, + ) + + _args = copy.deepcopy(args) + _args.encoder_embed_dim = args.decoder_embed_dim + + decoder_cls = AugTransformerUnitDecoder if aug_attn else TransformerUnitDecoder + return decoder_cls( + _args, + tgt_dict, + embed_tokens, + ) + + @classmethod + def build_model(cls, args, task): + encoder = cls.build_encoder(args) + decoder = cls.build_decoder( + args, + task.target_dictionary, + aug_attn=getattr(args, "synthesizer_augmented_cross_attention", False), + ) + base_model = cls(encoder, decoder) + + base_model.t2u_augmented_cross_attn = getattr( + args, "synthesizer_augmented_cross_attention", False + ) + + # set up multitask decoders + is_mt_decoder = False + base_model.mt_task_name = None + base_model.multitask_decoders = {} + n_aux_tasks = len(list(task.multitask_tasks.items())) + for i, (task_name, task_obj) in enumerate(task.multitask_tasks.items()): + if i == n_aux_tasks - 1: + is_mt_decoder = True + base_model.mt_task_name = task_name + assert "target" in task_name + assert task_obj.args.decoder_type == "transformer" + # NOTE: we assume that the last task is for the first-pass decoder + + in_dim = ( + args.encoder_embed_dim + if task_obj.args.input_from == "encoder" + else args.decoder_embed_dim + ) + task_decoder = cls.build_multitask_decoder( + task_obj.args, + task_obj.target_dictionary, + in_dim, + is_mt_decoder, + getattr(args, "translation_decoder_layers", 4), + getattr(args, "decoder_embed_dim", 256), + getattr(args, "decoder_attention_heads", 4), + ) + + setattr(base_model, f"{task_name}_decoder", task_decoder) + decoder_model_cls = ( + FairseqEncoderModel + if task_obj.args.decoder_type == "ctc" + else FairseqLanguageModel + ) + base_model.multitask_decoders[task_name] = decoder_model_cls( + getattr(base_model, f"{task_name}_decoder") + ) + + assert is_mt_decoder, "set at least one intermediate non-CTC decoder" + + # set up encoder on top of the auxiliary MT decoder + if getattr(args, "synthesizer_encoder_layers", 0) > 0: + base_model.synthesizer_encoder = cls.build_text_encoder(args) + else: + base_model.synthesizer_encoder = None + + return base_model + + @classmethod + def build_text_encoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_layers = args.synthesizer_encoder_layers + _args.encoder_embed_dim = args.decoder_embed_dim + _args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim + _args.encoder_attention_heads = args.decoder_attention_heads + _args.encoder_normalize_before = True + return TransformerEncoderNoEmb(_args) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + prev_output_tokens_mt, + tgt_speaker=None, + return_all_hiddens=False, + ): + mt_decoder = getattr(self, f"{self.mt_task_name}_decoder") + + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + tgt_speaker=tgt_speaker, + return_all_hiddens=return_all_hiddens, + ) + + # 1. MT decoder + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + ) + x = mt_decoder_out[1]["inner_states"][-1] + if mt_decoder.layer_norm is not None: + x = mt_decoder.layer_norm(x) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. T2U encoder + if hasattr(self, "synthesizer_encoder"): + t2u_encoder_out = self.synthesizer_encoder( + x, + mt_decoder_padding_mask, + return_all_hiddens=return_all_hiddens, + ) + else: + t2u_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask], # B x T + } + + # 3. T2U decoder + if self.t2u_augmented_cross_attn: + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + encoder_out2=t2u_encoder_out, + ) + else: + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=t2u_encoder_out, + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"] + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + decoder_out[-1]["mt_decoder_out"] = mt_decoder_out + return decoder_out + + +@register_model("spect2_conformer") +class SpecT2ConformerModel(S2SpecTConformerModel): + """ + Direct speech-to-speech translation model with S2T Conformer encoder + MT Transformer decoder + TTS Transformer decoder + """ + + @staticmethod + def add_args(parser): + S2SpecTConformerModel.add_args(parser) + parser.add_argument( + "--translation-decoder-layers", + type=int, + default=4, + metavar="N", + help="num decoder layers in the first-pass translation module", + ) + parser.add_argument( + "--synthesizer", + default="transformer", + choices=["transformer"], + help="", + ) + parser.add_argument( + "--synthesizer-encoder-layers", + type=int, + default=0, + metavar="N", + help="num encoder layers in the second-pass synthesizer module", + ) + + @classmethod + def build_multitask_decoder( + cls, + args, + tgt_dict, + in_dim, + is_mt_decoder, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ): + decoder_args = args.decoder_args + decoder_args.encoder_embed_dim = in_dim + if args.decoder_type == "transformer": + if is_mt_decoder: + multitask_text_transformer_decoder_arch( + decoder_args, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ) # 4L + else: + base_multitask_text_transformer_decoder_arch(decoder_args) # 2L + task_decoder = TransformerDecoder( + decoder_args, + tgt_dict, + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) + elif args.decoder_type == "ctc": + task_decoder = CTCDecoder( + dictionary=tgt_dict, + in_dim=in_dim, + ) + else: + raise NotImplementedError( + "currently only support multitask decoder_type 'transformer', 'ctc'" + ) + + return task_decoder + + @classmethod + def build_decoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_embed_dim = args.decoder_embed_dim + + if args.synthesizer == "transformer": + return TTSTransformerDecoder(_args, None, padding_idx=1) + else: + raise NotImplementedError(args.synthesizer) + + @classmethod + def build_model(cls, args, task): + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args) + base_model = cls(encoder, decoder) + + # set up multitask decoders + is_mt_decoder = False + base_model.mt_task_name = None + base_model.multitask_decoders = {} + n_aux_tasks = len(list(task.multitask_tasks.items())) + for i, (task_name, task_obj) in enumerate(task.multitask_tasks.items()): + if i == n_aux_tasks - 1: + is_mt_decoder = True + base_model.mt_task_name = task_name + assert "target" in task_name + assert task_obj.args.decoder_type == "transformer" + # NOTE: we assume that the last task is for the first-pass decoder + + in_dim = ( + args.encoder_embed_dim + if task_obj.args.input_from == "encoder" + else args.decoder_embed_dim + ) + task_decoder = cls.build_multitask_decoder( + task_obj.args, + task_obj.target_dictionary, + in_dim, + is_mt_decoder, + getattr(args, "translation_decoder_layers", 4), + getattr(args, "decoder_embed_dim", 256), + getattr(args, "decoder_attention_heads", 4), + ) + + setattr(base_model, f"{task_name}_decoder", task_decoder) + decoder_model_cls = ( + FairseqEncoderModel + if task_obj.args.decoder_type == "ctc" + else FairseqLanguageModel + ) + base_model.multitask_decoders[task_name] = decoder_model_cls( + getattr(base_model, f"{task_name}_decoder") + ) + + assert is_mt_decoder, "set at least one intermediate non-CTC decoder" + + # set up encoder on top of the auxiliary MT decoder + if getattr(args, "synthesizer_encoder_layers", 0) > 0: + base_model.synthesizer_encoder = cls.build_text_encoder(args) + + return base_model + + @classmethod + def build_text_encoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_layers = args.synthesizer_encoder_layers + _args.encoder_embed_dim = args.decoder_embed_dim + _args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim + _args.encoder_attention_heads = args.decoder_attention_heads + _args.encoder_normalize_before = True + return TransformerEncoderNoEmb(_args) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + prev_output_tokens_mt, + tgt_speaker=None, + incremental_state=None, + target_lengths=None, + speaker=None, + return_all_hiddens=False, + ): + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + tgt_speaker=tgt_speaker, + return_all_hiddens=return_all_hiddens, + ) + + # 1. MT decoder + mt_decoder = getattr(self, f"{self.mt_task_name}_decoder") + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + ) + x = mt_decoder_out[1]["inner_states"][-1] + if mt_decoder.layer_norm is not None: + x = mt_decoder.layer_norm(x) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. TTS encoder + if hasattr(self, "synthesizer_encoder"): + tts_encoder_out = self.synthesizer_encoder( + x, + mt_decoder_padding_mask, + return_all_hiddens=return_all_hiddens, + ) + else: + tts_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask], # B x T + } + + # 3. TTS decoder + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=tts_encoder_out, + incremental_state=incremental_state, + target_lengths=target_lengths, + speaker=speaker, + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"] + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + decoder_out[-1]["mt_decoder_out"] = mt_decoder_out + return decoder_out + + +class TransformerEncoderNoEmb(FairseqEncoder): + """Transformer encoder without token embeddings.""" + + def __init__(self, args): + super().__init__(None) + + self.layers = nn.ModuleList( + [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + def forward(self, x, encoder_padding_mask, return_all_hiddens=False): + + encoder_states = [] + + for layer in self.layers: + x = layer(x, encoder_padding_mask) + if return_all_hiddens: + encoder_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask is not None and encoder_padding_mask.any() + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + def reorder_encoder_out(self, encoder_out, new_order): + new_encoder_out = ( + [] + if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] + ) + + new_encoder_padding_mask = ( + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] + ) + + new_encoder_embedding = ( + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] + ) + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } + + +class AugTransformerUnitDecoder(AugTransformerDecoder): + """Based on Transformer decoder, with support to decoding stacked units""" + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn, output_projection + ) + self.n_frames_per_step = args.n_frames_per_step + + self.out_proj_n_frames = ( + Linear( + self.output_embed_dim, + self.output_embed_dim * self.n_frames_per_step, + bias=False, + ) + if self.n_frames_per_step > 1 + else None + ) + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + encoder_out2: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + encoder_out2=encoder_out2, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + + if not features_only: + bsz, seq_len, d = x.size() + if self.out_proj_n_frames: + x = self.out_proj_n_frames(x) + x = self.output_layer(x.view(bsz, seq_len, self.n_frames_per_step, d)) + x = x.view(bsz, seq_len * self.n_frames_per_step, -1) + if ( + incremental_state is None and self.n_frames_per_step > 1 + ): # teacher-forcing mode in training + x = x[ + :, : -(self.n_frames_per_step - 1), : + ] # remove extra frames after + + return x, extra + + def upgrade_state_dict_named(self, state_dict, name): + if self.n_frames_per_step > 1: + move_keys = [ + ( + f"{name}.project_in_dim.weight", + f"{name}.embed_tokens.project_in_dim.weight", + ) + ] + for from_k, to_k in move_keys: + if from_k in state_dict and to_k not in state_dict: + state_dict[to_k] = state_dict[from_k] + del state_dict[from_k] + + +@register_model_architecture(model_name="unity_conformer", arch_name="unity_conformer") +def unity_conformer_architecture_base(args): + args.attn_type = getattr(args, "attn_type", None) + args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.max_source_positions = getattr(args, "max_source_positions", 6000) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) + s2ut_architecture_base(args) + + +@register_model_architecture( + model_name="unity_conformer", arch_name="s2ut_conformer_translatotron2" +) +def unity_conformer_architecture_base_legacy(args): + unity_conformer_architecture_base(args) + + +@register_model_architecture( + model_name="spect2_conformer", arch_name="spect2_conformer" +) +def translatotron2_conformer_architecture_base(args): + args.attn_type = getattr(args, "attn_type", None) + args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.max_source_positions = getattr(args, "max_source_positions", 6000) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) + s2spect_architecture_base(args) + + +@register_model_architecture( + model_name="spect2_conformer", arch_name="s2spect_conformer_translatotron2" +) +def translatotron2_conformer_architecture_base_legacy(args): + translatotron2_conformer_architecture_base(args) diff --git a/fairseq/models/speech_to_speech/s2s_transformer.py b/fairseq/models/speech_to_speech/s2s_transformer.py index a5954a83e5..5af07bb673 100644 --- a/fairseq/models/speech_to_speech/s2s_transformer.py +++ b/fairseq/models/speech_to_speech/s2s_transformer.py @@ -12,21 +12,16 @@ from fairseq import checkpoint_utils, utils from fairseq.models import ( - FairseqEncoderModel, FairseqEncoderDecoderModel, + FairseqEncoderModel, FairseqLanguageModel, register_model, register_model_architecture, ) -from fairseq.models.speech_to_text import S2TTransformerEncoder from fairseq.models.speech_to_speech.modules import CTCDecoder, StackedEmbedding +from fairseq.models.speech_to_text import S2TTransformerEncoder from fairseq.models.text_to_speech import TTSTransformerDecoder -from fairseq.models.transformer import ( - Linear, - TransformerDecoder, - TransformerModelBase, -) - +from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase logger = logging.getLogger(__name__) @@ -260,6 +255,13 @@ def add_args(parser): metavar="N", help="# of channels in Conv1d subsampling layers", ) + parser.add_argument( + "--conv-version", + type=str, + default="s2t_transformer", + choices=["s2t_transformer", "convtransformer"], + help="version of frontend convolutional layers", + ) # Transformer parser.add_argument( "--activation-fn", @@ -435,6 +437,13 @@ def add_args(parser): metavar="N", help="# of channels in Conv1d subsampling layers", ) + parser.add_argument( + "--conv-version", + type=str, + default="s2t_transformer", + choices=["s2t_transformer", "convtransformer"], + help="version of frontend convolutional layers", + ) # Transformer parser.add_argument( "--activation-fn", @@ -604,8 +613,10 @@ def base_s2st_transformer_encoder_architecture(args): args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0) # Convolutional subsampler + args.input_channels = getattr(args, "input_channels", 1) args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_channels = getattr(args, "conv_channels", 1024) + args.conv_version = getattr(args, "conv_version", "s2t_transformer") # Transformer args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) diff --git a/fairseq/models/speech_to_text/__init__.py b/fairseq/models/speech_to_text/__init__.py index f49c88e563..62ef663efb 100644 --- a/fairseq/models/speech_to_text/__init__.py +++ b/fairseq/models/speech_to_text/__init__.py @@ -6,7 +6,8 @@ from .berard import * # noqa from .convtransformer import * # noqa from .multi_modality_model import * # noqa +from .s2t_conformer import * # noqa from .s2t_transformer import * # noqa from .s2t_wav_transformer import * # noqa from .xm_transformer import * # noqa -from .s2t_conformer import * # noqa +from .xm_transformer_unity import * # noqa diff --git a/fairseq/models/speech_to_text/convtransformer.py b/fairseq/models/speech_to_text/convtransformer.py index eba000d7b0..29dd49cec0 100644 --- a/fairseq/models/speech_to_text/convtransformer.py +++ b/fairseq/models/speech_to_text/convtransformer.py @@ -1,4 +1,7 @@ -#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import logging import math @@ -7,6 +10,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor + from fairseq import checkpoint_utils, utils from fairseq.data.data_utils import lengths_to_padding_mask from fairseq.models import ( @@ -15,9 +20,9 @@ register_model, register_model_architecture, ) +from fairseq.models.speech_to_text.modules.convolution import infer_conv_output_dim from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerEncoderLayer -from torch import Tensor logger = logging.getLogger(__name__) @@ -251,7 +256,7 @@ def __init__(self, args): ), torch.nn.ReLU(), ) - transformer_input_dim = self.infer_conv_output_dim( + transformer_input_dim = infer_conv_output_dim( self.in_channels, self.input_dim, args.conv_out_channels ) self.out = torch.nn.Linear(transformer_input_dim, args.encoder_embed_dim) @@ -274,16 +279,6 @@ def __init__(self, args): def pooling_ratio(self): return 4 - def infer_conv_output_dim(self, in_channels, input_dim, out_channels): - sample_seq_len = 200 - sample_bsz = 10 - x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) - x = torch.nn.Conv2d(1, out_channels, 3, stride=2, padding=3 // 2)(x) - x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x) - x = x.transpose(1, 2) - mb, seq = x.size()[:2] - return x.contiguous().view(mb, seq, -1).size(-1) - def forward(self, src_tokens, src_lengths): """Encode input sequence. :param torch.Tensor xs: input tensor diff --git a/fairseq/models/speech_to_text/modules/convolution.py b/fairseq/models/speech_to_text/modules/convolution.py new file mode 100644 index 0000000000..526d7540c5 --- /dev/null +++ b/fairseq/models/speech_to_text/modules/convolution.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch +import torch.nn as nn + + +class Conv1dSubsampler(nn.Module): + """Convolutional subsampler: a stack of 1D convolution (along temporal + dimension) followed by non-linear activation via gated linear units + (https://arxiv.org/abs/1911.08460) + + Args: + in_channels (int): the number of input channels + mid_channels (int): the number of intermediate channels + out_channels (int): the number of output channels + kernel_sizes (List[int]): the kernel size for each convolutional layer + """ + + def __init__( + self, + in_channels: int, + mid_channels: int, + out_channels: int, + kernel_sizes: List[int] = (3, 3), + ): + super(Conv1dSubsampler, self).__init__() + self.n_layers = len(kernel_sizes) + self.conv_layers = nn.ModuleList( + nn.Conv1d( + in_channels if i == 0 else mid_channels // 2, + mid_channels if i < self.n_layers - 1 else out_channels * 2, + k, + stride=2, + padding=k // 2, + ) + for i, k in enumerate(kernel_sizes) + ) + + def get_out_seq_lens_tensor(self, in_seq_lens_tensor): + out = in_seq_lens_tensor.clone() + for _ in range(self.n_layers): + out = ((out.float() - 1) / 2 + 1).floor().long() + return out + + def forward(self, src_tokens, src_lengths): + bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D) + x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T + for conv in self.conv_layers: + x = conv(x) + x = nn.functional.glu(x, dim=1) + _, _, out_seq_len = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D) + return x, self.get_out_seq_lens_tensor(src_lengths) + + +def infer_conv_output_dim(in_channels, input_dim, out_channels): + sample_seq_len = 200 + sample_bsz = 10 + x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) + x = torch.nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=3 // 2)(x) + x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x) + x = x.transpose(1, 2) + mb, seq = x.size()[:2] + return x.contiguous().view(mb, seq, -1).size(-1) + + +class Conv2dSubsampler(nn.Module): + """Convolutional subsampler: a stack of 2D convolution based on ESPnet implementation + (https://github.com/espnet/espnet) + + Args: + input_channels (int): the number of input channels + input_feat_per_channel (int): encoder input dimension per input channel + conv_out_channels (int): the number of output channels of conv layer + encoder_embed_dim (int): encoder dimentions + """ + + def __init__( + self, + input_channels: int, + input_feat_per_channel: int, + conv_out_channels: int, + encoder_embed_dim: int, + ): + super().__init__() + assert input_channels == 1, input_channels + self.conv = torch.nn.Sequential( + torch.nn.Conv2d( + input_channels, conv_out_channels, 3, stride=2, padding=3 // 2 + ), + torch.nn.ReLU(), + torch.nn.Conv2d( + conv_out_channels, + conv_out_channels, + 3, + stride=2, + padding=3 // 2, + ), + torch.nn.ReLU(), + ) + transformer_input_dim = infer_conv_output_dim( + input_channels, input_feat_per_channel, conv_out_channels + ) + self.out = torch.nn.Linear(transformer_input_dim, encoder_embed_dim) + + def forward(self, src_tokens, src_lengths): + B, T_i, C = src_tokens.size() + x = src_tokens.view(B, T_i, 1, C).transpose(1, 2).contiguous() + x = self.conv(x) + B, _, T_o, _ = x.size() + x = x.transpose(1, 2).transpose(0, 1).contiguous().view(T_o, B, -1) + x = self.out(x) + + subsampling_factor = int(T_i * 1.0 / T_o + 0.5) + input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long() + input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to( + input_len_0.device + ) + input_lengths = torch.min(input_len_0, input_len_1) + return x, input_lengths diff --git a/fairseq/models/speech_to_text/s2t_conformer.py b/fairseq/models/speech_to_text/s2t_conformer.py index fbac61d5a7..15603d3be6 100644 --- a/fairseq/models/speech_to_text/s2t_conformer.py +++ b/fairseq/models/speech_to_text/s2t_conformer.py @@ -1,16 +1,28 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + import logging +import math +from pathlib import Path + import torch + +from fairseq import checkpoint_utils +from fairseq.data.data_utils import lengths_to_padding_mask +from fairseq.models import FairseqEncoder, register_model, register_model_architecture +from fairseq.models.speech_to_text.modules.convolution import ( + Conv1dSubsampler, + Conv2dSubsampler, +) from fairseq.models.speech_to_text.s2t_transformer import ( S2TTransformerEncoder, S2TTransformerModel, - Conv1dSubsampler, - base_architecture as transformer_base_architecture, + base_architecture, ) -from fairseq.data.data_utils import lengths_to_padding_mask -from fairseq.modules.conformer_layer import ConformerEncoderLayer -from fairseq.models import FairseqEncoder, register_model_architecture, register_model from fairseq.modules import PositionalEmbedding, RelPositionalEncoding -import math +from fairseq.modules.conformer_layer import ConformerEncoderLayer logger = logging.getLogger(__name__) @@ -20,16 +32,29 @@ class S2TConformerEncoder(FairseqEncoder): def __init__(self, args): super().__init__(None) + + self.encoder_freezing_updates = args.encoder_freezing_updates + self.num_updates = 0 + self.embed_scale = math.sqrt(args.encoder_embed_dim) if args.no_scale_embedding: self.embed_scale = 1.0 self.padding_idx = 1 - self.subsample = Conv1dSubsampler( - args.input_feat_per_channel * args.input_channels, - args.conv_channels, - args.encoder_embed_dim, - [int(k) for k in args.conv_kernel_sizes.split(",")], - ) + self.conv_version = args.conv_version + if self.conv_version == "s2t_transformer": + self.subsample = Conv1dSubsampler( + args.input_feat_per_channel * args.input_channels, + args.conv_channels, + args.encoder_embed_dim, + [int(k) for k in args.conv_kernel_sizes.split(",")], + ) + elif self.conv_version == "convtransformer": + self.subsample = Conv2dSubsampler( + args.input_channels, + args.input_feat_per_channel, + 256, + args.encoder_embed_dim, + ) self.pos_enc_type = args.pos_enc_type if self.pos_enc_type == "rel_pos": self.embed_positions = RelPositionalEncoding( @@ -61,7 +86,7 @@ def __init__(self, args): ] ) - def forward(self, src_tokens, src_lengths, return_all_hiddens=False): + def _forward(self, src_tokens, src_lengths, return_all_hiddens=False): """ Args: src_tokens: Input source tokens Tensor of shape B X T X C @@ -110,10 +135,30 @@ def forward(self, src_tokens, src_lengths, return_all_hiddens=False): "src_lengths": [], } + def forward(self, src_tokens, src_lengths, return_all_hiddens=False): + if self.num_updates < self.encoder_freezing_updates: + with torch.no_grad(): + x = self._forward( + src_tokens, + src_lengths, + return_all_hiddens=return_all_hiddens, + ) + else: + x = self._forward( + src_tokens, + src_lengths, + return_all_hiddens=return_all_hiddens, + ) + return x + def reorder_encoder_out(self, encoder_out, new_order): """Required method for a FairseqEncoder. Calls the method from the parent class""" return S2TTransformerEncoder.reorder_encoder_out(self, encoder_out, new_order) + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + self.num_updates = num_updates + @register_model("s2t_conformer") class S2TConformerModel(S2TTransformerModel): @@ -123,16 +168,18 @@ def __init__(self, encoder, decoder): @staticmethod def add_args(parser): S2TTransformerModel.add_args(parser) - parser.add_argument("--input-feat-per-channel", default=80) - parser.add_argument("--depthwise-conv-kernel-size", default=31) - parser.add_argument("--input-channels", default=1) + parser.add_argument("--input-feat-per-channel", type=int, default=80) + parser.add_argument("--depthwise-conv-kernel-size", type=int, default=31) + parser.add_argument("--input-channels", type=int, default=1) parser.add_argument( "--attn-type", + type=str, default=None, help="If not specified uses fairseq MHA. Other valid option is espnet", ) parser.add_argument( "--pos-enc-type", + type=str, default="abs", help="Must be specified in addition to attn-type=espnet for rel_pos and rope", ) @@ -140,11 +187,22 @@ def add_args(parser): @classmethod def build_encoder(cls, args): encoder = S2TConformerEncoder(args) + pretraining_path = getattr(args, "load_pretrained_encoder_from", None) + if pretraining_path is not None: + if not Path(pretraining_path).exists(): + logger.warning( + f"skipped pretraining because {pretraining_path} does not exist" + ) + else: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=pretraining_path + ) + logger.info(f"loaded pretrained encoder from: {pretraining_path}") return encoder @register_model_architecture("s2t_conformer", "s2t_conformer") -def base_architecture(args): +def conformer_base_architecture(args): args.attn_type = getattr(args, "attn_type", None) args.pos_enc_type = getattr(args, "pos_enc_type", "abs") args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) @@ -156,4 +214,4 @@ def base_architecture(args): args.dropout = getattr(args, "dropout", 0.1) args.encoder_layers = getattr(args, "encoder_layers", 16) args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) - transformer_base_architecture(args) + base_architecture(args) diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 4b43e1acb5..99c06878af 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -1,4 +1,7 @@ -#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import logging import math @@ -18,6 +21,10 @@ register_model_architecture, ) from fairseq.models.speech_to_text.hub_interface import S2THubInterface +from fairseq.models.speech_to_text.modules.convolution import ( + Conv1dSubsampler, + Conv2dSubsampler, +) from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.modules import ( FairseqDropout, @@ -29,55 +36,6 @@ logger = logging.getLogger(__name__) -class Conv1dSubsampler(nn.Module): - """Convolutional subsampler: a stack of 1D convolution (along temporal - dimension) followed by non-linear activation via gated linear units - (https://arxiv.org/abs/1911.08460) - - Args: - in_channels (int): the number of input channels - mid_channels (int): the number of intermediate channels - out_channels (int): the number of output channels - kernel_sizes (List[int]): the kernel size for each convolutional layer - """ - - def __init__( - self, - in_channels: int, - mid_channels: int, - out_channels: int, - kernel_sizes: List[int] = (3, 3), - ): - super(Conv1dSubsampler, self).__init__() - self.n_layers = len(kernel_sizes) - self.conv_layers = nn.ModuleList( - nn.Conv1d( - in_channels if i == 0 else mid_channels // 2, - mid_channels if i < self.n_layers - 1 else out_channels * 2, - k, - stride=2, - padding=k // 2, - ) - for i, k in enumerate(kernel_sizes) - ) - - def get_out_seq_lens_tensor(self, in_seq_lens_tensor): - out = in_seq_lens_tensor.clone() - for _ in range(self.n_layers): - out = ((out.float() - 1) / 2 + 1).floor().long() - return out - - def forward(self, src_tokens, src_lengths): - bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D) - x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T - for conv in self.conv_layers: - x = conv(x) - x = nn.functional.glu(x, dim=1) - _, _, out_seq_len = x.size() - x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D) - return x, self.get_out_seq_lens_tensor(src_lengths) - - @register_model("s2t_transformer") class S2TTransformerModel(FairseqEncoderDecoderModel): """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for @@ -136,6 +94,13 @@ def add_args(parser): metavar="N", help="# of channels in Conv1d subsampling layers", ) + parser.add_argument( + "--conv-version", + type=str, + default="s2t_transformer", + choices=["s2t_transformer", "convtransformer"], + help="version of frontend convolutional layers", + ) # Transformer parser.add_argument( "--activation-fn", @@ -339,12 +304,21 @@ def __init__(self, args): self.embed_scale = 1.0 self.padding_idx = 1 - self.subsample = Conv1dSubsampler( - args.input_feat_per_channel * args.input_channels, - args.conv_channels, - args.encoder_embed_dim, - [int(k) for k in args.conv_kernel_sizes.split(",")], - ) + self.conv_version = args.conv_version + if self.conv_version == "s2t_transformer": + self.subsample = Conv1dSubsampler( + args.input_feat_per_channel * args.input_channels, + args.conv_channels, + args.encoder_embed_dim, + [int(k) for k in args.conv_kernel_sizes.split(",")], + ) + elif self.conv_version == "convtransformer": + self.subsample = Conv2dSubsampler( + args.input_channels, + args.input_feat_per_channel, + 256, + args.encoder_embed_dim, + ) self.embed_positions = PositionalEmbedding( args.max_source_positions, args.encoder_embed_dim, self.padding_idx @@ -474,8 +448,10 @@ def extract_features( def base_architecture(args): args.encoder_freezing_updates = getattr(args, "encoder_freezing_updates", 0) # Convolutional subsampler + args.input_channels = getattr(args, "input_channels", 1) args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") args.conv_channels = getattr(args, "conv_channels", 1024) + args.conv_version = getattr(args, "conv_version", "s2t_transformer") # Transformer args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) diff --git a/fairseq/models/speech_to_text/xm_transformer.py b/fairseq/models/speech_to_text/xm_transformer.py index c82dea9ba4..5151063f66 100644 --- a/fairseq/models/speech_to_text/xm_transformer.py +++ b/fairseq/models/speech_to_text/xm_transformer.py @@ -20,14 +20,25 @@ register_model, register_model_architecture, ) +from fairseq.models.speech_to_speech.modules import CTCDecoder from fairseq.models.speech_to_text.hub_interface import S2THubInterface -from fairseq.models.transformer import Embedding, TransformerDecoder +from fairseq.models.transformer import ( + Embedding, + TransformerDecoder, + TransformerModelBase, +) from fairseq.models.wav2vec import Wav2VecEncoder from fairseq.modules.layer_norm import LayerNorm logger = logging.getLogger(__name__) +def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + return Embedding(num_embeddings, embed_dim, padding_idx) + + class Conv1dAdaptor(nn.Module): def __init__( self, @@ -249,7 +260,10 @@ def add_wav2vec_asr_args(parser): help="if set, then the weight-norm (in one pos_conv layer) is removed from the model", ) parser.add_argument( - "--encoder-embed-dim", type=int, metavar="N", help="encoder embedding dimension to be used when w2v_path is None and no encoder_proj is set" + "--encoder-embed-dim", + type=int, + metavar="N", + help="encoder embedding dimension to be used when w2v_path is None and no encoder_proj is set", ) @@ -497,8 +511,7 @@ def hub_models(cls): "xm_transformer_s2ut_800m-es-en-st-asr-bt_h1_2022", "xm_transformer_s2ut_800m-en-es-st_plus_asr", "xm_transformer_s2ut_800m-hk-en-h1_2022", - "xm_transformer_s2ut_800m-en-hk-h1_2022" - + "xm_transformer_s2ut_800m-en-hk-h1_2022", ] return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} @@ -514,6 +527,7 @@ def from_pretrained( **kwargs, ): from fairseq import hub_utils + x = hub_utils.from_pretrained( model_name_or_path, checkpoint_file, @@ -557,7 +571,9 @@ def build_encoder(cls, args): if args.w2v_path: state = checkpoint_utils.load_checkpoint_to_cpu(args.w2v_path) if state.get("cfg") is not None: - encoder_embed_dim = state["cfg"]._content["model"]["encoder_embed_dim"] + encoder_embed_dim = state["cfg"]._content["model"][ + "encoder_embed_dim" + ] elif state.get("args") is not None: encoder_embed_dim = state["args"].encoder_embed_dim else: @@ -607,6 +623,7 @@ def build_decoder(cls, args, task, embed_tokens): _args.dropout = args.decoder_dropout _args.attention_dropout = args.decoder_attention_dropout _args.activation_dropout = args.decoder_activation_dropout + _args.layerdrop = _args.decoder_layerdrop decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens) decoder = cls.maybe_load_pretrained( @@ -623,15 +640,10 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_architecture(args) - if getattr(args, "load_pretrained_decoder_from", None): - ckpt = torch.load(getattr(args, "load_pretrained_decoder_from", None)) - decoder_args_dict = cls.get_decoder_args_from_checkpoint(ckpt["cfg"]) - args = cls.override_decoder_args(args, decoder_args_dict) - - def build_embedding(dictionary, embed_dim): - num_embeddings = len(dictionary) - padding_idx = dictionary.pad() - return Embedding(num_embeddings, embed_dim, padding_idx) + # if getattr(args, "load_pretrained_decoder_from", None) is not None: + # ckpt = torch.load(getattr(args, "load_pretrained_decoder_from", None)) + # decoder_args_dict = cls.get_decoder_args_from_checkpoint(ckpt["cfg"]) + # args = cls.override_decoder_args(args, decoder_args_dict) decoder_embed_tokens = build_embedding( task.target_dictionary, args.decoder_embed_dim @@ -641,6 +653,37 @@ def build_embedding(dictionary, embed_dim): decoder = cls.build_decoder(args, task, decoder_embed_tokens) return cls(encoder, decoder) + @classmethod + def build_multitask_decoder(cls, args, tgt_dict, in_dim): + decoder_args = args.decoder_args + decoder_args.encoder_embed_dim = in_dim + if args.decoder_type == "transformer": + from fairseq.models.speech_to_speech import ( + base_multitask_text_transformer_decoder_arch, + ) + + base_multitask_text_transformer_decoder_arch(decoder_args) # 2L + task_decoder = TransformerDecoder( + decoder_args, + tgt_dict, + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) + elif args.decoder_type == "ctc": + task_decoder = CTCDecoder( + dictionary=tgt_dict, + in_dim=in_dim, + ) + else: + raise NotImplementedError( + "currently only support multitask decoder_type 'transformer', 'ctc'" + ) + + return task_decoder + def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], diff --git a/fairseq/models/speech_to_text/xm_transformer_unity.py b/fairseq/models/speech_to_text/xm_transformer_unity.py new file mode 100644 index 0000000000..752bee7602 --- /dev/null +++ b/fairseq/models/speech_to_text/xm_transformer_unity.py @@ -0,0 +1,312 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging + +from fairseq.models import ( + FairseqEncoderModel, + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_text.xm_transformer import XMTransformerModel +from fairseq.models.speech_to_text.xm_transformer import ( + base_architecture as xm_t_base_architecture, +) +from fairseq.models.speech_to_text.xm_transformer import ( + build_embedding, + need_finetuning, + set_default_adaptor_args, + set_default_general_args, + set_default_transformer_decoder_args, + set_default_w2v_encoder_args, +) +from fairseq.models.transformer import Linear, TransformerDecoder +from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder + +logger = logging.getLogger(__name__) + + +def unit_transformer_decoder_arch_base( + args, decoder_layers=6, decoder_embed_dim=768, decoder_attention_heads=12 +): + args.encoder_layers = decoder_layers + args.decoder_layers = decoder_layers + args.decoder_embed_dim = decoder_embed_dim + args.decoder_ffn_embed_dim = decoder_embed_dim * 4 + args.decoder_attention_heads = decoder_attention_heads + args.encoder_embed_dim = args.decoder_embed_dim + args.decoder_output_dim = decoder_embed_dim + args.decoder_input_dim = decoder_embed_dim + + +def unit_transformer_decoder_arch_large( + args, decoder_layers=12, decoder_embed_dim=1024, decoder_attention_heads=16 +): + args.encoder_layers = decoder_layers + args.decoder_layers = decoder_layers + args.decoder_embed_dim = decoder_embed_dim + args.decoder_ffn_embed_dim = decoder_embed_dim * 4 + args.decoder_attention_heads = decoder_attention_heads + args.encoder_embed_dim = args.decoder_embed_dim + args.decoder_output_dim = decoder_embed_dim + args.decoder_input_dim = decoder_embed_dim + + +@register_model("unity_xm_transformer") +class UnitYXMTransformerModel(XMTransformerModel): + @classmethod + def hub_models(cls): + base_url = "http://dl.fbaipublicfiles.com/fairseq/s2t" + model_ids = [] + return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} + + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + XMTransformerModel.add_args(parser) + parser.add_argument( + "--translation-decoder-layers", + type=int, + default=4, + metavar="N", + help="num decoder layers in the first-pass translation module", + ) + parser.add_argument( + "--synthesizer-encoder-layers", + type=int, + default=0, + metavar="N", + help="num encoder layers in the second-pass synthesizer module", + ) + parser.add_argument( + "--synthesizer-augmented-cross-attention", + action="store_true", + default=False, + help="augmented cross-attention over speech encoder output", + ) + parser.add_argument( + "--load-pretrained-aux-decoder-from", + type=str, + metavar="STR", + help="model to take decoder weights from (for initialization)", + ) + + @classmethod + def build_text_decoder(cls, args, task): + _args = copy.deepcopy(args) + + if args.adaptor_proj or args.encoder_proj: # not V0 arch + _args.encoder_embed_dim = _args.decoder_embed_dim + _args.dropout = args.decoder_dropout + _args.attention_dropout = args.decoder_attention_dropout + _args.activation_dropout = args.decoder_activation_dropout + _args.layerdrop = _args.decoder_layerdrop + _args.decoder_layers = _args.translation_decoder_layers + + embed_tokens = build_embedding(task.target_dictionary, _args.decoder_embed_dim) + decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens) + + if getattr(args, "load_pretrained_aux_decoder_from", None) is not None: + decoder = cls.maybe_load_pretrained( + decoder, getattr(args, "load_pretrained_aux_decoder_from", None) + ) + + for k, p in decoder.named_parameters(): + p.requires_grad = need_finetuning(args.finetune_decoder_params, k) + return decoder + + @classmethod + def build_decoder(cls, args, task, aug_attn=False): + _args = copy.deepcopy(args) + _args.layerdrop = 0.0 # turn off layerdrop for shallow layers + + _args.encoder_embed_dim = args.decoder_embed_dim + + proj = None + if args.decoder_embed_dim != _args.decoder_embed_dim: + proj = Linear(args.decoder_embed_dim, _args.decoder_embed_dim) + + embed_tokens = build_embedding(task.target_dictionary, _args.decoder_embed_dim) + decoder_cls = AugTransformerDecoder if aug_attn else TransformerDecoder + decoder = decoder_cls(_args, task.target_dictionary, embed_tokens) + + if getattr(args, "load_pretrained_decoder_from", None) is not None: + # load all layers first and then discard the bottom layers + embed_tokens = build_embedding( + task.target_dictionary, _args.decoder_embed_dim + ) + decoder_tmp = decoder_cls(_args, task.target_dictionary, embed_tokens) + decoder_tmp = cls.maybe_load_pretrained( + decoder_tmp, getattr(_args, "load_pretrained_decoder_from", None) + ) + state_dict = decoder_tmp.state_dict() + for k, p in decoder.named_parameters(): + p.data = state_dict[k].data + p.requires_grad = need_finetuning(_args.finetune_decoder_params, k) + decoder.layers = decoder.layers[-_args.decoder_layers :] + + return decoder, proj, _args + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + + # make sure all arguments are present in older models + xm_t_base_architecture(args) + + encoder = cls.build_encoder(args) + decoder, proj, unit_args = cls.build_decoder( + args, + task, + aug_attn=getattr(args, "synthesizer_augmented_cross_attention", False), + ) + base_model = cls(encoder, decoder) + setattr(base_model, "proj", proj) + + base_model.t2u_augmented_cross_attn = getattr( + args, "synthesizer_augmented_cross_attention", False + ) + + # set up multitask decoders + base_model.mt_task_name = None + base_model.multitask_decoders = {} + n_aux_tasks = len(list(task.multitask_tasks.items())) + for i, (task_name, task_obj) in enumerate(task.multitask_tasks.items()): + + if i < n_aux_tasks - 1: + task_decoder = cls.build_multitask_decoder( + task_obj.args, task_obj.target_dictionary, args.decoder_embed_dim + ) + else: + base_model.mt_task_name = task_name + assert "target" in task_name + assert task_obj.args.decoder_type == "transformer" + # NOTE: we assume that the last task is for the first-pass decoder + + task_decoder = cls.build_text_decoder(args, task_obj) + + setattr(base_model, f"{task_name}_decoder", task_decoder) + decoder_model_cls = ( + FairseqEncoderModel + if task_obj.args.decoder_type == "ctc" + else FairseqLanguageModel + ) + base_model.multitask_decoders[task_name] = decoder_model_cls( + getattr(base_model, f"{task_name}_decoder") + ) + + # set up encoder on top of the auxiliary MT decoder + if getattr(args, "synthesizer_encoder_layers", 0) > 0: + base_model.synthesizer_encoder = cls.build_t2u_encoder(unit_args) + else: + base_model.synthesizer_encoder = None + + return base_model + + @classmethod + def build_t2u_encoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_layers = _args.synthesizer_encoder_layers + _args.encoder_embed_dim = args.decoder_embed_dim + _args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim + _args.encoder_attention_heads = args.decoder_attention_heads + _args.encoder_normalize_before = True + + from fairseq.models.speech_to_speech import TransformerEncoderNoEmb + + return TransformerEncoderNoEmb(_args) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + prev_output_tokens_mt, + return_all_hiddens=False, + tgt_speaker=None, + **kwargs, + ): + """ + The forward method inherited from the base class has a **kwargs + argument in its input, which is not supported in torchscript. This + method overwrites the forward method definition without **kwargs. + """ + encoder_out = self.encoder( + src_tokens=src_tokens, src_lengths=src_lengths, **kwargs + ) + + # 1. MT decoder + mt_decoder = getattr(self, f"{self.mt_task_name}_decoder") + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + ) + x = mt_decoder_out[1]["inner_states"][-1] + if mt_decoder.layer_norm is not None: + x = mt_decoder.layer_norm(x) + if self.proj is not None: + x = self.proj(x) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. T2U encoder + if self.synthesizer_encoder is not None: + t2u_encoder_out = self.synthesizer_encoder( + x, + mt_decoder_padding_mask, + ) + else: + t2u_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask], # B x T + } + + # 3. T2U decoder + if self.t2u_augmented_cross_attn: + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=encoder_out, + encoder_out2=t2u_encoder_out, + ) + else: + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=t2u_encoder_out, + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_out"] + # NOTE: from the top layer + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + decoder_out[-1]["mt_decoder_out"] = mt_decoder_out + return decoder_out + + +@register_model_architecture( + model_name="unity_xm_transformer", arch_name="unity_xm_transformer" +) +def base_architecture_unity(args): + set_default_general_args(args) + set_default_w2v_encoder_args(args) + set_default_adaptor_args(args) + set_default_transformer_decoder_args(args) + + args.layernorm_embedding = False + args.decoder_learned_pos = False + + +@register_model_architecture( + model_name="unity_xm_transformer", arch_name="xm_transformer_t2" +) +def base_architecture_legacy(args): + base_architecture_unity(args) diff --git a/fairseq/models/transformer/transformer_decoder_aug.py b/fairseq/models/transformer/transformer_decoder_aug.py new file mode 100644 index 0000000000..0a35db13c7 --- /dev/null +++ b/fairseq/models/transformer/transformer_decoder_aug.py @@ -0,0 +1,391 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import utils +from fairseq.distributed import fsdp_wrap +from fairseq.models.transformer import TransformerConfig +from fairseq.models.transformer.transformer_decoder import TransformerDecoderBase +from fairseq.modules import ( + LayerDropModuleList, + SinusoidalPositionalEmbedding, + transformer_layer_aug, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper + + +class AugTransformerDecoderBase(TransformerDecoderBase): + """ + Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + cfg, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + encoder_attn_merge_type="sequential", + dropnet_ratio=0, + ): + super().__init__( + cfg, + dictionary, + embed_tokens, + no_encoder_attn=no_encoder_attn, + output_projection=output_projection, + ) + # assert cfg.cross_self_attention + self.cross_self_attention = cfg.cross_self_attention + + if self.decoder_layerdrop > 0.0: + self.layers = LayerDropModuleList(p=self.decoder_layerdrop) + else: + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + self.build_decoder_layer( + cfg, no_encoder_attn, encoder_attn_merge_type, dropnet_ratio + ) + for _ in range(cfg.decoder.layers) + ] + ) + + def build_decoder_layer( + self, + cfg, + no_encoder_attn=False, + encoder_attn_merge_type="sequential", + dropnet_ratio=0, + ): + layer = transformer_layer_aug.AugTransformerDecoderLayerBase( + cfg, + no_encoder_attn, + encoder_attn_merge_type=encoder_attn_merge_type, + dropnet_ratio=dropnet_ratio, + ) + checkpoint = cfg.checkpoint_activations + if checkpoint: + offload_to_cpu = cfg.offload_activations + layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) + # if we are checkpointing, enforce that FSDP always wraps the + # checkpointed layer, regardless of layer size + min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 + layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) + return layer + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + encoder_out2: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + encoder_out2=encoder_out2, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + + if not features_only: + x = self.output_layer(x) + return x, extra + + def extract_features( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + encoder_out2: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + return self.extract_features_scriptable( + prev_output_tokens, + encoder_out, + encoder_out2, + incremental_state, + full_context_alignment, + alignment_layer, + alignment_heads, + ) + + """ + A scriptable subclass of this class has an extract_features method and calls + super().extract_features, but super() is not supported in torchscript. A copy of + this function is made to be used in the subclass instead. + """ + + def extract_features_scriptable( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]], + encoder_out2: Optional[Dict[str, List[Tensor]]], + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + ): + """ + Similar to *forward* but only return features. + + Includes several features from "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + alignment_layer (int, optional): return mean alignment over + heads at this layer (default: last layer). + alignment_heads (int, optional): only average alignment over + this many heads (default: all heads). + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + bs, slen = prev_output_tokens.size() + if alignment_layer is None: + alignment_layer = self.num_layers - 1 + + enc: Optional[Tensor] = None + padding_mask: Optional[Tensor] = None + if encoder_out is not None and len(encoder_out["encoder_out"]) > 0: + enc = encoder_out["encoder_out"][0] + if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: + padding_mask = encoder_out["encoder_padding_mask"][0] + + enc2: Optional[Tensor] = None + padding_mask2: Optional[Tensor] = None + if encoder_out2 is not None and len(encoder_out2["encoder_out"]) > 0: + enc2 = encoder_out2["encoder_out"][0] + if encoder_out2 is not None and len(encoder_out2["encoder_padding_mask"]) > 0: + padding_mask2 = encoder_out2["encoder_padding_mask"][0] + + # embed positions + positions = None + if self.embed_positions is not None: + positions = self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # Prevent torchscript exporting issue for dynamic quant embedding + prev_output_tokens = prev_output_tokens.contiguous() + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.quant_noise is not None: + x = self.quant_noise(x) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + + if self.layernorm_embedding is not None: + x = self.layernorm_embedding(x) + + x = self.dropout_module(x) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + self_attn_padding_mask: Optional[Tensor] = None + if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any(): + self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) + + # decoder layers + attn: Optional[Tensor] = None + attn2: Optional[Tensor] = None + inner_states: List[Optional[Tensor]] = [x] + for idx, layer in enumerate(self.layers): + if incremental_state is None and not full_context_alignment: + self_attn_mask = self.buffered_future_mask(x) + else: + self_attn_mask = None + + x, layer_attn, layer_attn2, _ = layer( + x, + enc, + padding_mask, + enc2, + padding_mask2, + incremental_state, + self_attn_mask=self_attn_mask, + self_attn_padding_mask=self_attn_padding_mask, + need_attn=bool((idx == alignment_layer)), + need_head_weights=bool((idx == alignment_layer)), + ) + inner_states.append(x) + if layer_attn is not None and idx == alignment_layer: + attn = layer_attn.float().to(x) + if layer_attn2 is not None and idx == alignment_layer: + attn2 = layer_attn2.float().to(x) + + if attn is not None: + if alignment_heads is not None: + attn = attn[:alignment_heads] + + # average probabilities over heads + attn = attn.mean(dim=0) + + if attn2 is not None: + if alignment_heads is not None: + attn2 = attn2[:alignment_heads] + + # average probabilities over heads + attn2 = attn2.mean(dim=0) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + if self.project_out_dim is not None: + x = self.project_out_dim(x) + + return x, {"attn": [attn], "attn2": [attn2], "inner_states": inner_states} + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): + weights_key = "{}.embed_positions.weights".format(name) + if weights_key in state_dict: + del state_dict[weights_key] + state_dict[ + "{}.embed_positions._float_tensor".format(name) + ] = torch.FloatTensor(1) + + if f"{name}.output_projection.weight" not in state_dict: + if self.share_input_output_embed: + embed_out_key = f"{name}.embed_tokens.weight" + else: + embed_out_key = f"{name}.embed_out" + if embed_out_key in state_dict: + state_dict[f"{name}.output_projection.weight"] = state_dict[ + embed_out_key + ] + if not self.share_input_output_embed: + del state_dict[embed_out_key] + + for i in range(self.num_layers): + # update layer norms + layer_norm_map = { + "0": "self_attn_layer_norm", + "1": "encoder_attn_layer_norm", + "2": "encoder_attn_layer_norm2", + "3": "final_layer_norm", + } + for old, new in layer_norm_map.items(): + for m in ("weight", "bias"): + k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m) + if k in state_dict: + state_dict[ + "{}.layers.{}.{}.{}".format(name, i, new, m) + ] = state_dict[k] + del state_dict[k] + + version_key = "{}.version".format(name) + if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: + # earlier checkpoints did not normalize after the stack of layers + self.layer_norm = None + self.normalize = False + state_dict[version_key] = torch.Tensor([1]) + + return state_dict + + +class AugTransformerDecoder(AugTransformerDecoderBase): + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + self.args = args + super().__init__( + TransformerConfig.from_namespace(args), + dictionary, + embed_tokens, + no_encoder_attn=no_encoder_attn, + output_projection=output_projection, + encoder_attn_merge_type=getattr( + args, "synthesizer_augmented_cross_attention_merge_type", "sequential" + ), + dropnet_ratio=getattr(args, "dropnet_ratio", 0), + ) + + def build_output_projection(self, args, dictionary, embed_tokens): + super().build_output_projection( + TransformerConfig.from_namespace(args), dictionary, embed_tokens + ) + + def build_decoder_layer( + self, + args, + no_encoder_attn=False, + encoder_attn_merge_type="sequential", + dropnet_ratio=0, + ): + return super().build_decoder_layer( + TransformerConfig.from_namespace(args), + no_encoder_attn=no_encoder_attn, + encoder_attn_merge_type=encoder_attn_merge_type, + dropnet_ratio=dropnet_ratio, + ) diff --git a/fairseq/modules/transformer_layer_aug.py b/fairseq/modules/transformer_layer_aug.py new file mode 100644 index 0000000000..b63bdbd77f --- /dev/null +++ b/fairseq/modules/transformer_layer_aug.py @@ -0,0 +1,319 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional + +import torch +from numpy.random import uniform +from torch import Tensor + +from fairseq.modules import LayerNorm +from fairseq.modules.transformer_layer import TransformerDecoderLayerBase + + +class AugTransformerDecoderLayerBase(TransformerDecoderLayerBase): + """Decoder layer block. + + In the original paper each operation (multi-head attention, encoder + attention or FFN) is postprocessed with: `dropout -> add residual -> + layernorm`. In the tensor2tensor code they suggest that learning is more + robust when preprocessing each layer with layernorm and postprocessing with: + `dropout -> add residual`. We default to the approach in the paper, but the + tensor2tensor approach can be enabled by setting + *cfg.decoder.normalize_before* to ``True``. + + Args: + args (argparse.Namespace): parsed command-line arguments + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + cfg, + no_encoder_attn=False, + add_bias_kv=False, + add_zero_attn=False, + encoder_attn_merge_type="sequential", + dropnet_ratio=0, + ): + super().__init__( + cfg, + no_encoder_attn=no_encoder_attn, + add_bias_kv=add_bias_kv, + add_zero_attn=False, + ) + assert not no_encoder_attn + + self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) + self.encoder_attn2 = self.build_encoder_attention(self.embed_dim, cfg) + if encoder_attn_merge_type == "sequential": + self.encoder_attn_layer_norm2 = LayerNorm(self.embed_dim, export=cfg.export) + else: + self.encoder_attn_layer_norm2 = None + + self.encoder_attn_merge_type = encoder_attn_merge_type + self.dropnet_ratio = dropnet_ratio + + def forward( + self, + x, + encoder_out: Optional[torch.Tensor] = None, + encoder_padding_mask: Optional[torch.Tensor] = None, + encoder_out2: Optional[torch.Tensor] = None, + encoder_padding_mask2: Optional[torch.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + prev_self_attn_state: Optional[List[torch.Tensor]] = None, + prev_attn_state: Optional[List[torch.Tensor]] = None, + self_attn_mask: Optional[torch.Tensor] = None, + self_attn_padding_mask: Optional[torch.Tensor] = None, + need_attn: bool = False, + need_head_weights: bool = False, + ): + """ + Args: + x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_padding_mask (ByteTensor, optional): binary + ByteTensor of shape `(batch, src_len)` where padding + elements are indicated by ``1``. + need_attn (bool, optional): return attention weights + need_head_weights (bool, optional): return attention weights + for each head (default: return average over heads). + + Returns: + encoded output of shape `(seq_len, batch, embed_dim)` + """ + if need_head_weights: + need_attn = True + + residual = x + if self.normalize_before: + x = self.self_attn_layer_norm(x) + if prev_self_attn_state is not None: + prev_key, prev_value = prev_self_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_self_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] + assert incremental_state is not None + self.self_attn._set_input_buffer(incremental_state, saved_state) + _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) + if self.cross_self_attention and not ( + incremental_state is not None + and _self_attn_input_buffer is not None + and "prev_key" in _self_attn_input_buffer + ): + if self_attn_mask is not None: + assert encoder_out is not None + self_attn_mask = torch.cat( + (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 + ) + if self_attn_padding_mask is not None: + if encoder_padding_mask is None: + assert encoder_out is not None + encoder_padding_mask = self_attn_padding_mask.new_zeros( + encoder_out.size(1), encoder_out.size(0) + ) + self_attn_padding_mask = torch.cat( + (encoder_padding_mask, self_attn_padding_mask), dim=1 + ) + assert encoder_out is not None + y = torch.cat((encoder_out, x), dim=0) + else: + y = x + + x, attn = self.self_attn( + query=x, + key=y, + value=y, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + need_weights=False, + attn_mask=self_attn_mask, + ) + if self.c_attn is not None: + tgt_len, bsz = x.size(0), x.size(1) + x = x.view(tgt_len, bsz, self.nh, self.head_dim) + x = torch.einsum("tbhd,h->tbhd", x, self.c_attn) + x = x.reshape(tgt_len, bsz, self.embed_dim) + if self.attn_ln is not None: + x = self.attn_ln(x) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.self_attn_layer_norm(x) + + assert encoder_out is not None + assert encoder_out2 is not None + + if self.encoder_attn_merge_type == "sequential": + ratios = self.get_dropnet_ratio() + + # first encoder attention + if ratios[0] > 0: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + x = ratios[0] * x + + # second encoder attention + if ratios[1] > 0: + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm2(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn2._set_input_buffer(incremental_state, saved_state) + + x, attn2 = self.encoder_attn2( + query=x, + key=encoder_out2, + value=encoder_out2, + key_padding_mask=encoder_padding_mask2, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x = self.dropout_module(x) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm2(x) + x = ratios[1] * x + + elif self.encoder_attn_merge_type == "parallel": + residual = x + if self.normalize_before: + x = self.encoder_attn_layer_norm(x) + if prev_attn_state is not None: + prev_key, prev_value = prev_attn_state[:2] + saved_state: Dict[str, Optional[Tensor]] = { + "prev_key": prev_key, + "prev_value": prev_value, + } + if len(prev_attn_state) >= 3: + saved_state["prev_key_padding_mask"] = prev_attn_state[2] + assert incremental_state is not None + self.encoder_attn._set_input_buffer(incremental_state, saved_state) + + x1, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x2, attn2 = self.encoder_attn2( + query=x, + key=encoder_out2, + value=encoder_out2, + key_padding_mask=encoder_padding_mask2, + incremental_state=incremental_state, + static_kv=True, + need_weights=need_attn or (not self.training and self.need_attn), + need_head_weights=need_head_weights, + ) + x1 = self.dropout_module(x1) + x2 = self.dropout_module(x2) + ratios = self.get_dropnet_ratio() + x = ratios[0] * x1 + ratios[1] * x2 + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.encoder_attn_layer_norm(x) + + else: + raise NotImplementedError(self.encoder_attn_merge_type) + + residual = x + if self.normalize_before: + x = self.final_layer_norm(x) + + x = self.activation_fn(self.fc1(x)) + x = self.activation_dropout_module(x) + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) + x = self.fc2(x) + x = self.dropout_module(x) + if self.w_resid is not None: + residual = torch.mul(self.w_resid, residual) + x = self.residual_connection(x, residual) + if not self.normalize_before: + x = self.final_layer_norm(x) + if self.onnx_trace and incremental_state is not None: + saved_state = self.self_attn._get_input_buffer(incremental_state) + assert saved_state is not None + if self_attn_padding_mask is not None: + self_attn_state = [ + saved_state["prev_key"], + saved_state["prev_value"], + saved_state["prev_key_padding_mask"], + ] + else: + self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] + return x, attn, attn2, self_attn_state + return x, attn, attn2, None + + def get_dropnet_ratio(self): + if self.encoder_attn_merge_type == "sequential": + if self.dropnet_ratio > 0: + frand = float(uniform(0, 1)) + if frand < self.dropnet_ratio and self.training: + return [2, 0] + elif frand > 1 - self.dropnet_ratio and self.training: + return [0, 2] + else: + return [1, 1] + else: + return [1, 1] + + elif self.encoder_attn_merge_type == "parallel": + if self.dropnet_ratio > 0: + frand = float(uniform(0, 1)) + if frand < self.dropnet_ratio and self.training: + return [1, 0] + elif frand > 1 - self.dropnet_ratio and self.training: + return [0, 1] + else: + return [0.5, 0.5] + else: + return [0.5, 0.5] diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 5176f5d267..e01f2fd113 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -91,7 +91,6 @@ def __init__( ).long() self.vocab_size = len(tgt_dict) - self.beam_size = beam_size # the max beam size is the dictionary size - 1, since we never select pad self.beam_size = min(beam_size, self.vocab_size - 1) self.model.set_decoder_beam_size(self.beam_size) @@ -210,13 +209,6 @@ def _generate( constraints: Optional[Tensor] = None, bos_token: Optional[int] = None, ): - incremental_states = torch.jit.annotate( - List[Dict[str, Dict[str, Optional[Tensor]]]], - [ - torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) - for i in range(self.model.models_size) - ], - ) net_input = sample["net_input"] if "src_tokens" in net_input: @@ -245,18 +237,55 @@ def _generate( + str(net_input.keys()) ) - # bsz: total number of sentences in beam - # Note that src_tokens may have more than 2 dimensions (i.e. audio features) - bsz, src_len = src_tokens.size()[:2] - beam_size = self.beam_size - if constraints is not None and not self.search.supports_constraints: raise NotImplementedError( "Target-side constraints were provided, but search method doesn't support them" ) # Initialize constraints, when active - self.search.init_constraints(constraints, beam_size) + self.search.init_constraints(constraints, self.beam_size) + + # compute the encoder output for each beam + with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): + encoder_outs = self.model.forward_encoder(net_input) + + finalized = self.generate_decoder( + encoder_outs, + src_tokens, + src_lengths, + sample, + prefix_tokens, + constraints, + bos_token, + ) + return finalized + + def generate_decoder( + self, + encoder_outs, + src_tokens, + src_lengths, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + aux_task_name="", + encoder_outs2: Optional[Tensor] = None, + ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + + decoder_name = f"{aux_task_name}_decoder" if aux_task_name else "decoder" max_len: int = -1 if self.match_source_len: @@ -269,9 +298,6 @@ def _generate( assert ( self.min_len <= max_len ), "min_len cannot be larger than max_len, please adjust these!" - # compute the encoder output for each beam - with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): - encoder_outs = self.model.forward_encoder(net_input) # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) @@ -279,6 +305,8 @@ def _generate( encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) # ensure encoder_outs is a List. assert encoder_outs is not None + if encoder_outs2 is not None: + encoder_outs2 = self.model.reorder_encoder_out(encoder_outs2, new_order) # initialize buffers scores = ( @@ -344,10 +372,16 @@ def _generate( corr.unsqueeze(-1) * beam_size ) original_batch_idxs = original_batch_idxs[batch_idxs] - self.model.reorder_incremental_state(incremental_states, reorder_state) + self.model.reorder_incremental_state( + incremental_states, reorder_state, decoder_name + ) encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state ) + if encoder_outs2 is not None: + encoder_outs2 = self.model.reorder_encoder_out( + encoder_outs2, reorder_state + ) with torch.autograd.profiler.record_function( "EnsembleModel: forward_decoder" ): @@ -356,9 +390,11 @@ def _generate( encoder_outs, incremental_states, self.temperature, + decoder_name=decoder_name, + encoder_outs2=encoder_outs2, ) - if self.lm_model is not None: + if self.lm_model is not None and not aux_task_name: lm_out = self.lm_model(tokens[:, : step + 1]) probs = self.lm_model.get_normalized_probs( lm_out, log_probs=True, sample=None @@ -374,7 +410,7 @@ def _generate( # handle max length constraint if step >= max_len: lprobs[:, : self.eos] = -math.inf - lprobs[:, self.eos + 1:] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf # handle prefix tokens (possibly with different lengths) if ( @@ -604,7 +640,7 @@ def _prefix_tokens( if eos_mask.any(): # validate that the first beam matches the prefix first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ - :, 0, 1: step + 1 + :, 0, 1 : step + 1 ] eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] @@ -649,12 +685,12 @@ def finalize_hypos( # tokens is (batch * beam, max_len). So the index_select # gets the newly EOS rows, then selects cols 1..{step + 2} tokens_clone = tokens.index_select(0, bbsz_idx)[ - :, 1: step + 2 + :, 1 : step + 2 ] # skip the first index, which is EOS tokens_clone[:, step] = self.eos attn_clone = ( - attn.index_select(0, bbsz_idx)[:, :, 1: step + 2] + attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] if attn is not None else None ) @@ -807,23 +843,38 @@ def forward_decoder( encoder_outs: List[Dict[str, List[Tensor]]], incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], temperature: float = 1.0, + decoder_name="decoder", + encoder_outs2: List[Dict[str, List[Tensor]]] = None, ): log_probs = [] avg_attn: Optional[Tensor] = None encoder_out: Optional[Dict[str, List[Tensor]]] = None + encoder_out2: Optional[Dict[str, List[Tensor]]] = None for i, model in enumerate(self.models): if self.has_encoder(): encoder_out = encoder_outs[i] + if encoder_outs2 is not None: + encoder_out2 = encoder_outs2[i] # decode each model if self.has_incremental_states(): - decoder_out = model.decoder.forward( - tokens, - encoder_out=encoder_out, - incremental_state=incremental_states[i], - ) + if encoder_out2 is not None: + decoder_out = getattr(model, decoder_name).forward( + tokens, + encoder_out=encoder_out, + encoder_out2=encoder_out2, + incremental_state=incremental_states[i], + ) + else: + decoder_out = getattr(model, decoder_name).forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) else: - if hasattr(model, "decoder"): - decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + if hasattr(model, decoder_name): + decoder_out = getattr(model, decoder_name).forward( + tokens, encoder_out=encoder_out + ) else: decoder_out = model.forward(tokens) @@ -845,7 +896,7 @@ def forward_decoder( decoder_out[0][:, -1:, :].div_(temperature), None if decoder_len <= 1 else decoder_out[1], ) - probs = model.get_normalized_probs( + probs = getattr(model, decoder_name).get_normalized_probs( decoder_out_tuple, log_probs=True, sample=None ) probs = probs[:, -1, :] @@ -896,11 +947,12 @@ def reorder_incremental_state( self, incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], new_order, + decoder_name="decoder", ): if not self.has_incremental_states(): return for i, model in enumerate(self.models): - model.decoder.reorder_incremental_state_scripting( + getattr(model, decoder_name).reorder_incremental_state_scripting( incremental_states[i], new_order ) diff --git a/fairseq/sequence_generator_multi_decoder.py b/fairseq/sequence_generator_multi_decoder.py new file mode 100644 index 0000000000..9fae5d5f7a --- /dev/null +++ b/fairseq/sequence_generator_multi_decoder.py @@ -0,0 +1,258 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import search +from fairseq.sequence_generator import SequenceGenerator + + +class MultiDecoderSequenceGenerator(nn.Module): + def __init__( + self, + models, + tgt_dict, + tgt_dict_mt, + beam_size=1, + beam_size_mt=1, + max_len_a=0, + max_len_b=200, + max_len_a_mt=0, + max_len_b_mt=200, + max_len=0, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + len_penalty_mt=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + eos=None, + eos_mt=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + ): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length for the second pass + max_len_a_mt/b_mt (int, optional): generate sequences of maximum length + ax + b, where x is the source length for the first pass + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty in the second pass, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + len_penalty (float, optional): length penalty in the first pass, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + self.generator = SequenceGenerator( + models, + tgt_dict, + beam_size=beam_size, + max_len_a=max_len_a, + max_len_b=max_len_b, + max_len=max_len, + min_len=min_len, + normalize_scores=normalize_scores, + len_penalty=len_penalty, + unk_penalty=unk_penalty, + temperature=temperature, + match_source_len=match_source_len, + no_repeat_ngram_size=no_repeat_ngram_size, + search_strategy=search.BeamSearch(tgt_dict), + eos=eos, + symbols_to_strip_from_output=symbols_to_strip_from_output, + lm_model=lm_model, + lm_weight=lm_weight, + ) + self.eos = self.generator.eos + + self.generator_mt = SequenceGenerator( + models, + tgt_dict_mt, + beam_size=beam_size_mt, + max_len_a=max_len_a_mt, + max_len_b=max_len_b_mt, + max_len=max_len, + min_len=min_len, + normalize_scores=normalize_scores, + len_penalty=len_penalty_mt, + unk_penalty=unk_penalty, + temperature=temperature, + match_source_len=match_source_len, + no_repeat_ngram_size=no_repeat_ngram_size, + search_strategy=search.BeamSearch(tgt_dict_mt), + eos=eos_mt, + symbols_to_strip_from_output=symbols_to_strip_from_output, + ) + + @torch.no_grad() + def generate( + self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs + ) -> List[List[Dict[str, Tensor]]]: + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, **kwargs) + + def _generate( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + net_input = sample["net_input"] + + if "src_tokens" in net_input: + src_tokens = net_input["src_tokens"] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ( + (src_tokens.ne(self.generator.eos) & src_tokens.ne(self.generator.pad)) + .long() + .sum(dim=1) + ) + else: + raise Exception( + "expected src_tokens or source in net input. input keys: " + + str(net_input.keys()) + ) + + if constraints is not None and not self.generator.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.generator.search.init_constraints(constraints, self.generator.beam_size) + self.generator_mt.search.init_constraints( + constraints, self.generator_mt.beam_size + ) + + # compute the encoder output for each beam + with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): + encoder_outs = self.generator.model.forward_encoder(net_input) + + single_model = self.generator.model.single_model + mt_decoder = getattr(single_model, f"{single_model.mt_task_name}_decoder") + + # 1. MT decoder + finalized_mt = self.generator_mt.generate_decoder( + encoder_outs, + src_tokens, + src_lengths, + sample, + prefix_tokens, + constraints, + bos_token, + aux_task_name=single_model.mt_task_name, + ) + + # extract decoder output corresponding to the best hypothesis + max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt]) + prev_output_tokens_mt = ( + src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len) + .fill_(mt_decoder.padding_idx) + .int() + ) # B x T + for i, hypo in enumerate(finalized_mt): + i_beam = 0 + tmp = hypo[i_beam]["tokens"].int() # hyp + eos + prev_output_tokens_mt[i, 0] = self.generator_mt.eos + if tmp[-1] == self.generator_mt.eos: + tmp = tmp[:-1] + prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp + + text = "".join([self.generator_mt.tgt_dict[c] for c in tmp]) + text = text.replace("_", " ") + text = text.replace("▁", " ") + text = text.replace("", " ") + text = text.replace("", "") + text = text.replace("", "") + if len(text) > 0 and text[0] == " ": + text = text[1:] + sample_id = sample["id"].tolist()[i] + print("{} (None-{})".format(text, sample_id)) + + x = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_outs[0], + features_only=True, + )[0].transpose(0, 1) + + if getattr(single_model, "proj", None) is not None: + x = single_model.proj(x) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. T2U encoder + if getattr(single_model, "synthesizer_encoder", None) is not None: + t2u_encoder_out = single_model.synthesizer_encoder( + x, + mt_decoder_padding_mask, + ) + else: + t2u_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask] + if mt_decoder_padding_mask is not None + else [], # B x T + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + if getattr(single_model, "t2u_augmented_cross_attn", False): + encoder_outs2 = [t2u_encoder_out] + else: + encoder_outs = [t2u_encoder_out] + encoder_outs2 = None + + # 3. T2U decoder + finalized = self.generator.generate_decoder( + encoder_outs, + src_tokens, + src_lengths, + sample, + prefix_tokens, + constraints, + bos_token, + encoder_outs2=encoder_outs2, + ) + return finalized diff --git a/fairseq/speech_generator.py b/fairseq/speech_generator.py index 90ec914e60..8951eefbff 100644 --- a/fairseq/speech_generator.py +++ b/fairseq/speech_generator.py @@ -3,8 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import numpy as np +import torch from fairseq.data.audio.speech_to_text_dataset import S2TDataConfig @@ -76,7 +76,203 @@ def generate(self, model, sample, has_targ=False, **kwargs): incremental_state=incremental_state, target_lengths=cur_out_lens, speaker=sample["speaker"], - **kwargs + **kwargs, + ) + cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2) + feat.append(cur_extra["feature_out"]) + attn.append(cur_extra["attn"]) + eos_prob.append(cur_eos_prob) + + cur_finished = cur_eos_prob.squeeze(1) > self.eos_prob_threshold + out_lens.masked_fill_((~finished) & cur_finished, step + 1) + finished = finished | cur_finished + if finished.sum().item() == bsz: + break + prev_feat_out = cur_extra["feature_out"] + + feat = torch.cat(feat, dim=1) + feat = model.decoder.postnet(feat) + feat + eos_prob = torch.cat(eos_prob, dim=1) + attn = torch.cat(attn, dim=2) + alignment = attn.max(dim=1)[1] + + feat = feat.reshape(bsz, -1, raw_dim) + feat = self.gcmvn_denormalize(feat) + + eos_prob = eos_prob.repeat_interleave(n_frames_per_step, dim=1) + attn = attn.repeat_interleave(n_frames_per_step, dim=2) + alignment = alignment.repeat_interleave(n_frames_per_step, dim=1) + out_lens = out_lens * n_frames_per_step + + finalized = [ + { + "feature": feat[b, :out_len], + "eos_prob": eos_prob[b, :out_len], + "attn": attn[b, :, :out_len], + "alignment": alignment[b, :out_len], + "waveform": self.get_waveform(feat[b, :out_len]), + } + for b, out_len in zip(range(bsz), out_lens) + ] + + if has_targ: + assert sample["target"].size(-1) == out_dim + tgt_feats = sample["target"].view(bsz, -1, raw_dim) + tgt_feats = self.gcmvn_denormalize(tgt_feats) + tgt_lens = sample["target_lengths"] * n_frames_per_step + for b, (f, l) in enumerate(zip(tgt_feats, tgt_lens)): + finalized[b]["targ_feature"] = f[:l] + finalized[b]["targ_waveform"] = self.get_waveform(f[:l]) + return finalized + + +class Translatotron2SpeechGenerator(SpeechGenerator): + def __init__( + self, + models, + args, + vocoder, + data_cfg, + tgt_dict_mt, + max_iter: int = 6000, + eos_prob_threshold: float = 0.5, + eos_mt=None, + symbols_to_strip_from_output=None, + ): + super().__init__(models[0], vocoder, data_cfg) + self.max_iter = max_iter + self.eos_prob_threshold = eos_prob_threshold + + self.tgt_dict_mt = tgt_dict_mt + self.eos_mt = eos_mt + + from fairseq import search + from fairseq.sequence_generator import SequenceGenerator + + self.text_generator = SequenceGenerator( + models, + tgt_dict_mt, + beam_size=max(1, getattr(args, "beam", 5)), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search.BeamSearch(tgt_dict_mt), + eos=eos_mt, + symbols_to_strip_from_output=symbols_to_strip_from_output, + ) + + @torch.no_grad() + def generate(self, model, sample, has_targ=False, **kwargs): + model.eval() + + src_tokens = sample["net_input"]["src_tokens"] + src_lengths = sample["net_input"]["src_lengths"] + bsz, src_len = src_tokens.size()[:2] + n_frames_per_step = model.decoder.n_frames_per_step + out_dim = model.decoder.out_dim + raw_dim = out_dim // n_frames_per_step + + # initialize + encoder_out = model.forward_encoder( + src_tokens, src_lengths, speaker=sample["speaker"] + ) + + prefix_tokens = None + constraints = None + bos_token = None + + mt_decoder = getattr(model, f"{model.mt_task_name}_decoder") + + # 1. MT decoder + finalized_mt = self.text_generator.generate_decoder( + [encoder_out], + src_tokens, + src_lengths, + sample, + prefix_tokens, + constraints, + bos_token, + aux_task_name=model.mt_task_name, + ) + + # extract decoder output corresponding to the best hypothesis + max_tgt_len = max([len(hypo[0]["tokens"]) for hypo in finalized_mt]) + prev_output_tokens_mt = ( + src_tokens.new_zeros(src_tokens.shape[0], max_tgt_len) + .fill_(mt_decoder.padding_idx) + .int() + ) # B x T + for i, hypo in enumerate(finalized_mt): + i_beam = 0 + tmp = hypo[i_beam]["tokens"].int() # hyp + eos + prev_output_tokens_mt[i, 0] = self.text_generator.eos + if tmp[-1] == self.text_generator.eos: + tmp = tmp[:-1] + prev_output_tokens_mt[i, 1 : len(tmp) + 1] = tmp + + text = "".join([self.tgt_dict_mt[c] for c in tmp]) + text = text.replace("_", " ") + text = text.replace("▁", " ") + text = text.replace("", " ") + text = text.replace("", "") + text = text.replace("", "") + if len(text) > 0 and text[0] == " ": + text = text[1:] + sample_id = sample["id"].tolist()[i] + print("{} (None-{})".format(text, sample_id)) + + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + features_only=True, + ) + x = mt_decoder_out[0].transpose(0, 1) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. TTS encoder + if getattr(model, "synthesizer_encoder", None) is not None: + synthesizer_encoder_out = model.synthesizer_encoder( + x, + mt_decoder_padding_mask, + ) + else: + synthesizer_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask] + if mt_decoder_padding_mask is not None + else [], # B x T + "encoder_embedding": [], + "encoder_states": [], + "src_tokens": [], + "src_lengths": [], + } + + # 3. TTS decoder + incremental_state = {} + feat, attn, eos_prob = [], [], [] + finished = src_tokens.new_zeros((bsz,)).bool() + out_lens = src_lengths.new_zeros((bsz,)).long().fill_(self.max_iter) + + prev_feat_out = encoder_out["encoder_out"][0].new_zeros(bsz, 1, out_dim) + for step in range(self.max_iter): + cur_out_lens = out_lens.clone() + cur_out_lens.masked_fill_(cur_out_lens.eq(self.max_iter), step + 1) + _, cur_eos_out, cur_extra = model.forward_decoder( + prev_feat_out, + encoder_out=synthesizer_encoder_out, + incremental_state=incremental_state, + target_lengths=cur_out_lens, + speaker=sample["speaker"], + **kwargs, ) cur_eos_prob = torch.sigmoid(cur_eos_out).squeeze(2) feat.append(cur_extra["feature_out"]) diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py index d9e2325686..8da5fecbd6 100644 --- a/fairseq/tasks/speech_to_speech.py +++ b/fairseq/tasks/speech_to_speech.py @@ -8,6 +8,7 @@ import math from argparse import Namespace from pathlib import Path +from typing import List import torch import torch.nn as nn @@ -16,8 +17,12 @@ from fairseq.data import Dictionary from fairseq.data.audio.data_cfg import MultitaskConfig, S2SDataConfig from fairseq.data.audio.speech_to_speech_dataset import SpeechToSpeechDatasetCreator -from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset +from fairseq.data.audio.speech_to_text_dataset import ( + SpeechToTextDataset, + TextTargetMultitaskData, +) from fairseq.tasks import LegacyFairseqTask, register_task +from fairseq.tasks.speech_to_text import DummyMultiTask from fairseq.tasks.text_to_speech import batch_mel_cepstral_distortion logger = logging.getLogger(__name__) @@ -142,6 +147,12 @@ def add_args(cls, parser): default="config.yaml", help="Configuration YAML filename (under manifest root)", ) + parser.add_argument( + "--multitask-config-yaml", + type=str, + default=None, + help="Configuration YAML filename for the multitasks (under manifest root)", + ) parser.add_argument( "--max-source-positions", default=6000, @@ -170,12 +181,6 @@ def add_args(cls, parser): default=1, help="# stacked frames, use 0 for reduced discrete unit sequence", ) - parser.add_argument( - "--multitask-config-yaml", - type=str, - default=None, - help="Configuration YAML filename for the multitasks (under manifest root)", - ) parser.add_argument("--eval-inference", action="store_true") parser.add_argument( "--eval-args", @@ -208,15 +213,28 @@ def __init__(self, args, tgt_dict, infer_tgt_lang_id=None): super().__init__(args) self.tgt_dict = tgt_dict self.data_cfg = S2SDataConfig(Path(args.data) / args.config_yaml) + self.multitask_tasks = {} + self.tgt_dict_mt = None + self.eos_token_mt = None if getattr(args, "multitask_config_yaml", None) is not None: multitask_cfg = MultitaskConfig( Path(args.data) / args.multitask_config_yaml ) for task_name, task_config in multitask_cfg.get_all_tasks().items(): - self.multitask_tasks[task_name] = DummyMultiTask( - task_config, task_config.tgt_dict - ) + task_obj = DummyMultiTask(task_config, task_config.tgt_dict) + self.multitask_tasks[task_name] = task_obj + if "target" in task_name and task_obj.args.decoder_type != "ctc": + self.tgt_dict_mt = task_obj.target_dictionary + if task_config.prepend_bos_and_append_tgt_lang_tag: + self.eos_token_mt = task_config.eos_token + assert not isinstance(self.eos_token_mt, List) + + if not self.eos_token_mt: + raise Warning( + "Please provide --eos_token to replace eos in sequence generator" + ) + self._infer_tgt_lang_id = infer_tgt_lang_id @classmethod @@ -265,15 +283,15 @@ def setup_task(cls, args, **kwargs): def build_criterion(self, args): from fairseq import criterions - if len(self.multitask_tasks) > 0: - if self.args.target_is_code and args._name != "speech_to_unit": - raise ValueError( - "set --criterion speech_to_unit for speech-to-unit loss with multitask" - ) - elif not self.args.target_is_code and args._name != "speech_to_spectrogram": - raise ValueError( - "set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask" - ) + # if len(self.multitask_tasks) > 0: + # if self.args.target_is_code and args._name != "speech_to_unit": + # raise ValueError( + # "set --criterion speech_to_unit for speech-to-unit loss with multitask" + # ) + # elif not self.args.target_is_code and args._name != "speech_to_spectrogram": + # raise ValueError( + # "set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask" + # ) return criterions.build_criterion(args, self) @@ -295,6 +313,10 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): def target_dictionary(self): return self.tgt_dict + @property + def target_dictionary_mt(self): + return self.tgt_dict_mt + @property def source_dictionary(self): return None @@ -325,6 +347,36 @@ def build_model(self, args, from_checkpoint=False): return model + def build_generator_translatotron2( + self, + models, + args, + extra_gen_cls_kwargs=None, + ): + from fairseq.sequence_generator_multi_decoder import ( + MultiDecoderSequenceGenerator, + ) + + return MultiDecoderSequenceGenerator( + models, + self.target_dictionary, + self.target_dictionary_mt, + beam_size=max(1, getattr(args, "beam", 1)), + beam_size_mt=max(1, getattr(args, "beam_mt", 1)), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + max_len_a_mt=getattr(args, "max_len_a_mt", 0), + max_len_b_mt=getattr(args, "max_len_b_mt", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + **extra_gen_cls_kwargs, + ) + def build_generator( self, models, @@ -343,14 +395,26 @@ def build_generator( else self.vocoder.cpu() ) + from fairseq.models.speech_to_speech import ( + SpecT2ConformerModel, + UnitYConformerModel, + ) + if self.args.target_is_code: if self.args.n_frames_per_step == 1: - seq_generator = super().build_generator( - models, - args, - seq_gen_cls=None, - extra_gen_cls_kwargs=extra_gen_cls_kwargs, - ) + if isinstance(models[0], UnitYConformerModel): + seq_generator = self.build_generator_translatotron2( + models, + args, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + else: + seq_generator = super().build_generator( + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) else: assert ( getattr(args, "beam", 1) == 1 and getattr(args, "nbest", 1) == 1 @@ -360,24 +424,64 @@ def build_generator( self.args.target_code_size, ) else: - if getattr(args, "teacher_forcing", False): - from fairseq.speech_generator import ( - TeacherForcingAutoRegressiveSpeechGenerator, + if isinstance(models[0], SpecT2ConformerModel): + if getattr(args, "teacher_forcing", False): + raise NotImplementedError + else: + from fairseq.speech_generator import Translatotron2SpeechGenerator + + generator = Translatotron2SpeechGenerator + + lang_token_ids_aux = { + i + for s, i in self.tgt_dict_mt.indices.items() + if TextTargetMultitaskData.is_lang_tag(s) + } + + if extra_gen_cls_kwargs is None: + extra_gen_cls_kwargs = {} + extra_gen_cls_kwargs[ + "symbols_to_strip_from_output" + ] = lang_token_ids_aux + + eos_id_mt = ( + self.tgt_dict_mt.index(self.eos_token_mt) + if self.eos_token_mt + else None ) + assert eos_id_mt != self.tgt_dict_mt.unk() + extra_gen_cls_kwargs["eos_mt"] = eos_id_mt - generator = TeacherForcingAutoRegressiveSpeechGenerator - logger.info("Teacher forcing mode for generation") + seq_generator = generator( + models, + args, + self.vocoder, + self.data_cfg, + self.target_dictionary_mt, + max_iter=self.args.max_target_positions, + eos_prob_threshold=self.args.eos_prob_threshold, + **extra_gen_cls_kwargs, + ) else: - from fairseq.speech_generator import AutoRegressiveSpeechGenerator - - generator = AutoRegressiveSpeechGenerator - seq_generator = generator( - models[0], - self.vocoder, - self.data_cfg, - max_iter=self.args.max_target_positions, - eos_prob_threshold=self.args.eos_prob_threshold, - ) + if getattr(args, "teacher_forcing", False): + from fairseq.speech_generator import ( + TeacherForcingAutoRegressiveSpeechGenerator, + ) + + generator = TeacherForcingAutoRegressiveSpeechGenerator + logger.info("Teacher forcing mode for generation") + else: + from fairseq.speech_generator import AutoRegressiveSpeechGenerator + + generator = AutoRegressiveSpeechGenerator + + seq_generator = generator( + models[0], + self.vocoder, + self.data_cfg, + max_iter=self.args.max_target_positions, + eos_prob_threshold=self.args.eos_prob_threshold, + ) return seq_generator @@ -388,6 +492,8 @@ def train_step( criterion.set_multitask_loss_weight( task_name, task_obj.args.get_loss_weight(update_num) ) + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].train() loss, sample_size, logging_output = super().train_step( sample, model, criterion, optimizer, update_num, ignore_grad @@ -395,6 +501,9 @@ def train_step( return loss, sample_size, logging_output def valid_step(self, sample, model, criterion): + for task_name, task_obj in self.multitask_tasks.items(): + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].eval() loss, sample_size, logging_output = super().valid_step(sample, model, criterion) if self.args.eval_inference: @@ -480,41 +589,3 @@ def inference_step( prefix_tokens=prefix_tokens, constraints=constraints, ) - - -class DummyMultiTask(LegacyFairseqTask): - def __init__(self, args, tgt_dict): - super().__init__(args) - self.tgt_dict = tgt_dict - - @property - def target_dictionary(self): - return self.tgt_dict - - def inference_step( - self, generator, models, sample, prefix_tokens=None, constraints=None - ): - if self.args.decoder_type == "ctc": - model = models[0] # only support single model - encoder_out = model(**sample) - if hasattr(model, "get_logits"): - emissions = model.get_logits( - encoder_out - ) # no need to normalize emissions - else: - emissions = model.get_normalized_probs(encoder_out, log_probs=True) - return generator.decode( - emissions.transpose(0, 1).float().cpu().contiguous() - ) - else: - raise NotImplementedError("only ctc decoder is supported at the moment") - - def build_generator( - self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None - ): - if self.args.decoder_type == "ctc": - from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder - - return W2lViterbiDecoder(args, self.tgt_dict) - else: - raise NotImplementedError("only ctc decoder is supported at the moment") diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 80e18dc072..5ca6c03c93 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -4,19 +4,21 @@ # LICENSE file in the root directory of this source tree. import logging -from pathlib import Path from argparse import Namespace +from pathlib import Path +from typing import List from fairseq.data import Dictionary, encoders from fairseq.data.audio.audio_utils import get_features_or_waveform +from fairseq.data.audio.data_cfg import MultitaskConfig from fairseq.data.audio.speech_to_text_dataset import ( S2TDataConfig, SpeechToTextDataset, SpeechToTextDatasetCreator, + TextTargetMultitaskData, ) from fairseq.tasks import LegacyFairseqTask, register_task - logger = logging.getLogger(__name__) @@ -31,6 +33,12 @@ def add_args(cls, parser): default="config.yaml", help="Configuration YAML filename (under manifest root)", ) + parser.add_argument( + "--multitask-config-yaml", + type=str, + default=None, + help="Configuration YAML filename for the multitasks (under manifest root)", + ) parser.add_argument( "--max-source-positions", default=6000, @@ -59,6 +67,27 @@ def __init__(self, args, tgt_dict): "Please set only one of the two options to avoid adding target token multiple times" ) + self.multitask_tasks = {} + self.tgt_dict_mt = None + self.eos_token_mt = None + if getattr(args, "multitask_config_yaml", None) is not None: + multitask_cfg = MultitaskConfig( + Path(args.data) / args.multitask_config_yaml + ) + for task_name, task_config in multitask_cfg.get_all_tasks().items(): + task_obj = DummyMultiTask(task_config, task_config.tgt_dict) + self.multitask_tasks[task_name] = task_obj + if "target" in task_name and task_config.decoder_type != "ctc": + self.tgt_dict_mt = task_obj.target_dictionary + if task_config.prepend_bos_and_append_tgt_lang_tag: + self.eos_token_mt = task_config.eos_token + assert not isinstance(self.eos_token_mt, List) + + if not self.eos_token_mt: + raise Warning( + "Please provide --eos_token to replace eos in sequence generator" + ) + def _get_speaker_to_id(self): speaker_to_id = None speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") @@ -109,12 +138,17 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): epoch=epoch, seed=self.args.seed, speaker_to_id=self.speaker_to_id, + multitask=self.multitask_tasks, ) @property def target_dictionary(self): return self.tgt_dict + @property + def target_dictionary_mt(self): + return self.tgt_dict_mt + @property def source_dictionary(self): return None @@ -128,6 +162,51 @@ def build_model(self, args, from_checkpoint=False): args.speaker_to_id = self.speaker_to_id return super(SpeechToTextTask, self).build_model(args, from_checkpoint) + def build_generator_translatotron2( + self, + models, + args, + extra_gen_cls_kwargs, + ): + from fairseq.sequence_generator_multi_decoder import ( + MultiDecoderSequenceGenerator, + ) + + lang_token_ids_aux = { + i + for s, i in self.tgt_dict_mt.indices.items() + if TextTargetMultitaskData.is_lang_tag(s) + } + + extra_gen_cls_kwargs["symbols_to_strip_from_output"].update(lang_token_ids_aux) + + eos_id_mt = ( + self.tgt_dict_mt.index(self.eos_token_mt) if self.eos_token_mt else None + ) + assert eos_id_mt != self.tgt_dict_mt.unk() + extra_gen_cls_kwargs["eos_mt"] = eos_id_mt + + return MultiDecoderSequenceGenerator( + models, + self.target_dictionary, + self.target_dictionary_mt, + beam_size=max(1, getattr(args, "beam", 1)), + beam_size_mt=max(1, getattr(args, "beam_mt", 1)), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + max_len_a_mt=getattr(args, "max_len_a_mt", 0), + max_len_b_mt=getattr(args, "max_len_b_mt", 0), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + len_penalty_mt=getattr(args, "lenpen_mt", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + **extra_gen_cls_kwargs, + ) + def build_generator( self, models, @@ -164,9 +243,44 @@ def build_generator( eos_id = self.tgt_dict.index(eos_token) if eos_token else None extra_gen_cls_kwargs["eos"] = eos_id - return super().build_generator( - models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs + from fairseq.models.speech_to_text import UnitYXMTransformerModel + + if isinstance(models[0], UnitYXMTransformerModel): + return self.build_generator_translatotron2( + models, + args, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + else: + return super().build_generator( + models, + args, + seq_gen_cls=None, + extra_gen_cls_kwargs=extra_gen_cls_kwargs, + ) + + def train_step( + self, sample, model, criterion, optimizer, update_num, ignore_grad=False + ): + for task_name, task_obj in self.multitask_tasks.items(): + criterion.set_multitask_loss_weight( + task_name, task_obj.args.get_loss_weight(update_num) + ) + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].train() + + loss, sample_size, logging_output = super().train_step( + sample, model, criterion, optimizer, update_num, ignore_grad ) + return loss, sample_size, logging_output + + def valid_step(self, sample, model, criterion): + for task_name, task_obj in self.multitask_tasks.items(): + if task_name in model.multitask_decoders: + model.multitask_decoders[task_name].eval() + loss, sample_size, logging_output = super().valid_step(sample, model, criterion) + + return loss, sample_size, logging_output def build_tokenizer(self, args): logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") @@ -184,3 +298,41 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): return SpeechToTextDataset( "interactive", False, self.data_cfg, src_tokens, src_lengths ) + + +class DummyMultiTask(LegacyFairseqTask): + def __init__(self, args, tgt_dict): + super().__init__(args) + self.tgt_dict = tgt_dict + + @property + def target_dictionary(self): + return self.tgt_dict + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + if self.args.decoder_type == "ctc": + model = models[0] # only support single model + encoder_out = model(**sample) + if hasattr(model, "get_logits"): + emissions = model.get_logits( + encoder_out + ) # no need to normalize emissions + else: + emissions = model.get_normalized_probs(encoder_out, log_probs=True) + return generator.decode( + emissions.transpose(0, 1).float().cpu().contiguous() + ) + else: + raise NotImplementedError("only ctc decoder is supported at the moment") + + def build_generator( + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None + ): + if self.args.decoder_type == "ctc": + from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder + + return W2lViterbiDecoder(args, self.tgt_dict) + else: + raise NotImplementedError("only ctc decoder is supported at the moment") From ff62dfe391d7bc846f0e7f2b542b48d020529d35 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Sat, 27 Aug 2022 04:42:24 -0700 Subject: [PATCH 02/35] Rename for consistency --- .../criterions/speech_to_speech_criterion.py | 6 +- fairseq/models/speech_to_speech/__init__.py | 3 +- .../models/speech_to_speech/s2s_conformer.py | 42 ++- .../s2s_conformer_translatotron2.py | 261 ++++++++++++++++++ ...conformer_t2.py => s2s_conformer_unity.py} | 242 +--------------- .../speech_to_speech/s2s_transformer.py | 28 +- .../speech_to_text/xm_transformer_unity.py | 3 +- fairseq/tasks/speech_to_speech.py | 4 +- fairseq/tasks/speech_to_text.py | 4 +- 9 files changed, 324 insertions(+), 269 deletions(-) create mode 100644 fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py rename fairseq/models/speech_to_speech/{s2s_conformer_t2.py => s2s_conformer_unity.py} (65%) diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py index a6bc0cb73f..564e5f1ce2 100644 --- a/fairseq/criterions/speech_to_speech_criterion.py +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -252,7 +252,7 @@ def logging_outputs_can_be_summed() -> bool: @register_criterion( - "speech_to_unit_translatotron2", dataclass=LabelSmoothedCrossEntropyCriterionConfig + "speech_to_unit_2pass", dataclass=LabelSmoothedCrossEntropyCriterionConfig ) class SpeechToUnitTranslatotron2MultitaskTaskCriterion( SpeechToUnitMultitaskTaskCriterion @@ -429,9 +429,7 @@ def reduce_metrics(cls, logging_outputs) -> None: MultitaskCriterion.reduce_metrics(logging_outputs) -@register_criterion( - "speech_to_spectrogram_translatotron2", dataclass=Tacotron2CriterionConfig -) +@register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig) class SpeechToSpectrogramTranslatotron2MultitaskTaskCriterion( Tacotron2Criterion, MultitaskCriterion ): diff --git a/fairseq/models/speech_to_speech/__init__.py b/fairseq/models/speech_to_speech/__init__.py index d3105bf429..76fd1ef7ec 100644 --- a/fairseq/models/speech_to_speech/__init__.py +++ b/fairseq/models/speech_to_speech/__init__.py @@ -5,5 +5,6 @@ from .modules import * # noqa from .s2s_conformer import * # noqa -from .s2s_conformer_t2 import * # noqa +from .s2s_conformer_translatotron2 import * # noqa +from .s2s_conformer_unity import * # noqa from .s2s_transformer import * # noqa diff --git a/fairseq/models/speech_to_speech/s2s_conformer.py b/fairseq/models/speech_to_speech/s2s_conformer.py index 7f1d49c8bc..2ba0bd389d 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer.py +++ b/fairseq/models/speech_to_speech/s2s_conformer.py @@ -11,10 +11,10 @@ from fairseq import checkpoint_utils from fairseq.models import register_model, register_model_architecture from fairseq.models.speech_to_speech.s2s_transformer import ( - S2SpecTTransformerModel, S2UTTransformerModel, - s2spect_architecture_base, + TranslatotronTransformerModel, s2ut_architecture_base, + translatotron_architecture_base, ) from fairseq.models.speech_to_text import S2TConformerEncoder from fairseq.models.transformer import Linear @@ -53,7 +53,7 @@ def forward( @register_model("s2ut_conformer") class S2UTConformerModel(S2UTTransformerModel): """ - Direct speech-to-speech translation model with S2T Conformer encoder + Transformer discrete unit decoder + Direct speech-to-speech translation model with Conformer encoder + Transformer discrete unit decoder (S2UT) """ @staticmethod @@ -90,15 +90,15 @@ def build_encoder(cls, args): return encoder -@register_model("s2spect_conformer") -class S2SpecTConformerModel(S2SpecTTransformerModel): +@register_model("translatotron_conformer") +class TranslatotronConformerModel(TranslatotronTransformerModel): """ - Direct speech-to-speech translation model with S2T Conformer encoder + TTS Transformer decoder + Direct speech-to-speech translation model with Conformer encoder + TTS Transformer decoder (Translatotron) """ @staticmethod def add_args(parser): - S2SpecTTransformerModel.add_args(parser) + TranslatotronTransformerModel.add_args(parser) parser.add_argument("--depthwise-conv-kernel-size", type=int, default=31) parser.add_argument( "--attn-type", @@ -146,8 +146,8 @@ def s2ut_conformer_architecture_base(args): s2ut_architecture_base(args) -@register_model_architecture("s2spect_conformer", "s2spect_conformer") -def s2spect_conformer_architecture_base(args): +@register_model_architecture("translatotron_conformer", "translatotron_conformer") +def translatotron_conformer_architecture_base(args): args.attn_type = getattr(args, "attn_type", None) args.pos_enc_type = getattr(args, "pos_enc_type", "abs") args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) @@ -159,11 +159,13 @@ def s2spect_conformer_architecture_base(args): args.dropout = getattr(args, "dropout", 0.1) args.encoder_layers = getattr(args, "encoder_layers", 16) args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) - s2spect_architecture_base(args) + translatotron_architecture_base(args) -@register_model_architecture("s2spect_conformer", "s2spect_conformer_fisher") -def s2spect_architecture_fisher(args): +@register_model_architecture( + "translatotron_conformer", "translatotron_conformer_fisher" +) +def translatotron_architecture_fisher(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) @@ -172,4 +174,18 @@ def s2spect_architecture_fisher(args): # decoder args.prenet_dim = getattr(args, "prenet_dim", 32) - s2spect_conformer_architecture_base(args) + translatotron_conformer_architecture_base(args) + + +# for old models +@register_model_architecture( + model_name="translatotron_conformer", arch_name="s2spect_conformer" +) +def translatotron_conformer_architecture_base_legacy(args): + translatotron_conformer_architecture_base(args) + + +# for old models +@register_model_architecture("translatotron_conformer", "s2spect_conformer_fisher") +def translatotron_architecture_fisher_legacy(args): + translatotron_architecture_fisher(args) diff --git a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py new file mode 100644 index 0000000000..c3286c491f --- /dev/null +++ b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py @@ -0,0 +1,261 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging + +from fairseq.models import ( + FairseqEncoderModel, + FairseqLanguageModel, + register_model, + register_model_architecture, +) +from fairseq.models.speech_to_speech.modules import CTCDecoder +from fairseq.models.speech_to_speech.s2s_conformer import TranslatotronConformerModel +from fairseq.models.speech_to_speech.s2s_conformer_unity import ( + TransformerEncoderNoEmb, + multitask_text_transformer_decoder_arch, +) +from fairseq.models.speech_to_speech.s2s_transformer import ( + base_multitask_text_transformer_decoder_arch, + translatotron_architecture_base, +) +from fairseq.models.text_to_speech import TTSTransformerDecoder +from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase + +logger = logging.getLogger(__name__) + + +@register_model("translatotron2_conformer") +class Translatotron2ConformerModel(TranslatotronConformerModel): + """ + Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + TTS Transformer decoder (Translatotron2) + """ + + @staticmethod + def add_args(parser): + TranslatotronConformerModel.add_args(parser) + parser.add_argument( + "--translation-decoder-layers", + type=int, + default=4, + metavar="N", + help="num decoder layers in the first-pass translation module", + ) + parser.add_argument( + "--synthesizer", + default="transformer", + choices=["transformer"], + help="", + ) + parser.add_argument( + "--synthesizer-encoder-layers", + type=int, + default=0, + metavar="N", + help="num encoder layers in the second-pass synthesizer module", + ) + + @classmethod + def build_multitask_decoder( + cls, + args, + tgt_dict, + in_dim, + is_mt_decoder, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ): + decoder_args = args.decoder_args + decoder_args.encoder_embed_dim = in_dim + if args.decoder_type == "transformer": + if is_mt_decoder: + multitask_text_transformer_decoder_arch( + decoder_args, + decoder_layers, + decoder_embed_dim, + decoder_attention_heads, + ) # 4L + else: + base_multitask_text_transformer_decoder_arch(decoder_args) # 2L + task_decoder = TransformerDecoder( + decoder_args, + tgt_dict, + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) + elif args.decoder_type == "ctc": + task_decoder = CTCDecoder( + dictionary=tgt_dict, + in_dim=in_dim, + ) + else: + raise NotImplementedError( + "currently only support multitask decoder_type 'transformer', 'ctc'" + ) + + return task_decoder + + @classmethod + def build_decoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_embed_dim = args.decoder_embed_dim + + if args.synthesizer == "transformer": + return TTSTransformerDecoder(_args, None, padding_idx=1) + else: + raise NotImplementedError(args.synthesizer) + + @classmethod + def build_model(cls, args, task): + encoder = cls.build_encoder(args) + decoder = cls.build_decoder(args) + base_model = cls(encoder, decoder) + + # set up multitask decoders + is_mt_decoder = False + base_model.mt_task_name = None + base_model.multitask_decoders = {} + n_aux_tasks = len(list(task.multitask_tasks.items())) + for i, (task_name, task_obj) in enumerate(task.multitask_tasks.items()): + if i == n_aux_tasks - 1: + is_mt_decoder = True + base_model.mt_task_name = task_name + assert "target" in task_name + assert task_obj.args.decoder_type == "transformer" + # NOTE: we assume that the last task is for the first-pass decoder + + in_dim = ( + args.encoder_embed_dim + if task_obj.args.input_from == "encoder" + else args.decoder_embed_dim + ) + task_decoder = cls.build_multitask_decoder( + task_obj.args, + task_obj.target_dictionary, + in_dim, + is_mt_decoder, + getattr(args, "translation_decoder_layers", 4), + getattr(args, "decoder_embed_dim", 256), + getattr(args, "decoder_attention_heads", 4), + ) + + setattr(base_model, f"{task_name}_decoder", task_decoder) + decoder_model_cls = ( + FairseqEncoderModel + if task_obj.args.decoder_type == "ctc" + else FairseqLanguageModel + ) + base_model.multitask_decoders[task_name] = decoder_model_cls( + getattr(base_model, f"{task_name}_decoder") + ) + + assert is_mt_decoder, "set at least one intermediate non-CTC decoder" + + # set up encoder on top of the auxiliary MT decoder + if getattr(args, "synthesizer_encoder_layers", 0) > 0: + base_model.synthesizer_encoder = cls.build_text_encoder(args) + + return base_model + + @classmethod + def build_text_encoder(cls, args): + _args = copy.deepcopy(args) + _args.encoder_layers = args.synthesizer_encoder_layers + _args.encoder_embed_dim = args.decoder_embed_dim + _args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim + _args.encoder_attention_heads = args.decoder_attention_heads + _args.encoder_normalize_before = True + return TransformerEncoderNoEmb(_args) + + def forward( + self, + src_tokens, + src_lengths, + prev_output_tokens, + prev_output_tokens_mt, + tgt_speaker=None, + incremental_state=None, + target_lengths=None, + speaker=None, + return_all_hiddens=False, + ): + encoder_out = self.encoder( + src_tokens, + src_lengths=src_lengths, + tgt_speaker=tgt_speaker, + return_all_hiddens=return_all_hiddens, + ) + + # 1. MT decoder + mt_decoder = getattr(self, f"{self.mt_task_name}_decoder") + mt_decoder_out = mt_decoder( + prev_output_tokens_mt, + encoder_out=encoder_out, + ) + x = mt_decoder_out[1]["inner_states"][-1] + if mt_decoder.layer_norm is not None: + x = mt_decoder.layer_norm(x) + + mt_decoder_padding_mask = None + if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): + mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) + + # 2. TTS encoder + if hasattr(self, "synthesizer_encoder"): + tts_encoder_out = self.synthesizer_encoder( + x, + mt_decoder_padding_mask, + return_all_hiddens=return_all_hiddens, + ) + else: + tts_encoder_out = { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [mt_decoder_padding_mask], # B x T + } + + # 3. TTS decoder + decoder_out = self.decoder( + prev_output_tokens, + encoder_out=tts_encoder_out, + incremental_state=incremental_state, + target_lengths=target_lengths, + speaker=speaker, + ) + if return_all_hiddens: + decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"] + decoder_out[-1]["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ] + decoder_out[-1]["mt_decoder_out"] = mt_decoder_out + return decoder_out + + +@register_model_architecture( + model_name="translatotron2_conformer", arch_name="translatotron2_conformer" +) +def translatotron2_conformer_architecture_base(args): + args.attn_type = getattr(args, "attn_type", None) + args.pos_enc_type = getattr(args, "pos_enc_type", "abs") + args.max_source_positions = getattr(args, "max_source_positions", 6000) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) + args.dropout = getattr(args, "dropout", 0.1) + args.encoder_layers = getattr(args, "encoder_layers", 16) + args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) + translatotron_architecture_base(args) + + +# for old models +@register_model_architecture( + model_name="translatotron2_conformer", arch_name="s2spect_conformer_translatotron2" +) +def translatotron2_conformer_architecture_base_legacy(args): + translatotron2_conformer_architecture_base(args) diff --git a/fairseq/models/speech_to_speech/s2s_conformer_t2.py b/fairseq/models/speech_to_speech/s2s_conformer_unity.py similarity index 65% rename from fairseq/models/speech_to_speech/s2s_conformer_t2.py rename to fairseq/models/speech_to_speech/s2s_conformer_unity.py index cab15344e5..f069851868 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_t2.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_unity.py @@ -18,17 +18,12 @@ register_model_architecture, ) from fairseq.models.speech_to_speech.modules import CTCDecoder -from fairseq.models.speech_to_speech.s2s_conformer import ( - S2SpecTConformerModel, - S2UTConformerModel, -) +from fairseq.models.speech_to_speech.s2s_conformer import S2UTConformerModel from fairseq.models.speech_to_speech.s2s_transformer import ( TransformerUnitDecoder, base_multitask_text_transformer_decoder_arch, - s2spect_architecture_base, s2ut_architecture_base, ) -from fairseq.models.text_to_speech import TTSTransformerDecoder from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder from fairseq.modules import LayerNorm, TransformerEncoderLayer @@ -48,7 +43,7 @@ def multitask_text_transformer_decoder_arch( @register_model("unity_conformer") class UnitYConformerModel(S2UTConformerModel): """ - Direct speech-to-speech translation model with S2T Conformer encoder + MT Transformer decoder + Transformer discrete unit decoder + Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + Transformer discrete unit decoder (UnitY) """ @staticmethod @@ -285,215 +280,6 @@ def forward( return decoder_out -@register_model("spect2_conformer") -class SpecT2ConformerModel(S2SpecTConformerModel): - """ - Direct speech-to-speech translation model with S2T Conformer encoder + MT Transformer decoder + TTS Transformer decoder - """ - - @staticmethod - def add_args(parser): - S2SpecTConformerModel.add_args(parser) - parser.add_argument( - "--translation-decoder-layers", - type=int, - default=4, - metavar="N", - help="num decoder layers in the first-pass translation module", - ) - parser.add_argument( - "--synthesizer", - default="transformer", - choices=["transformer"], - help="", - ) - parser.add_argument( - "--synthesizer-encoder-layers", - type=int, - default=0, - metavar="N", - help="num encoder layers in the second-pass synthesizer module", - ) - - @classmethod - def build_multitask_decoder( - cls, - args, - tgt_dict, - in_dim, - is_mt_decoder, - decoder_layers, - decoder_embed_dim, - decoder_attention_heads, - ): - decoder_args = args.decoder_args - decoder_args.encoder_embed_dim = in_dim - if args.decoder_type == "transformer": - if is_mt_decoder: - multitask_text_transformer_decoder_arch( - decoder_args, - decoder_layers, - decoder_embed_dim, - decoder_attention_heads, - ) # 4L - else: - base_multitask_text_transformer_decoder_arch(decoder_args) # 2L - task_decoder = TransformerDecoder( - decoder_args, - tgt_dict, - embed_tokens=TransformerModelBase.build_embedding( - decoder_args, - tgt_dict, - decoder_args.decoder_embed_dim, - ), - ) - elif args.decoder_type == "ctc": - task_decoder = CTCDecoder( - dictionary=tgt_dict, - in_dim=in_dim, - ) - else: - raise NotImplementedError( - "currently only support multitask decoder_type 'transformer', 'ctc'" - ) - - return task_decoder - - @classmethod - def build_decoder(cls, args): - _args = copy.deepcopy(args) - _args.encoder_embed_dim = args.decoder_embed_dim - - if args.synthesizer == "transformer": - return TTSTransformerDecoder(_args, None, padding_idx=1) - else: - raise NotImplementedError(args.synthesizer) - - @classmethod - def build_model(cls, args, task): - encoder = cls.build_encoder(args) - decoder = cls.build_decoder(args) - base_model = cls(encoder, decoder) - - # set up multitask decoders - is_mt_decoder = False - base_model.mt_task_name = None - base_model.multitask_decoders = {} - n_aux_tasks = len(list(task.multitask_tasks.items())) - for i, (task_name, task_obj) in enumerate(task.multitask_tasks.items()): - if i == n_aux_tasks - 1: - is_mt_decoder = True - base_model.mt_task_name = task_name - assert "target" in task_name - assert task_obj.args.decoder_type == "transformer" - # NOTE: we assume that the last task is for the first-pass decoder - - in_dim = ( - args.encoder_embed_dim - if task_obj.args.input_from == "encoder" - else args.decoder_embed_dim - ) - task_decoder = cls.build_multitask_decoder( - task_obj.args, - task_obj.target_dictionary, - in_dim, - is_mt_decoder, - getattr(args, "translation_decoder_layers", 4), - getattr(args, "decoder_embed_dim", 256), - getattr(args, "decoder_attention_heads", 4), - ) - - setattr(base_model, f"{task_name}_decoder", task_decoder) - decoder_model_cls = ( - FairseqEncoderModel - if task_obj.args.decoder_type == "ctc" - else FairseqLanguageModel - ) - base_model.multitask_decoders[task_name] = decoder_model_cls( - getattr(base_model, f"{task_name}_decoder") - ) - - assert is_mt_decoder, "set at least one intermediate non-CTC decoder" - - # set up encoder on top of the auxiliary MT decoder - if getattr(args, "synthesizer_encoder_layers", 0) > 0: - base_model.synthesizer_encoder = cls.build_text_encoder(args) - - return base_model - - @classmethod - def build_text_encoder(cls, args): - _args = copy.deepcopy(args) - _args.encoder_layers = args.synthesizer_encoder_layers - _args.encoder_embed_dim = args.decoder_embed_dim - _args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim - _args.encoder_attention_heads = args.decoder_attention_heads - _args.encoder_normalize_before = True - return TransformerEncoderNoEmb(_args) - - def forward( - self, - src_tokens, - src_lengths, - prev_output_tokens, - prev_output_tokens_mt, - tgt_speaker=None, - incremental_state=None, - target_lengths=None, - speaker=None, - return_all_hiddens=False, - ): - encoder_out = self.encoder( - src_tokens, - src_lengths=src_lengths, - tgt_speaker=tgt_speaker, - return_all_hiddens=return_all_hiddens, - ) - - # 1. MT decoder - mt_decoder = getattr(self, f"{self.mt_task_name}_decoder") - mt_decoder_out = mt_decoder( - prev_output_tokens_mt, - encoder_out=encoder_out, - ) - x = mt_decoder_out[1]["inner_states"][-1] - if mt_decoder.layer_norm is not None: - x = mt_decoder.layer_norm(x) - - mt_decoder_padding_mask = None - if prev_output_tokens_mt.eq(mt_decoder.padding_idx).any(): - mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) - - # 2. TTS encoder - if hasattr(self, "synthesizer_encoder"): - tts_encoder_out = self.synthesizer_encoder( - x, - mt_decoder_padding_mask, - return_all_hiddens=return_all_hiddens, - ) - else: - tts_encoder_out = { - "encoder_out": [x], # T x B x C - "encoder_padding_mask": [mt_decoder_padding_mask], # B x T - } - - # 3. TTS decoder - decoder_out = self.decoder( - prev_output_tokens, - encoder_out=tts_encoder_out, - incremental_state=incremental_state, - target_lengths=target_lengths, - speaker=speaker, - ) - if return_all_hiddens: - decoder_out[-1]["encoder_states"] = encoder_out["encoder_states"] - decoder_out[-1]["encoder_padding_mask"] = encoder_out[ - "encoder_padding_mask" - ] - decoder_out[-1]["mt_decoder_out"] = mt_decoder_out - return decoder_out - - class TransformerEncoderNoEmb(FairseqEncoder): """Transformer encoder without token embeddings.""" @@ -681,31 +467,9 @@ def unity_conformer_architecture_base(args): s2ut_architecture_base(args) +# for old models @register_model_architecture( model_name="unity_conformer", arch_name="s2ut_conformer_translatotron2" ) def unity_conformer_architecture_base_legacy(args): unity_conformer_architecture_base(args) - - -@register_model_architecture( - model_name="spect2_conformer", arch_name="spect2_conformer" -) -def translatotron2_conformer_architecture_base(args): - args.attn_type = getattr(args, "attn_type", None) - args.pos_enc_type = getattr(args, "pos_enc_type", "abs") - args.max_source_positions = getattr(args, "max_source_positions", 6000) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) - args.dropout = getattr(args, "dropout", 0.1) - args.encoder_layers = getattr(args, "encoder_layers", 16) - args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) - s2spect_architecture_base(args) - - -@register_model_architecture( - model_name="spect2_conformer", arch_name="s2spect_conformer_translatotron2" -) -def translatotron2_conformer_architecture_base_legacy(args): - translatotron2_conformer_architecture_base(args) diff --git a/fairseq/models/speech_to_speech/s2s_transformer.py b/fairseq/models/speech_to_speech/s2s_transformer.py index 5af07bb673..6ba8808061 100644 --- a/fairseq/models/speech_to_speech/s2s_transformer.py +++ b/fairseq/models/speech_to_speech/s2s_transformer.py @@ -416,8 +416,8 @@ def forward( return decoder_out -@register_model("s2spect_transformer") -class S2SpecTTransformerModel(S2STransformerMultitaskModelBase): +@register_model("translatotron_transformer") +class TranslatotronTransformerModel(S2STransformerMultitaskModelBase): """ Speech-to-spectrogram model with S2T Transformer encoder + TTS Transformer decoder """ @@ -675,9 +675,9 @@ def s2ut_architecture_fisher(args): @register_model_architecture( - model_name="s2spect_transformer", arch_name="s2spect_transformer" + model_name="translatotron_transformer", arch_name="translatotron_transformer" ) -def s2spect_architecture_base(args): +def translatotron_architecture_base(args): base_s2st_transformer_encoder_architecture(args) # decoder @@ -701,8 +701,10 @@ def s2spect_architecture_base(args): args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) -@register_model_architecture("s2spect_transformer", "s2spect_transformer_fisher") -def s2spect_architecture_fisher(args): +@register_model_architecture( + "translatotron_transformer", "translatotron_transformer_fisher" +) +def translatotron_architecture_fisher(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) @@ -711,4 +713,16 @@ def s2spect_architecture_fisher(args): # decoder args.prenet_dim = getattr(args, "prenet_dim", 32) - s2spect_architecture_base(args) + translatotron_architecture_base(args) + + +# for old models +@register_model_architecture("translatotron_transformer", "s2spect_transformer") +def translatotron_architecture_base_legacy(args): + translatotron_architecture_base(args) + + +# for old models +@register_model_architecture("translatotron_transformer", "s2spect_transformer_fisher") +def translatotron_architecture_fisher_legacy(args): + translatotron_architecture_fisher(args) diff --git a/fairseq/models/speech_to_text/xm_transformer_unity.py b/fairseq/models/speech_to_text/xm_transformer_unity.py index 752bee7602..9234285d32 100644 --- a/fairseq/models/speech_to_text/xm_transformer_unity.py +++ b/fairseq/models/speech_to_text/xm_transformer_unity.py @@ -305,8 +305,9 @@ def base_architecture_unity(args): args.decoder_learned_pos = False +# for old models @register_model_architecture( model_name="unity_xm_transformer", arch_name="xm_transformer_t2" ) -def base_architecture_legacy(args): +def base_architecture_unity_legacy(args): base_architecture_unity(args) diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py index 8da5fecbd6..50e66df2cb 100644 --- a/fairseq/tasks/speech_to_speech.py +++ b/fairseq/tasks/speech_to_speech.py @@ -396,7 +396,7 @@ def build_generator( ) from fairseq.models.speech_to_speech import ( - SpecT2ConformerModel, + Translatotron2ConformerModel, UnitYConformerModel, ) @@ -424,7 +424,7 @@ def build_generator( self.args.target_code_size, ) else: - if isinstance(models[0], SpecT2ConformerModel): + if isinstance(models[0], Translatotron2ConformerModel): if getattr(args, "teacher_forcing", False): raise NotImplementedError else: diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 5ca6c03c93..5130069adf 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -162,7 +162,7 @@ def build_model(self, args, from_checkpoint=False): args.speaker_to_id = self.speaker_to_id return super(SpeechToTextTask, self).build_model(args, from_checkpoint) - def build_generator_translatotron2( + def build_generator_unity( self, models, args, @@ -246,7 +246,7 @@ def build_generator( from fairseq.models.speech_to_text import UnitYXMTransformerModel if isinstance(models[0], UnitYXMTransformerModel): - return self.build_generator_translatotron2( + return self.build_generator_unity( models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs, From 122acbfc509b81462fe68d8cab5e78287b5cfa8a Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 1 Sep 2022 00:22:30 -0700 Subject: [PATCH 03/35] Refactor conformer encoder construction --- .../models/speech_to_speech/s2s_conformer.py | 44 ++++++++----------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/fairseq/models/speech_to_speech/s2s_conformer.py b/fairseq/models/speech_to_speech/s2s_conformer.py index 2ba0bd389d..7278ccdef3 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer.py +++ b/fairseq/models/speech_to_speech/s2s_conformer.py @@ -22,6 +22,22 @@ logger = logging.getLogger(__name__) +def build_s2s_conformer_encoder(args): + encoder = S2SConformerEncoder(args) + pretraining_path = getattr(args, "load_pretrained_encoder_from", None) + if pretraining_path is not None: + if not Path(pretraining_path).exists(): + logger.warning( + f"skipped pretraining because {pretraining_path} does not exist" + ) + else: + encoder = checkpoint_utils.load_pretrained_component_from_model( + component=encoder, checkpoint=pretraining_path + ) + logger.info(f"loaded pretrained encoder from: {pretraining_path}") + return encoder + + class S2SConformerEncoder(S2TConformerEncoder): """Based on S2T transformer encoder, with support to incorporate target speaker embedding.""" @@ -75,19 +91,7 @@ def add_args(parser): @classmethod def build_encoder(cls, args): - encoder = S2SConformerEncoder(args) - pretraining_path = getattr(args, "load_pretrained_encoder_from", None) - if pretraining_path is not None: - if not Path(pretraining_path).exists(): - logger.warning( - f"skipped pretraining because {pretraining_path} does not exist" - ) - else: - encoder = checkpoint_utils.load_pretrained_component_from_model( - component=encoder, checkpoint=pretraining_path - ) - logger.info(f"loaded pretrained encoder from: {pretraining_path}") - return encoder + return build_s2s_conformer_encoder(args) @register_model("translatotron_conformer") @@ -115,19 +119,7 @@ def add_args(parser): @classmethod def build_encoder(cls, args): - encoder = S2SConformerEncoder(args) - pretraining_path = getattr(args, "load_pretrained_encoder_from", None) - if pretraining_path is not None: - if not Path(pretraining_path).exists(): - logger.warning( - f"skipped pretraining because {pretraining_path} does not exist" - ) - else: - encoder = checkpoint_utils.load_pretrained_component_from_model( - component=encoder, checkpoint=pretraining_path - ) - logger.info(f"loaded pretrained encoder from: {pretraining_path}") - return encoder + return build_s2s_conformer_encoder(args) @register_model_architecture("s2ut_conformer", "s2ut_conformer") From 65bad5c0e74d1454c8e843ac0fc5c6ca9df1a0ff Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 1 Sep 2022 00:33:20 -0700 Subject: [PATCH 04/35] Change the order of arguments for rdrop_alpha --- fairseq/criterions/label_smoothed_cross_entropy.py | 2 +- fairseq/criterions/speech_to_speech_criterion.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index 036dff943a..37c81d8c49 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -64,9 +64,9 @@ def __init__( task, sentence_avg, label_smoothing, - rdrop_alpha, ignore_prefix_size=0, report_accuracy=False, + rdrop_alpha=0.0, ): super().__init__(task) self.sentence_avg = sentence_avg diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py index 564e5f1ce2..fee7a4d800 100644 --- a/fairseq/criterions/speech_to_speech_criterion.py +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -163,17 +163,17 @@ def __init__( task, sentence_avg, label_smoothing, - rdrop_alpha, ignore_prefix_size=0, report_accuracy=False, + rdrop_alpha=0.0, ): super().__init__( task, sentence_avg, label_smoothing, - rdrop_alpha, ignore_prefix_size, report_accuracy, + rdrop_alpha, ) MultitaskCriterion.__init__(self, task.multitask_tasks, rdrop_alpha) @@ -262,17 +262,17 @@ def __init__( task, sentence_avg, label_smoothing, - rdrop_alpha, ignore_prefix_size=0, report_accuracy=False, + rdrop_alpha=0.0, ): super().__init__( task, sentence_avg, label_smoothing, - rdrop_alpha, ignore_prefix_size, report_accuracy, + rdrop_alpha, ) def forward(self, model, sample, reduce=True): From b09a84f2257552f5099f01532c1ec8bbd78c7e79 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 1 Sep 2022 00:39:58 -0700 Subject: [PATCH 05/35] Add compute_loss_with_rdrop --- fairseq/criterions/label_smoothed_cross_entropy.py | 12 +++++++++++- fairseq/criterions/speech_to_speech_criterion.py | 4 ++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index 37c81d8c49..0d6ee79ae8 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -131,6 +131,17 @@ def compute_loss(self, model, net_output, sample, reduce=True): ignore_index=self.padding_idx, reduce=reduce, ) + return loss, nll_loss + + def compute_loss_with_rdrop(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) if self.rdrop_alpha > 0: pad_mask = target[: target.size(0) // 2].unsqueeze(-1).eq(self.padding_idx) @@ -138,7 +149,6 @@ def compute_loss(self, model, net_output, sample, reduce=True): loss += self.rdrop_alpha * rdrop_kl_loss else: rdrop_kl_loss = loss.new_zeros(1) - return loss, nll_loss, rdrop_kl_loss def compute_accuracy(self, model, net_output, sample): diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py index fee7a4d800..818f3559d5 100644 --- a/fairseq/criterions/speech_to_speech_criterion.py +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -190,7 +190,7 @@ def forward(self, model, sample, reduce=True): net_input_concat = duplicate_input(net_input_concat) net_output, extra = model(**net_input_concat) - loss, nll_loss, rdrop_kl_loss = self.compute_loss( + loss, nll_loss, rdrop_kl_loss = self.compute_loss_with_rdrop( model, [net_output], sample, reduce=reduce ) sample_size = ( @@ -295,7 +295,7 @@ def forward(self, model, sample, reduce=True): net_input_concat = duplicate_input(net_input_concat) net_output, extra = model(**net_input_concat) - loss, nll_loss, rdrop_kl_loss = self.compute_loss( + loss, nll_loss, rdrop_kl_loss = self.compute_loss_with_rdrop( model, [net_output], sample, reduce=reduce ) From c8fa268d0085ffdaa555a6c36dd37996ba671d80 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 1 Sep 2022 00:47:35 -0700 Subject: [PATCH 06/35] Move build_multitask_decoder to xm_transformer_unity.py --- .../models/speech_to_text/xm_transformer.py | 38 +------------------ .../speech_to_text/xm_transformer_unity.py | 34 ++++++++++++++++- 2 files changed, 34 insertions(+), 38 deletions(-) diff --git a/fairseq/models/speech_to_text/xm_transformer.py b/fairseq/models/speech_to_text/xm_transformer.py index 5151063f66..0d0e93ec40 100644 --- a/fairseq/models/speech_to_text/xm_transformer.py +++ b/fairseq/models/speech_to_text/xm_transformer.py @@ -20,13 +20,8 @@ register_model, register_model_architecture, ) -from fairseq.models.speech_to_speech.modules import CTCDecoder from fairseq.models.speech_to_text.hub_interface import S2THubInterface -from fairseq.models.transformer import ( - Embedding, - TransformerDecoder, - TransformerModelBase, -) +from fairseq.models.transformer import Embedding, TransformerDecoder from fairseq.models.wav2vec import Wav2VecEncoder from fairseq.modules.layer_norm import LayerNorm @@ -653,37 +648,6 @@ def build_model(cls, args, task): decoder = cls.build_decoder(args, task, decoder_embed_tokens) return cls(encoder, decoder) - @classmethod - def build_multitask_decoder(cls, args, tgt_dict, in_dim): - decoder_args = args.decoder_args - decoder_args.encoder_embed_dim = in_dim - if args.decoder_type == "transformer": - from fairseq.models.speech_to_speech import ( - base_multitask_text_transformer_decoder_arch, - ) - - base_multitask_text_transformer_decoder_arch(decoder_args) # 2L - task_decoder = TransformerDecoder( - decoder_args, - tgt_dict, - embed_tokens=TransformerModelBase.build_embedding( - decoder_args, - tgt_dict, - decoder_args.decoder_embed_dim, - ), - ) - elif args.decoder_type == "ctc": - task_decoder = CTCDecoder( - dictionary=tgt_dict, - in_dim=in_dim, - ) - else: - raise NotImplementedError( - "currently only support multitask decoder_type 'transformer', 'ctc'" - ) - - return task_decoder - def get_normalized_probs( self, net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], diff --git a/fairseq/models/speech_to_text/xm_transformer_unity.py b/fairseq/models/speech_to_text/xm_transformer_unity.py index 9234285d32..ba61abde0e 100644 --- a/fairseq/models/speech_to_text/xm_transformer_unity.py +++ b/fairseq/models/speech_to_text/xm_transformer_unity.py @@ -12,6 +12,7 @@ register_model, register_model_architecture, ) +from fairseq.models.speech_to_speech.modules import CTCDecoder from fairseq.models.speech_to_text.xm_transformer import XMTransformerModel from fairseq.models.speech_to_text.xm_transformer import ( base_architecture as xm_t_base_architecture, @@ -24,7 +25,7 @@ set_default_transformer_decoder_args, set_default_w2v_encoder_args, ) -from fairseq.models.transformer import Linear, TransformerDecoder +from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder logger = logging.getLogger(__name__) @@ -210,6 +211,37 @@ def build_model(cls, args, task): return base_model + @classmethod + def build_multitask_decoder(cls, args, tgt_dict, in_dim): + decoder_args = args.decoder_args + decoder_args.encoder_embed_dim = in_dim + if args.decoder_type == "transformer": + from fairseq.models.speech_to_speech import ( + base_multitask_text_transformer_decoder_arch, + ) + + base_multitask_text_transformer_decoder_arch(decoder_args) # 2L + task_decoder = TransformerDecoder( + decoder_args, + tgt_dict, + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) + elif args.decoder_type == "ctc": + task_decoder = CTCDecoder( + dictionary=tgt_dict, + in_dim=in_dim, + ) + else: + raise NotImplementedError( + "currently only support multitask decoder_type 'transformer', 'ctc'" + ) + + return task_decoder + @classmethod def build_t2u_encoder(cls, args): _args = copy.deepcopy(args) From e19ed1ed74f800b607580d20a1b59839b90c9275 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 14 Sep 2022 23:47:07 -0700 Subject: [PATCH 07/35] Fix generator selection --- fairseq/tasks/speech_to_speech.py | 13 +++++-------- fairseq/tasks/speech_to_text.py | 8 ++++---- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py index 50e66df2cb..5343b61ec3 100644 --- a/fairseq/tasks/speech_to_speech.py +++ b/fairseq/tasks/speech_to_speech.py @@ -347,7 +347,7 @@ def build_model(self, args, from_checkpoint=False): return model - def build_generator_translatotron2( + def build_generator_dual_decoder( self, models, args, @@ -395,15 +395,12 @@ def build_generator( else self.vocoder.cpu() ) - from fairseq.models.speech_to_speech import ( - Translatotron2ConformerModel, - UnitYConformerModel, - ) + has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None if self.args.target_is_code: if self.args.n_frames_per_step == 1: - if isinstance(models[0], UnitYConformerModel): - seq_generator = self.build_generator_translatotron2( + if has_dual_decoder: + seq_generator = self.build_generator_dual_decoder( models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs, @@ -424,7 +421,7 @@ def build_generator( self.args.target_code_size, ) else: - if isinstance(models[0], Translatotron2ConformerModel): + if has_dual_decoder: if getattr(args, "teacher_forcing", False): raise NotImplementedError else: diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 5130069adf..f1da986419 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -162,7 +162,7 @@ def build_model(self, args, from_checkpoint=False): args.speaker_to_id = self.speaker_to_id return super(SpeechToTextTask, self).build_model(args, from_checkpoint) - def build_generator_unity( + def build_generator_dual_decoder( self, models, args, @@ -243,10 +243,10 @@ def build_generator( eos_id = self.tgt_dict.index(eos_token) if eos_token else None extra_gen_cls_kwargs["eos"] = eos_id - from fairseq.models.speech_to_text import UnitYXMTransformerModel + has_dual_decoder = getattr(models[0], "mt_task_name", None) is not None - if isinstance(models[0], UnitYXMTransformerModel): - return self.build_generator_unity( + if has_dual_decoder: + return self.build_generator_dual_decoder( models, args, extra_gen_cls_kwargs=extra_gen_cls_kwargs, From edc489cf5964ca94d2dd7d3a6e0df78509c49d2c Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 14 Sep 2022 23:50:15 -0700 Subject: [PATCH 08/35] Fix check in build_criterion --- fairseq/tasks/speech_to_speech.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py index 5343b61ec3..1be5ba8e7e 100644 --- a/fairseq/tasks/speech_to_speech.py +++ b/fairseq/tasks/speech_to_speech.py @@ -283,15 +283,17 @@ def setup_task(cls, args, **kwargs): def build_criterion(self, args): from fairseq import criterions - # if len(self.multitask_tasks) > 0: - # if self.args.target_is_code and args._name != "speech_to_unit": - # raise ValueError( - # "set --criterion speech_to_unit for speech-to-unit loss with multitask" - # ) - # elif not self.args.target_is_code and args._name != "speech_to_spectrogram": - # raise ValueError( - # "set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask" - # ) + if len(self.multitask_tasks) > 0: + if self.args.target_is_code and args._name.startswith("speech_to_unit"): + raise ValueError( + "set --criterion speech_to_unit for speech-to-unit loss with multitask" + ) + elif not self.args.target_is_code and args._name.startswith( + "speech_to_spectrogram" + ): + raise ValueError( + "set --criterion speech_to_spectrogram for speech-to-spectrogram loss with multitask" + ) return criterions.build_criterion(args, self) From 8f547db2eb29931c3525f1f0d6004fe116e50090 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 15 Sep 2022 01:28:57 -0700 Subject: [PATCH 09/35] Modularize Rdrop --- ...label_smoothed_cross_entropy_with_rdrop.py | 176 ++++++++++++++++++ .../criterions/speech_to_speech_criterion.py | 14 +- fairseq/data/audio/data_cfg.py | 2 +- 3 files changed, 184 insertions(+), 8 deletions(-) create mode 100644 fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py diff --git a/fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py b/fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py new file mode 100644 index 0000000000..80b5ea0f73 --- /dev/null +++ b/fairseq/criterions/label_smoothed_cross_entropy_with_rdrop.py @@ -0,0 +1,176 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field + +import torch + +from fairseq import metrics, utils +from fairseq.criterions import register_criterion +from fairseq.criterions.label_smoothed_cross_entropy import ( + LabelSmoothedCrossEntropyCriterion, + LabelSmoothedCrossEntropyCriterionConfig, + label_smoothed_nll_loss, +) + + +@dataclass +class RdropLabelSmoothedCrossEntropyCriterionConfig( + LabelSmoothedCrossEntropyCriterionConfig +): + rdrop_alpha: float = field( + default=0.0, + metadata={"help": "alpha for r-drop, 0 means no r-drop"}, + ) + + +@register_criterion( + "label_smoothed_cross_entropy_with_rdrop", + dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig, +) +class RdropLabelSmoothedCrossEntropyCriterion(LabelSmoothedCrossEntropyCriterion): + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=0, + report_accuracy=False, + rdrop_alpha=0.0, + ): + super().__init__( + task, + sentence_avg, + label_smoothing, + ignore_prefix_size=ignore_prefix_size, + report_accuracy=report_accuracy, + ) + self.sentence_avg = sentence_avg + self.eps = label_smoothing + self.ignore_prefix_size = ignore_prefix_size + self.report_accuracy = report_accuracy + self.rdrop_alpha = rdrop_alpha + + def forward(self, model, sample, reduce=True, net_output=None): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + if net_output is None: + if self.rdrop_alpha > 0 and sample["net_input"]["src_tokens"].size( + 0 + ) == sample["target"].size(0): + sample = duplicate_input(sample) + net_output = model(**sample["net_input"]) + loss, nll_loss, rdrop_kl_loss = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + logging_output = { + "loss": loss.data, + "nll_loss": nll_loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + } + if self.report_accuracy: + n_correct, total = self.compute_accuracy(model, net_output, sample) + logging_output["n_correct"] = utils.item(n_correct.data) + logging_output["total"] = utils.item(total.data) + if self.rdrop_alpha > 0: + logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) + return loss, sample_size, logging_output + + def get_lprobs_and_target(self, model, net_output, sample): + lprobs = model.get_normalized_probs(net_output, log_probs=True) + target = model.get_targets(sample, net_output) + if self.rdrop_alpha > 0 or target.size(0) != lprobs.size(0): + target = torch.cat([target, target.clone()], dim=0) + + if self.ignore_prefix_size > 0: + # lprobs: B x T x C + lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() + target = target[:, self.ignore_prefix_size :].contiguous() + return lprobs.view(-1, lprobs.size(-1)), target.view(-1) + + def compute_loss(self, model, net_output, sample, reduce=True): + lprobs, target = self.get_lprobs_and_target(model, net_output, sample) + loss, nll_loss = label_smoothed_nll_loss( + lprobs, + target, + self.eps, + ignore_index=self.padding_idx, + reduce=reduce, + ) + + if self.rdrop_alpha > 0: + pad_mask = target[: target.size(0) // 2].unsqueeze(-1).eq(self.padding_idx) + rdrop_kl_loss = compute_kl_loss(model, net_output, pad_mask) + loss += self.rdrop_alpha * rdrop_kl_loss + else: + rdrop_kl_loss = loss.new_zeros(1) + return loss, nll_loss, rdrop_kl_loss + + @classmethod + def reduce_metrics(cls, logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + super().reduce_metrics(logging_outputs) + + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + rdrop_kl_loss = utils.item( + sum(log.get("rdrop_kl_loss", 0) for log in logging_outputs) + / sample_size + / math.log(2) + ) + if rdrop_kl_loss > 0: + metrics.log_scalar("rdrop_kl_loss", rdrop_kl_loss) + + +def duplicate_input(sample): + if "net_input" in sample.keys(): + sample_input = sample["net_input"] + else: + sample_input = sample + + for k, v in sample_input.items(): + if isinstance(v, torch.Tensor): + sample_input[k] = torch.cat([v, v.clone()], dim=0) + if "net_input" in sample.keys(): + sample["net_input"] = sample_input + else: + sample = sample_input + return sample + + +def compute_kl_loss(model, net_output, pad_mask=None, reduce=True): + net_prob = model.get_normalized_probs(net_output, log_probs=True) + net_prob_tec = model.get_normalized_probs(net_output, log_probs=False) + + net_prob = net_prob.view(-1, net_prob.size(-1)) + net_prob_tec = net_prob_tec.view(-1, net_prob_tec.size(-1)) + + p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0) + p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0) + + p_loss = torch.nn.functional.kl_div(p, q_tec, reduction="none") + q_loss = torch.nn.functional.kl_div(q, p_tec, reduction="none") + + if pad_mask is not None: + p_loss.masked_fill_(pad_mask, 0.0) + q_loss.masked_fill_(pad_mask, 0.0) + + if reduce: + p_loss = p_loss.sum() + q_loss = q_loss.sum() + + loss = (p_loss + q_loss) / 2 + return loss diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py index 818f3559d5..246f078c7e 100644 --- a/fairseq/criterions/speech_to_speech_criterion.py +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -12,9 +12,9 @@ from fairseq import metrics, utils from fairseq.criterions import register_criterion from fairseq.criterions.ctc import CtcCriterion -from fairseq.criterions.label_smoothed_cross_entropy import ( - LabelSmoothedCrossEntropyCriterion, - LabelSmoothedCrossEntropyCriterionConfig, +from fairseq.criterions.label_smoothed_cross_entropy_with_rdrop import ( + RdropLabelSmoothedCrossEntropyCriterion, + RdropLabelSmoothedCrossEntropyCriterionConfig, duplicate_input, ) from fairseq.criterions.tacotron2_loss import ( @@ -48,7 +48,7 @@ def __init__(self, multitask_tasks, rdrop_alpha=0.0): else: self.multitask_criterion[ task_name - ] = LabelSmoothedCrossEntropyCriterion( + ] = RdropLabelSmoothedCrossEntropyCriterion( task_obj, task_obj.args.criterion_cfg.sentence_avg, label_smoothing=task_obj.args.criterion_cfg.label_smoothing, @@ -153,10 +153,10 @@ def reduce_metrics(cls, logging_outputs) -> None: @register_criterion( - "speech_to_unit", dataclass=LabelSmoothedCrossEntropyCriterionConfig + "speech_to_unit", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig ) class SpeechToUnitMultitaskTaskCriterion( - LabelSmoothedCrossEntropyCriterion, MultitaskCriterion + RdropLabelSmoothedCrossEntropyCriterion, MultitaskCriterion ): def __init__( self, @@ -252,7 +252,7 @@ def logging_outputs_can_be_summed() -> bool: @register_criterion( - "speech_to_unit_2pass", dataclass=LabelSmoothedCrossEntropyCriterionConfig + "speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig ) class SpeechToUnitTranslatotron2MultitaskTaskCriterion( SpeechToUnitMultitaskTaskCriterion diff --git a/fairseq/data/audio/data_cfg.py b/fairseq/data/audio/data_cfg.py index d621cfc931..c6ea331ffe 100644 --- a/fairseq/data/audio/data_cfg.py +++ b/fairseq/data/audio/data_cfg.py @@ -348,4 +348,4 @@ def eos_token(self): @property def rdrop_alpha(self): - return self.config.get("rdrop_alpha", 0.0) + return self.config.get("rdrop_alpha", None) From 31ac4b7564a9ce910f4fb5f4be8b8b4b41f5c7c6 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 15 Sep 2022 01:29:58 -0700 Subject: [PATCH 10/35] Minor fix --- fairseq/models/speech_to_text/s2t_conformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/models/speech_to_text/s2t_conformer.py b/fairseq/models/speech_to_text/s2t_conformer.py index 007e412eef..7bbbf42757 100644 --- a/fairseq/models/speech_to_text/s2t_conformer.py +++ b/fairseq/models/speech_to_text/s2t_conformer.py @@ -19,7 +19,9 @@ from fairseq.models.speech_to_text.s2t_transformer import ( S2TTransformerEncoder, S2TTransformerModel, - base_architecture, +) +from fairseq.models.speech_to_text.s2t_transformer import ( + base_architecture as transformer_base_architecture, ) from fairseq.modules import PositionalEmbedding, RelPositionalEncoding from fairseq.modules.conformer_layer import ConformerEncoderLayer @@ -229,4 +231,4 @@ def conformer_base_architecture(args): args.dropout = getattr(args, "dropout", 0.1) args.encoder_layers = getattr(args, "encoder_layers", 16) args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) - base_architecture(args) + transformer_base_architecture(args) From 4fecc91e7b4dc879b9461fac8af1adae27b2c1e4 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 15 Sep 2022 01:38:04 -0700 Subject: [PATCH 11/35] Refine class names --- .../criterions/speech_to_speech_criterion.py | 6 +-- .../models/speech_to_speech/s2s_conformer.py | 42 ++++++------------- .../s2s_conformer_translatotron2.py | 28 ++++++------- .../speech_to_speech/s2s_conformer_unity.py | 6 +-- .../speech_to_speech/s2s_transformer.py | 28 ++++--------- 5 files changed, 39 insertions(+), 71 deletions(-) diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py index 246f078c7e..1e96ced920 100644 --- a/fairseq/criterions/speech_to_speech_criterion.py +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -254,9 +254,7 @@ def logging_outputs_can_be_summed() -> bool: @register_criterion( "speech_to_unit_2pass", dataclass=RdropLabelSmoothedCrossEntropyCriterionConfig ) -class SpeechToUnitTranslatotron2MultitaskTaskCriterion( - SpeechToUnitMultitaskTaskCriterion -): +class SpeechToUnit2passMultitaskTaskCriterion(SpeechToUnitMultitaskTaskCriterion): def __init__( self, task, @@ -430,7 +428,7 @@ def reduce_metrics(cls, logging_outputs) -> None: @register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig) -class SpeechToSpectrogramTranslatotron2MultitaskTaskCriterion( +class SpeechToSpectrogram2passMultitaskTaskCriterion( Tacotron2Criterion, MultitaskCriterion ): def __init__( diff --git a/fairseq/models/speech_to_speech/s2s_conformer.py b/fairseq/models/speech_to_speech/s2s_conformer.py index c32dac956b..636396d536 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer.py +++ b/fairseq/models/speech_to_speech/s2s_conformer.py @@ -11,10 +11,10 @@ from fairseq import checkpoint_utils from fairseq.models import register_model, register_model_architecture from fairseq.models.speech_to_speech.s2s_transformer import ( + S2SpecTTransformerModel, S2UTTransformerModel, - TranslatotronTransformerModel, + s2spect_architecture_base, s2ut_architecture_base, - translatotron_architecture_base, ) from fairseq.models.speech_to_text import S2TConformerEncoder from fairseq.models.transformer import Linear @@ -69,7 +69,7 @@ def forward( @register_model("s2ut_conformer") class S2UTConformerModel(S2UTTransformerModel): """ - Direct speech-to-speech translation model with Conformer encoder + Transformer discrete unit decoder (S2UT) + Direct speech-to-speech translation model with Conformer encoder + Transformer discrete unit decoder """ @staticmethod @@ -99,15 +99,15 @@ def build_encoder(cls, args): return build_s2s_conformer_encoder(args) -@register_model("translatotron_conformer") -class TranslatotronConformerModel(TranslatotronTransformerModel): +@register_model("s2spect_conformer") +class S2SpecTConformerModel(S2SpecTTransformerModel): """ - Direct speech-to-speech translation model with Conformer encoder + TTS Transformer decoder (Translatotron) + Direct speech-to-speech translation model with Conformer encoder + TTS Transformer decoder """ @staticmethod def add_args(parser): - TranslatotronTransformerModel.add_args(parser) + S2SpecTTransformerModel.add_args(parser) parser.add_argument("--depthwise-conv-kernel-size", type=int, default=31) parser.add_argument( "--attn-type", @@ -143,8 +143,8 @@ def s2ut_conformer_architecture_base(args): s2ut_architecture_base(args) -@register_model_architecture("translatotron_conformer", "translatotron_conformer") -def translatotron_conformer_architecture_base(args): +@register_model_architecture("s2spect_conformer", "s2spect_conformer") +def s2spect_conformer_architecture_base(args): args.attn_type = getattr(args, "attn_type", None) args.pos_enc_type = getattr(args, "pos_enc_type", "abs") args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80) @@ -156,13 +156,11 @@ def translatotron_conformer_architecture_base(args): args.dropout = getattr(args, "dropout", 0.1) args.encoder_layers = getattr(args, "encoder_layers", 16) args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) - translatotron_architecture_base(args) + s2spect_architecture_base(args) -@register_model_architecture( - "translatotron_conformer", "translatotron_conformer_fisher" -) -def translatotron_architecture_fisher(args): +@register_model_architecture("s2spect_conformer", "s2spect_conformer_fisher") +def s2spect_architecture_fisher(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) @@ -171,18 +169,4 @@ def translatotron_architecture_fisher(args): # decoder args.prenet_dim = getattr(args, "prenet_dim", 32) - translatotron_conformer_architecture_base(args) - - -# for old models -@register_model_architecture( - model_name="translatotron_conformer", arch_name="s2spect_conformer" -) -def translatotron_conformer_architecture_base_legacy(args): - translatotron_conformer_architecture_base(args) - - -# for old models -@register_model_architecture("translatotron_conformer", "s2spect_conformer_fisher") -def translatotron_architecture_fisher_legacy(args): - translatotron_architecture_fisher(args) + s2spect_conformer_architecture_base(args) diff --git a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py index c3286c491f..d9c2c790fc 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py @@ -13,30 +13,30 @@ register_model_architecture, ) from fairseq.models.speech_to_speech.modules import CTCDecoder -from fairseq.models.speech_to_speech.s2s_conformer import TranslatotronConformerModel +from fairseq.models.speech_to_speech.s2s_conformer import S2SpecTConformerModel from fairseq.models.speech_to_speech.s2s_conformer_unity import ( TransformerEncoderNoEmb, multitask_text_transformer_decoder_arch, ) from fairseq.models.speech_to_speech.s2s_transformer import ( base_multitask_text_transformer_decoder_arch, - translatotron_architecture_base, + s2spect_architecture_base, ) from fairseq.models.text_to_speech import TTSTransformerDecoder -from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase +from fairseq.models.transformer import TransformerDecoder, TransformerModelBase logger = logging.getLogger(__name__) -@register_model("translatotron2_conformer") -class Translatotron2ConformerModel(TranslatotronConformerModel): +@register_model("s2spect2_conformer") +class S2SpecT2ConformerModel(S2SpecTConformerModel): """ - Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + TTS Transformer decoder (Translatotron2) + Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + TTS Transformer decoder """ @staticmethod def add_args(parser): - TranslatotronConformerModel.add_args(parser) + S2SpecTConformerModel.add_args(parser) parser.add_argument( "--translation-decoder-layers", type=int, @@ -238,9 +238,9 @@ def forward( @register_model_architecture( - model_name="translatotron2_conformer", arch_name="translatotron2_conformer" + model_name="s2spect2_conformer", arch_name="s2spect2_conformer" ) -def translatotron2_conformer_architecture_base(args): +def s2spect2_conformer_architecture_base(args): args.attn_type = getattr(args, "attn_type", None) args.pos_enc_type = getattr(args, "pos_enc_type", "abs") args.max_source_positions = getattr(args, "max_source_positions", 6000) @@ -250,12 +250,12 @@ def translatotron2_conformer_architecture_base(args): args.dropout = getattr(args, "dropout", 0.1) args.encoder_layers = getattr(args, "encoder_layers", 16) args.depthwise_conv_kernel_size = getattr(args, "depthwise_conv_kernel_size", 31) - translatotron_architecture_base(args) + s2spect_architecture_base(args) -# for old models +# for old naming @register_model_architecture( - model_name="translatotron2_conformer", arch_name="s2spect_conformer_translatotron2" + model_name="s2spect2_conformer", arch_name="s2spect_conformer_translatotron2" ) -def translatotron2_conformer_architecture_base_legacy(args): - translatotron2_conformer_architecture_base(args) +def s2spect2_conformer_architecture_base_legacy(args): + s2spect2_conformer_architecture_base(args) diff --git a/fairseq/models/speech_to_speech/s2s_conformer_unity.py b/fairseq/models/speech_to_speech/s2s_conformer_unity.py index f069851868..9eb766999d 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_unity.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_unity.py @@ -41,9 +41,9 @@ def multitask_text_transformer_decoder_arch( @register_model("unity_conformer") -class UnitYConformerModel(S2UTConformerModel): +class UnityConformerModel(S2UTConformerModel): """ - Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + Transformer discrete unit decoder (UnitY) + Direct speech-to-speech translation model with Conformer encoder + MT Transformer decoder + Transformer discrete unit decoder """ @staticmethod @@ -467,7 +467,7 @@ def unity_conformer_architecture_base(args): s2ut_architecture_base(args) -# for old models +# for old naming @register_model_architecture( model_name="unity_conformer", arch_name="s2ut_conformer_translatotron2" ) diff --git a/fairseq/models/speech_to_speech/s2s_transformer.py b/fairseq/models/speech_to_speech/s2s_transformer.py index 6ba8808061..5af07bb673 100644 --- a/fairseq/models/speech_to_speech/s2s_transformer.py +++ b/fairseq/models/speech_to_speech/s2s_transformer.py @@ -416,8 +416,8 @@ def forward( return decoder_out -@register_model("translatotron_transformer") -class TranslatotronTransformerModel(S2STransformerMultitaskModelBase): +@register_model("s2spect_transformer") +class S2SpecTTransformerModel(S2STransformerMultitaskModelBase): """ Speech-to-spectrogram model with S2T Transformer encoder + TTS Transformer decoder """ @@ -675,9 +675,9 @@ def s2ut_architecture_fisher(args): @register_model_architecture( - model_name="translatotron_transformer", arch_name="translatotron_transformer" + model_name="s2spect_transformer", arch_name="s2spect_transformer" ) -def translatotron_architecture_base(args): +def s2spect_architecture_base(args): base_s2st_transformer_encoder_architecture(args) # decoder @@ -701,10 +701,8 @@ def translatotron_architecture_base(args): args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4) -@register_model_architecture( - "translatotron_transformer", "translatotron_transformer_fisher" -) -def translatotron_architecture_fisher(args): +@register_model_architecture("s2spect_transformer", "s2spect_transformer_fisher") +def s2spect_architecture_fisher(args): args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256) args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 256 * 8) args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4) @@ -713,16 +711,4 @@ def translatotron_architecture_fisher(args): # decoder args.prenet_dim = getattr(args, "prenet_dim", 32) - translatotron_architecture_base(args) - - -# for old models -@register_model_architecture("translatotron_transformer", "s2spect_transformer") -def translatotron_architecture_base_legacy(args): - translatotron_architecture_base(args) - - -# for old models -@register_model_architecture("translatotron_transformer", "s2spect_transformer_fisher") -def translatotron_architecture_fisher_legacy(args): - translatotron_architecture_fisher(args) + s2spect_architecture_base(args) From 93c2ea5b99042e9d875159c92cc68603ae17ac6e Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 15 Sep 2022 02:14:06 -0700 Subject: [PATCH 12/35] Refactor submodules --- fairseq/models/speech_to_speech/__init__.py | 1 - .../speech_to_speech/modules/__init__.py | 0 .../speech_to_speech/modules/ctc_decoder.py | 18 ++ .../stacked_embedding.py} | 11 - .../modules/transformer_decoder_aug.py | 108 ++++++++++ .../modules/transformer_encoder.py | 85 ++++++++ .../s2s_conformer_translatotron2.py | 6 +- .../speech_to_speech/s2s_conformer_unity.py | 192 +----------------- .../speech_to_speech/s2s_transformer.py | 3 +- 9 files changed, 226 insertions(+), 198 deletions(-) create mode 100644 fairseq/models/speech_to_speech/modules/__init__.py create mode 100644 fairseq/models/speech_to_speech/modules/ctc_decoder.py rename fairseq/models/speech_to_speech/{modules.py => modules/stacked_embedding.py} (83%) create mode 100644 fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py create mode 100644 fairseq/models/speech_to_speech/modules/transformer_encoder.py diff --git a/fairseq/models/speech_to_speech/__init__.py b/fairseq/models/speech_to_speech/__init__.py index 76fd1ef7ec..f29215c2fe 100644 --- a/fairseq/models/speech_to_speech/__init__.py +++ b/fairseq/models/speech_to_speech/__init__.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import * # noqa from .s2s_conformer import * # noqa from .s2s_conformer_translatotron2 import * # noqa from .s2s_conformer_unity import * # noqa diff --git a/fairseq/models/speech_to_speech/modules/__init__.py b/fairseq/models/speech_to_speech/modules/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/fairseq/models/speech_to_speech/modules/ctc_decoder.py b/fairseq/models/speech_to_speech/modules/ctc_decoder.py new file mode 100644 index 0000000000..721efbf61a --- /dev/null +++ b/fairseq/models/speech_to_speech/modules/ctc_decoder.py @@ -0,0 +1,18 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn + +from fairseq.models import FairseqEncoder + + +class CTCDecoder(FairseqEncoder): + def __init__(self, dictionary, in_dim): + super().__init__(dictionary) + self.proj = nn.Linear(in_dim, len(dictionary)) + + def forward(self, src_tokens, src_lengths=None, **kwargs): + encoder_out = self.proj(src_tokens) + return {"encoder_out": encoder_out} diff --git a/fairseq/models/speech_to_speech/modules.py b/fairseq/models/speech_to_speech/modules/stacked_embedding.py similarity index 83% rename from fairseq/models/speech_to_speech/modules.py rename to fairseq/models/speech_to_speech/modules/stacked_embedding.py index a2049816ab..5955a08538 100644 --- a/fairseq/models/speech_to_speech/modules.py +++ b/fairseq/models/speech_to_speech/modules/stacked_embedding.py @@ -6,20 +6,9 @@ import torch from torch import nn -from fairseq.models import FairseqEncoder from fairseq.models.transformer import Linear -class CTCDecoder(FairseqEncoder): - def __init__(self, dictionary, in_dim): - super().__init__(dictionary) - self.proj = nn.Linear(in_dim, len(dictionary)) - - def forward(self, src_tokens, src_lengths=None, **kwargs): - encoder_out = self.proj(src_tokens) - return {"encoder_out": encoder_out} - - class StackedEmbedding(nn.Embedding): """Embedding module that supports stacked units -> single embedding""" diff --git a/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py b/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py new file mode 100644 index 0000000000..6650eed415 --- /dev/null +++ b/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List, Optional + +from torch import Tensor + +from fairseq.models.transformer import Linear +from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder + + +class AugTransformerUnitDecoder(AugTransformerDecoder): + """Based on Transformer decoder, with support to decoding stacked units""" + + def __init__( + self, + args, + dictionary, + embed_tokens, + no_encoder_attn=False, + output_projection=None, + ): + super().__init__( + args, dictionary, embed_tokens, no_encoder_attn, output_projection + ) + self.n_frames_per_step = args.n_frames_per_step + + self.out_proj_n_frames = ( + Linear( + self.output_embed_dim, + self.output_embed_dim * self.n_frames_per_step, + bias=False, + ) + if self.n_frames_per_step > 1 + else None + ) + + def forward( + self, + prev_output_tokens, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, + encoder_out2: Optional[Dict[str, List[Tensor]]] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + features_only: bool = False, + full_context_alignment: bool = False, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, + src_lengths: Optional[Any] = None, + return_all_hiddens: bool = False, + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (optional): output from the encoder, used for + encoder-side attention, should be of size T x B x C + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + features_only (bool, optional): only return features without + applying output layer (default: False). + full_context_alignment (bool, optional): don't apply + auto-regressive mask to self-attention (default: False). + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + + x, extra = self.extract_features( + prev_output_tokens, + encoder_out=encoder_out, + encoder_out2=encoder_out2, + incremental_state=incremental_state, + full_context_alignment=full_context_alignment, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + ) + + if not features_only: + bsz, seq_len, d = x.size() + if self.out_proj_n_frames: + x = self.out_proj_n_frames(x) + x = self.output_layer(x.view(bsz, seq_len, self.n_frames_per_step, d)) + x = x.view(bsz, seq_len * self.n_frames_per_step, -1) + if ( + incremental_state is None and self.n_frames_per_step > 1 + ): # teacher-forcing mode in training + x = x[ + :, : -(self.n_frames_per_step - 1), : + ] # remove extra frames after + + return x, extra + + def upgrade_state_dict_named(self, state_dict, name): + if self.n_frames_per_step > 1: + move_keys = [ + ( + f"{name}.project_in_dim.weight", + f"{name}.embed_tokens.project_in_dim.weight", + ) + ] + for from_k, to_k in move_keys: + if from_k in state_dict and to_k not in state_dict: + state_dict[to_k] = state_dict[from_k] + del state_dict[from_k] diff --git a/fairseq/models/speech_to_speech/modules/transformer_encoder.py b/fairseq/models/speech_to_speech/modules/transformer_encoder.py new file mode 100644 index 0000000000..fb1af433d8 --- /dev/null +++ b/fairseq/models/speech_to_speech/modules/transformer_encoder.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from fairseq.models import FairseqEncoder +from fairseq.modules import LayerNorm, TransformerEncoderLayer + + +class TransformerEncoderNoEmb(FairseqEncoder): + """Transformer encoder without token embeddings.""" + + def __init__(self, args): + super().__init__(None) + + self.layers = nn.ModuleList( + [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] + ) + if args.encoder_normalize_before: + self.layer_norm = LayerNorm(args.encoder_embed_dim) + else: + self.layer_norm = None + + def forward(self, x, encoder_padding_mask, return_all_hiddens=False): + + encoder_states = [] + + for layer in self.layers: + x = layer(x, encoder_padding_mask) + if return_all_hiddens: + encoder_states.append(x) + + if self.layer_norm is not None: + x = self.layer_norm(x) + + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask] + if encoder_padding_mask is not None and encoder_padding_mask.any() + else [], # B x T + "encoder_embedding": [], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } + + def reorder_encoder_out(self, encoder_out, new_order): + new_encoder_out = ( + [] + if len(encoder_out["encoder_out"]) == 0 + else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] + ) + + new_encoder_padding_mask = ( + [] + if len(encoder_out["encoder_padding_mask"]) == 0 + else [ + x.index_select(0, new_order) + for x in encoder_out["encoder_padding_mask"] + ] + ) + + new_encoder_embedding = ( + [] + if len(encoder_out["encoder_embedding"]) == 0 + else [ + x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] + ] + ) + + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: + for idx, state in enumerate(encoder_states): + encoder_states[idx] = state.index_select(1, new_order) + + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], # B x T + "src_lengths": [], # B x 1 + } diff --git a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py index d9c2c790fc..d146b31e51 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py @@ -12,10 +12,12 @@ register_model, register_model_architecture, ) -from fairseq.models.speech_to_speech.modules import CTCDecoder +from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder +from fairseq.models.speech_to_speech.modules.transformer_encoder import ( + TransformerEncoderNoEmb, +) from fairseq.models.speech_to_speech.s2s_conformer import S2SpecTConformerModel from fairseq.models.speech_to_speech.s2s_conformer_unity import ( - TransformerEncoderNoEmb, multitask_text_transformer_decoder_arch, ) from fairseq.models.speech_to_speech.s2s_transformer import ( diff --git a/fairseq/models/speech_to_speech/s2s_conformer_unity.py b/fairseq/models/speech_to_speech/s2s_conformer_unity.py index 9eb766999d..7216a06b38 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_unity.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_unity.py @@ -5,10 +5,6 @@ import copy import logging -from typing import Any, Dict, List, Optional - -import torch.nn as nn -from torch import Tensor from fairseq.models import ( FairseqEncoder, @@ -17,16 +13,21 @@ register_model, register_model_architecture, ) -from fairseq.models.speech_to_speech.modules import CTCDecoder +from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder +from fairseq.models.speech_to_speech.modules.stacked_embedding import StackedEmbedding +from fairseq.models.speech_to_speech.modules.transformer_decoder_aug import ( + AugTransformerUnitDecoder, +) +from fairseq.models.speech_to_speech.modules.transformer_encoder import ( + TransformerEncoderNoEmb, +) from fairseq.models.speech_to_speech.s2s_conformer import S2UTConformerModel from fairseq.models.speech_to_speech.s2s_transformer import ( TransformerUnitDecoder, base_multitask_text_transformer_decoder_arch, s2ut_architecture_base, ) -from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase -from fairseq.models.transformer.transformer_decoder_aug import AugTransformerDecoder -from fairseq.modules import LayerNorm, TransformerEncoderLayer +from fairseq.models.transformer import TransformerDecoder, TransformerModelBase logger = logging.getLogger(__name__) @@ -122,8 +123,6 @@ def build_multitask_decoder( @classmethod def build_decoder(cls, args, tgt_dict, aug_attn=False): - from fairseq.models.speech_to_speech.modules import StackedEmbedding - num_embeddings = len(tgt_dict) padding_idx = tgt_dict.pad() embed_tokens = StackedEmbedding( @@ -280,179 +279,6 @@ def forward( return decoder_out -class TransformerEncoderNoEmb(FairseqEncoder): - """Transformer encoder without token embeddings.""" - - def __init__(self, args): - super().__init__(None) - - self.layers = nn.ModuleList( - [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] - ) - if args.encoder_normalize_before: - self.layer_norm = LayerNorm(args.encoder_embed_dim) - else: - self.layer_norm = None - - def forward(self, x, encoder_padding_mask, return_all_hiddens=False): - - encoder_states = [] - - for layer in self.layers: - x = layer(x, encoder_padding_mask) - if return_all_hiddens: - encoder_states.append(x) - - if self.layer_norm is not None: - x = self.layer_norm(x) - - return { - "encoder_out": [x], # T x B x C - "encoder_padding_mask": [encoder_padding_mask] - if encoder_padding_mask is not None and encoder_padding_mask.any() - else [], # B x T - "encoder_embedding": [], # B x T x C - "encoder_states": encoder_states, # List[T x B x C] - "src_tokens": [], - "src_lengths": [], - } - - def reorder_encoder_out(self, encoder_out, new_order): - new_encoder_out = ( - [] - if len(encoder_out["encoder_out"]) == 0 - else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] - ) - - new_encoder_padding_mask = ( - [] - if len(encoder_out["encoder_padding_mask"]) == 0 - else [ - x.index_select(0, new_order) - for x in encoder_out["encoder_padding_mask"] - ] - ) - - new_encoder_embedding = ( - [] - if len(encoder_out["encoder_embedding"]) == 0 - else [ - x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] - ] - ) - - encoder_states = encoder_out["encoder_states"] - if len(encoder_states) > 0: - for idx, state in enumerate(encoder_states): - encoder_states[idx] = state.index_select(1, new_order) - - return { - "encoder_out": new_encoder_out, # T x B x C - "encoder_padding_mask": new_encoder_padding_mask, # B x T - "encoder_embedding": new_encoder_embedding, # B x T x C - "encoder_states": encoder_states, # List[T x B x C] - "src_tokens": [], # B x T - "src_lengths": [], # B x 1 - } - - -class AugTransformerUnitDecoder(AugTransformerDecoder): - """Based on Transformer decoder, with support to decoding stacked units""" - - def __init__( - self, - args, - dictionary, - embed_tokens, - no_encoder_attn=False, - output_projection=None, - ): - super().__init__( - args, dictionary, embed_tokens, no_encoder_attn, output_projection - ) - self.n_frames_per_step = args.n_frames_per_step - - self.out_proj_n_frames = ( - Linear( - self.output_embed_dim, - self.output_embed_dim * self.n_frames_per_step, - bias=False, - ) - if self.n_frames_per_step > 1 - else None - ) - - def forward( - self, - prev_output_tokens, - encoder_out: Optional[Dict[str, List[Tensor]]] = None, - encoder_out2: Optional[Dict[str, List[Tensor]]] = None, - incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, - features_only: bool = False, - full_context_alignment: bool = False, - alignment_layer: Optional[int] = None, - alignment_heads: Optional[int] = None, - src_lengths: Optional[Any] = None, - return_all_hiddens: bool = False, - ): - """ - Args: - prev_output_tokens (LongTensor): previous decoder outputs of shape - `(batch, tgt_len)`, for teacher forcing - encoder_out (optional): output from the encoder, used for - encoder-side attention, should be of size T x B x C - incremental_state (dict): dictionary used for storing state during - :ref:`Incremental decoding` - features_only (bool, optional): only return features without - applying output layer (default: False). - full_context_alignment (bool, optional): don't apply - auto-regressive mask to self-attention (default: False). - - Returns: - tuple: - - the decoder's output of shape `(batch, tgt_len, vocab)` - - a dictionary with any model-specific outputs - """ - - x, extra = self.extract_features( - prev_output_tokens, - encoder_out=encoder_out, - encoder_out2=encoder_out2, - incremental_state=incremental_state, - full_context_alignment=full_context_alignment, - alignment_layer=alignment_layer, - alignment_heads=alignment_heads, - ) - - if not features_only: - bsz, seq_len, d = x.size() - if self.out_proj_n_frames: - x = self.out_proj_n_frames(x) - x = self.output_layer(x.view(bsz, seq_len, self.n_frames_per_step, d)) - x = x.view(bsz, seq_len * self.n_frames_per_step, -1) - if ( - incremental_state is None and self.n_frames_per_step > 1 - ): # teacher-forcing mode in training - x = x[ - :, : -(self.n_frames_per_step - 1), : - ] # remove extra frames after - - return x, extra - - def upgrade_state_dict_named(self, state_dict, name): - if self.n_frames_per_step > 1: - move_keys = [ - ( - f"{name}.project_in_dim.weight", - f"{name}.embed_tokens.project_in_dim.weight", - ) - ] - for from_k, to_k in move_keys: - if from_k in state_dict and to_k not in state_dict: - state_dict[to_k] = state_dict[from_k] - del state_dict[from_k] - - @register_model_architecture(model_name="unity_conformer", arch_name="unity_conformer") def unity_conformer_architecture_base(args): args.attn_type = getattr(args, "attn_type", None) diff --git a/fairseq/models/speech_to_speech/s2s_transformer.py b/fairseq/models/speech_to_speech/s2s_transformer.py index 5af07bb673..ae855aeb1e 100644 --- a/fairseq/models/speech_to_speech/s2s_transformer.py +++ b/fairseq/models/speech_to_speech/s2s_transformer.py @@ -18,7 +18,8 @@ register_model, register_model_architecture, ) -from fairseq.models.speech_to_speech.modules import CTCDecoder, StackedEmbedding +from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder +from fairseq.models.speech_to_speech.modules.stacked_embedding import StackedEmbedding from fairseq.models.speech_to_text import S2TTransformerEncoder from fairseq.models.text_to_speech import TTSTransformerDecoder from fairseq.models.transformer import Linear, TransformerDecoder, TransformerModelBase From a35d790eab043786415de8976a0f4a4ff660d919 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Thu, 15 Sep 2022 20:54:04 -0700 Subject: [PATCH 13/35] Fix CE --- .../label_smoothed_cross_entropy.py | 90 +------------------ .../criterions/speech_to_speech_criterion.py | 4 +- 2 files changed, 5 insertions(+), 89 deletions(-) diff --git a/fairseq/criterions/label_smoothed_cross_entropy.py b/fairseq/criterions/label_smoothed_cross_entropy.py index 0d6ee79ae8..257466903f 100644 --- a/fairseq/criterions/label_smoothed_cross_entropy.py +++ b/fairseq/criterions/label_smoothed_cross_entropy.py @@ -20,10 +20,6 @@ class LabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass): default=0.0, metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"}, ) - rdrop_alpha: float = field( - default=0.0, - metadata={"help": "alpha for r-drop, 0 means no r-drop"}, - ) report_accuracy: bool = field( default=False, metadata={"help": "report accuracy metric"}, @@ -66,16 +62,14 @@ def __init__( label_smoothing, ignore_prefix_size=0, report_accuracy=False, - rdrop_alpha=0.0, ): super().__init__(task) self.sentence_avg = sentence_avg self.eps = label_smoothing self.ignore_prefix_size = ignore_prefix_size self.report_accuracy = report_accuracy - self.rdrop_alpha = rdrop_alpha - def forward(self, model, sample, reduce=True, net_output=None): + def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: @@ -83,15 +77,8 @@ def forward(self, model, sample, reduce=True, net_output=None): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - if net_output is None: - if self.rdrop_alpha > 0 and sample["net_input"]["src_tokens"].size( - 0 - ) == sample["target"].size(0): - sample = duplicate_input(sample) - net_output = model(**sample["net_input"]) - loss, nll_loss, rdrop_kl_loss = self.compute_loss( - model, net_output, sample, reduce=reduce - ) + net_output = model(**sample["net_input"]) + loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) sample_size = ( sample["target"].size(0) if self.sentence_avg else sample["ntokens"] ) @@ -106,16 +93,11 @@ def forward(self, model, sample, reduce=True, net_output=None): n_correct, total = self.compute_accuracy(model, net_output, sample) logging_output["n_correct"] = utils.item(n_correct.data) logging_output["total"] = utils.item(total.data) - if self.rdrop_alpha > 0: - logging_output["rdrop_kl_loss"] = utils.item(rdrop_kl_loss.data) return loss, sample_size, logging_output def get_lprobs_and_target(self, model, net_output, sample): lprobs = model.get_normalized_probs(net_output, log_probs=True) target = model.get_targets(sample, net_output) - if self.rdrop_alpha > 0 or target.size(0) != lprobs.size(0): - target = torch.cat([target, target.clone()], dim=0) - if self.ignore_prefix_size > 0: # lprobs: B x T x C lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous() @@ -133,24 +115,6 @@ def compute_loss(self, model, net_output, sample, reduce=True): ) return loss, nll_loss - def compute_loss_with_rdrop(self, model, net_output, sample, reduce=True): - lprobs, target = self.get_lprobs_and_target(model, net_output, sample) - loss, nll_loss = label_smoothed_nll_loss( - lprobs, - target, - self.eps, - ignore_index=self.padding_idx, - reduce=reduce, - ) - - if self.rdrop_alpha > 0: - pad_mask = target[: target.size(0) // 2].unsqueeze(-1).eq(self.padding_idx) - rdrop_kl_loss = compute_kl_loss(model, net_output, pad_mask) - loss += self.rdrop_alpha * rdrop_kl_loss - else: - rdrop_kl_loss = loss.new_zeros(1) - return loss, nll_loss, rdrop_kl_loss - def compute_accuracy(self, model, net_output, sample): lprobs, target = self.get_lprobs_and_target(model, net_output, sample) mask = target.ne(self.padding_idx) @@ -193,13 +157,6 @@ def reduce_metrics(cls, logging_outputs) -> None: if meters["total"].sum > 0 else float("nan"), ) - rdrop_kl_loss = utils.item( - sum(log.get("rdrop_kl_loss", 0) for log in logging_outputs) - / sample_size - / math.log(2) - ) - if rdrop_kl_loss > 0: - metrics.log_scalar("rdrop_kl_loss", rdrop_kl_loss) @staticmethod def logging_outputs_can_be_summed() -> bool: @@ -209,44 +166,3 @@ def logging_outputs_can_be_summed() -> bool: to True will improves distributed training speed. """ return True - - -def duplicate_input(sample): - if "net_input" in sample.keys(): - sample_input = sample["net_input"] - else: - sample_input = sample - - for k, v in sample_input.items(): - if isinstance(v, torch.Tensor): - sample_input[k] = torch.cat([v, v.clone()], dim=0) - if "net_input" in sample.keys(): - sample["net_input"] = sample_input - else: - sample = sample_input - return sample - - -def compute_kl_loss(model, net_output, pad_mask=None, reduce=True): - net_prob = model.get_normalized_probs(net_output, log_probs=True) - net_prob_tec = model.get_normalized_probs(net_output, log_probs=False) - - net_prob = net_prob.view(-1, net_prob.size(-1)) - net_prob_tec = net_prob_tec.view(-1, net_prob_tec.size(-1)) - - p, q = torch.split(net_prob, net_prob.size(0) // 2, dim=0) - p_tec, q_tec = torch.split(net_prob_tec, net_prob_tec.size(0) // 2, dim=0) - - p_loss = torch.nn.functional.kl_div(p, q_tec, reduction="none") - q_loss = torch.nn.functional.kl_div(q, p_tec, reduction="none") - - if pad_mask is not None: - p_loss.masked_fill_(pad_mask, 0.0) - q_loss.masked_fill_(pad_mask, 0.0) - - if reduce: - p_loss = p_loss.sum() - q_loss = q_loss.sum() - - loss = (p_loss + q_loss) / 2 - return loss diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py index 1e96ced920..d8e89916f3 100644 --- a/fairseq/criterions/speech_to_speech_criterion.py +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -190,7 +190,7 @@ def forward(self, model, sample, reduce=True): net_input_concat = duplicate_input(net_input_concat) net_output, extra = model(**net_input_concat) - loss, nll_loss, rdrop_kl_loss = self.compute_loss_with_rdrop( + loss, nll_loss, rdrop_kl_loss = self.compute_loss( model, [net_output], sample, reduce=reduce ) sample_size = ( @@ -293,7 +293,7 @@ def forward(self, model, sample, reduce=True): net_input_concat = duplicate_input(net_input_concat) net_output, extra = model(**net_input_concat) - loss, nll_loss, rdrop_kl_loss = self.compute_loss_with_rdrop( + loss, nll_loss, rdrop_kl_loss = self.compute_loss( model, [net_output], sample, reduce=reduce ) From 7c230303119233abd99d2fe11e5ef5a1d5e814c6 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Fri, 16 Sep 2022 01:24:00 -0700 Subject: [PATCH 14/35] Fix import --- fairseq/models/speech_to_text/xm_transformer_unity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/models/speech_to_text/xm_transformer_unity.py b/fairseq/models/speech_to_text/xm_transformer_unity.py index ba61abde0e..e04c05075c 100644 --- a/fairseq/models/speech_to_text/xm_transformer_unity.py +++ b/fairseq/models/speech_to_text/xm_transformer_unity.py @@ -12,7 +12,7 @@ register_model, register_model_architecture, ) -from fairseq.models.speech_to_speech.modules import CTCDecoder +from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder from fairseq.models.speech_to_text.xm_transformer import XMTransformerModel from fairseq.models.speech_to_text.xm_transformer import ( base_architecture as xm_t_base_architecture, From 08833136bcceb76655acfd49bbe776306a37db53 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Fri, 16 Sep 2022 01:44:18 -0700 Subject: [PATCH 15/35] Fix argments for datasets --- .../data/audio/speech_to_speech_dataset.py | 30 +++++++------- fairseq/data/audio/speech_to_text_dataset.py | 41 ++++++++++--------- 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/fairseq/data/audio/speech_to_speech_dataset.py b/fairseq/data/audio/speech_to_speech_dataset.py index 833bcedc54..87ab89f300 100644 --- a/fairseq/data/audio/speech_to_speech_dataset.py +++ b/fairseq/data/audio/speech_to_speech_dataset.py @@ -233,8 +233,8 @@ def collater( class SpeechToSpeechMultitaskDataset(SpeechToSpeechDataset): - def __init__(self, *argv): - super().__init__(*argv) + def __init__(self, **kwargs): + super().__init__(**kwargs) self.multitask_data = {} def add_multitask_dataset(self, task_name, task_data): @@ -325,19 +325,19 @@ def _from_list( ) ds = dataset_cls( - split_name, - is_train_split, - data_cfg, - src_audio_paths, - src_n_frames, - tgt_audio_paths, - tgt_n_frames, - src_langs, - tgt_langs, - ids, - target_is_code, - target_dictionary, - n_frames_per_step, + split=split_name, + is_train_split=is_train_split, + data_cfg=data_cfg, + src_audio_paths=src_audio_paths, + src_n_frames=src_n_frames, + tgt_audio_paths=tgt_audio_paths, + tgt_n_frames=tgt_n_frames, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + target_is_code=target_is_code, + target_dictionary=target_dictionary, + n_frames_per_step=n_frames_per_step, ) if has_multitask: diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index 046d0373e3..1333c5ec43 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -498,8 +498,8 @@ def collater(self, samples: List[torch.Tensor]) -> torch.Tensor: class SpeechToTextMultitaskDataset(SpeechToTextDataset): - def __init__(self, *argv): - super().__init__(*argv) + def __init__(self, **kwargs): + super().__init__(**kwargs) self.multitask_data = {} def add_multitask_dataset(self, task_name, task_data): @@ -586,24 +586,25 @@ def _from_list( SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset ) - ds = dataset_cls( - split_name, - is_train_split, - cfg, - audio_paths, - n_frames, - src_texts, - tgt_texts, - speakers, - src_langs, - tgt_langs, - ids, - tgt_dict, - pre_tokenizer, - bpe_tokenizer, - n_frames_per_step, - speaker_to_id, - ) + if has_multitask: + ds = dataset_cls( + split=split_name, + is_train_split=is_train_split, + cfg=cfg, + audio_paths=audio_paths, + n_frames=n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + ) if has_multitask: for task_name, task_obj in multitask.items(): From 0c3e731e71ef8283902fef9dcd1d44d896cb8576 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Fri, 16 Sep 2022 10:22:40 -0700 Subject: [PATCH 16/35] Add description to AugTransformerDecoderBase --- .../models/transformer/transformer_decoder.py | 4 +-- .../transformer/transformer_decoder_aug.py | 32 +++++++++---------- fairseq/modules/transformer_layer.py | 2 +- fairseq/modules/transformer_layer_aug.py | 28 +++++++--------- 4 files changed, 30 insertions(+), 36 deletions(-) diff --git a/fairseq/models/transformer/transformer_decoder.py b/fairseq/models/transformer/transformer_decoder.py index 1a0f978b3b..c22e5625d4 100644 --- a/fairseq/models/transformer/transformer_decoder.py +++ b/fairseq/models/transformer/transformer_decoder.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +from torch import Tensor from fairseq import utils from fairseq.distributed import fsdp_wrap @@ -25,7 +26,6 @@ ) from fairseq.modules.checkpoint_activations import checkpoint_wrapper from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ -from torch import Tensor # rewrite name for backward compatibility in `make_generation_fast_` @@ -42,7 +42,7 @@ class TransformerDecoderBase(FairseqIncrementalDecoder): is a :class:`TransformerDecoderLayer`. Args: - args (argparse.Namespace): parsed command-line arguments + cfg (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding no_encoder_attn (bool, optional): whether to attend to encoder outputs diff --git a/fairseq/models/transformer/transformer_decoder_aug.py b/fairseq/models/transformer/transformer_decoder_aug.py index 0a35db13c7..3f0603045d 100644 --- a/fairseq/models/transformer/transformer_decoder_aug.py +++ b/fairseq/models/transformer/transformer_decoder_aug.py @@ -23,15 +23,19 @@ class AugTransformerDecoderBase(TransformerDecoderBase): """ - Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer - is a :class:`TransformerDecoderLayer`. + Transformer decoder augmented with an additional cross-attention. Each layer + is a :class:`AugTransformerDecoderLayerBase`. Args: - args (argparse.Namespace): parsed command-line arguments + cfg (argparse.Namespace): parsed command-line arguments dictionary (~fairseq.data.Dictionary): decoding dictionary embed_tokens (torch.nn.Embedding): output embedding - no_encoder_attn (bool, optional): whether to attend to encoder outputs - (default: False). + encoder_attn_merge_type (str, optional): the way to combine outputs from + two cross-attention modules. If "sequential" is set, two cross-attention + modules are stacked sequentially. If "parallel" is set, they are processed + in parallel and combined before feeding it to FFN (default: sequential). + dropnet_ratio (float, optional): a probability to drop each cross-attention + module during training (default: 0.0). """ def __init__( @@ -39,16 +43,15 @@ def __init__( cfg, dictionary, embed_tokens, - no_encoder_attn=False, output_projection=None, encoder_attn_merge_type="sequential", - dropnet_ratio=0, + dropnet_ratio=0.0, ): super().__init__( cfg, dictionary, embed_tokens, - no_encoder_attn=no_encoder_attn, + no_encoder_attn=False, output_projection=output_projection, ) # assert cfg.cross_self_attention @@ -60,9 +63,7 @@ def __init__( self.layers = nn.ModuleList([]) self.layers.extend( [ - self.build_decoder_layer( - cfg, no_encoder_attn, encoder_attn_merge_type, dropnet_ratio - ) + self.build_decoder_layer(cfg, encoder_attn_merge_type, dropnet_ratio) for _ in range(cfg.decoder.layers) ] ) @@ -70,13 +71,12 @@ def __init__( def build_decoder_layer( self, cfg, - no_encoder_attn=False, encoder_attn_merge_type="sequential", dropnet_ratio=0, ): layer = transformer_layer_aug.AugTransformerDecoderLayerBase( cfg, - no_encoder_attn, + no_encoder_attn=False, encoder_attn_merge_type=encoder_attn_merge_type, dropnet_ratio=dropnet_ratio, ) @@ -355,7 +355,6 @@ def __init__( args, dictionary, embed_tokens, - no_encoder_attn=False, output_projection=None, ): self.args = args @@ -363,7 +362,7 @@ def __init__( TransformerConfig.from_namespace(args), dictionary, embed_tokens, - no_encoder_attn=no_encoder_attn, + no_encoder_attn=False, output_projection=output_projection, encoder_attn_merge_type=getattr( args, "synthesizer_augmented_cross_attention_merge_type", "sequential" @@ -379,13 +378,12 @@ def build_output_projection(self, args, dictionary, embed_tokens): def build_decoder_layer( self, args, - no_encoder_attn=False, encoder_attn_merge_type="sequential", dropnet_ratio=0, ): return super().build_decoder_layer( TransformerConfig.from_namespace(args), - no_encoder_attn=no_encoder_attn, + no_encoder_attn=False, encoder_attn_merge_type=encoder_attn_merge_type, dropnet_ratio=dropnet_ratio, ) diff --git a/fairseq/modules/transformer_layer.py b/fairseq/modules/transformer_layer.py index 4a283762b8..19e035dec5 100644 --- a/fairseq/modules/transformer_layer.py +++ b/fairseq/modules/transformer_layer.py @@ -28,7 +28,7 @@ class TransformerEncoderLayerBase(nn.Module): *cfg.encoder.normalize_before* to ``True``. Args: - args (argparse.Namespace): parsed command-line arguments + cfg (argparse.Namespace): parsed command-line arguments """ def __init__(self, cfg, return_fc=False): diff --git a/fairseq/modules/transformer_layer_aug.py b/fairseq/modules/transformer_layer_aug.py index b63bdbd77f..2acd5e2e8f 100644 --- a/fairseq/modules/transformer_layer_aug.py +++ b/fairseq/modules/transformer_layer_aug.py @@ -14,39 +14,35 @@ class AugTransformerDecoderLayerBase(TransformerDecoderLayerBase): - """Decoder layer block. + """Decoder layer block augmented with an additional cross-attention. - In the original paper each operation (multi-head attention, encoder - attention or FFN) is postprocessed with: `dropout -> add residual -> - layernorm`. In the tensor2tensor code they suggest that learning is more - robust when preprocessing each layer with layernorm and postprocessing with: - `dropout -> add residual`. We default to the approach in the paper, but the - tensor2tensor approach can be enabled by setting - *cfg.decoder.normalize_before* to ``True``. + This decoder block is processed with the sequence of the following sub-modules. + self-attention -> cross-attention (first) -> cross-attention (second) -> FFN Args: - args (argparse.Namespace): parsed command-line arguments - no_encoder_attn (bool, optional): whether to attend to encoder outputs - (default: False). + cfg (argparse.Namespace): parsed command-line arguments + encoder_attn_merge_type (str, optional): the way to combine outputs from + two cross-attention modules. If "sequential" is set, two cross-attention + modules are stacked sequentially. If "parallel" is set, they are processed + in parallel and combined before feeding it to FFN (default: sequential). + dropnet_ratio (float, optional): a probability to drop each cross-attention + module during training (default: 0.0). """ def __init__( self, cfg, - no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, encoder_attn_merge_type="sequential", - dropnet_ratio=0, + dropnet_ratio=0.0, ): super().__init__( cfg, - no_encoder_attn=no_encoder_attn, + no_encoder_attn=False, add_bias_kv=add_bias_kv, add_zero_attn=False, ) - assert not no_encoder_attn - self.encoder_attn = self.build_encoder_attention(self.embed_dim, cfg) self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=cfg.export) self.encoder_attn2 = self.build_encoder_attention(self.embed_dim, cfg) From 8a3f42ad8daff0d559312eef59bef8ecaee5603f Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Fri, 16 Sep 2022 11:15:00 -0700 Subject: [PATCH 17/35] Fix SpeechToTextDatasetCreator --- fairseq/data/audio/speech_to_text_dataset.py | 37 ++++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index 1333c5ec43..b690cc206e 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -586,25 +586,24 @@ def _from_list( SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset ) - if has_multitask: - ds = dataset_cls( - split=split_name, - is_train_split=is_train_split, - cfg=cfg, - audio_paths=audio_paths, - n_frames=n_frames, - src_texts=src_texts, - tgt_texts=tgt_texts, - speakers=speakers, - src_langs=src_langs, - tgt_langs=tgt_langs, - ids=ids, - tgt_dict=tgt_dict, - pre_tokenizer=pre_tokenizer, - bpe_tokenizer=bpe_tokenizer, - n_frames_per_step=n_frames_per_step, - speaker_to_id=speaker_to_id, - ) + ds = dataset_cls( + split=split_name, + is_train_split=is_train_split, + cfg=cfg, + audio_paths=audio_paths, + n_frames=n_frames, + src_texts=src_texts, + tgt_texts=tgt_texts, + speakers=speakers, + src_langs=src_langs, + tgt_langs=tgt_langs, + ids=ids, + tgt_dict=tgt_dict, + pre_tokenizer=pre_tokenizer, + bpe_tokenizer=bpe_tokenizer, + n_frames_per_step=n_frames_per_step, + speaker_to_id=speaker_to_id, + ) if has_multitask: for task_name, task_obj in multitask.items(): From 1863f8daec933524db3d354962fe4c41b531dc88 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 20 Sep 2022 01:41:18 -0700 Subject: [PATCH 18/35] Fix metavar in arguments --- .../models/speech_to_speech/s2s_conformer_translatotron2.py | 1 + fairseq/models/speech_to_speech/s2s_conformer_unity.py | 1 + fairseq/models/speech_to_speech/s2s_transformer.py | 6 ++---- fairseq/models/speech_to_text/s2t_transformer.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py index d146b31e51..8637e0de9f 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py @@ -243,6 +243,7 @@ def forward( model_name="s2spect2_conformer", arch_name="s2spect2_conformer" ) def s2spect2_conformer_architecture_base(args): + args.conv_version = getattr(args, "conv_version", "convtransformer") args.attn_type = getattr(args, "attn_type", None) args.pos_enc_type = getattr(args, "pos_enc_type", "abs") args.max_source_positions = getattr(args, "max_source_positions", 6000) diff --git a/fairseq/models/speech_to_speech/s2s_conformer_unity.py b/fairseq/models/speech_to_speech/s2s_conformer_unity.py index 7216a06b38..c759ef36a2 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_unity.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_unity.py @@ -281,6 +281,7 @@ def forward( @register_model_architecture(model_name="unity_conformer", arch_name="unity_conformer") def unity_conformer_architecture_base(args): + args.conv_version = getattr(args, "conv_version", "convtransformer") args.attn_type = getattr(args, "attn_type", None) args.pos_enc_type = getattr(args, "pos_enc_type", "abs") args.max_source_positions = getattr(args, "max_source_positions", 6000) diff --git a/fairseq/models/speech_to_speech/s2s_transformer.py b/fairseq/models/speech_to_speech/s2s_transformer.py index 074d84ad28..94cffd5e34 100644 --- a/fairseq/models/speech_to_speech/s2s_transformer.py +++ b/fairseq/models/speech_to_speech/s2s_transformer.py @@ -247,7 +247,7 @@ def add_args(parser): parser.add_argument( "--conv-kernel-sizes", type=str, - metavar="N", + metavar="STR", help="kernel sizes of Conv1d (s2t_transformer) subsampling layers", ) parser.add_argument( @@ -435,7 +435,7 @@ def add_args(parser): parser.add_argument( "--conv-kernel-sizes", type=str, - metavar="N", + metavar="STR", help="kernel sizes of Conv1d (s2t_transformer) subsampling layers", ) parser.add_argument( @@ -621,8 +621,6 @@ def base_s2st_transformer_encoder_architecture(args): # Convolutional subsampler args.input_channels = getattr(args, "input_channels", 1) - args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") - args.conv_channels = getattr(args, "conv_channels", 1024) args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") # for Conv1d args.conv_channels = getattr(args, "conv_channels", 1024) # for Conv1d args.conv_out_channels = getattr(args, "conv_out_channels", 256) # for Conv2d diff --git a/fairseq/models/speech_to_text/s2t_transformer.py b/fairseq/models/speech_to_text/s2t_transformer.py index 6adbd3c339..50fae2ffa2 100644 --- a/fairseq/models/speech_to_text/s2t_transformer.py +++ b/fairseq/models/speech_to_text/s2t_transformer.py @@ -85,7 +85,7 @@ def add_args(parser): parser.add_argument( "--conv-kernel-sizes", type=str, - metavar="N", + metavar="STR", help="kernel sizes of Conv1d (s2t_transformer) subsampling layers", ) parser.add_argument( From ee84288c7f047caf18e20609b7a7779eb98de054 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 20 Sep 2022 02:09:33 -0700 Subject: [PATCH 19/35] Uncomment override_decoder_args --- fairseq/models/speech_to_text/xm_transformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fairseq/models/speech_to_text/xm_transformer.py b/fairseq/models/speech_to_text/xm_transformer.py index b4b9ea9312..a9ca50ef6e 100644 --- a/fairseq/models/speech_to_text/xm_transformer.py +++ b/fairseq/models/speech_to_text/xm_transformer.py @@ -633,10 +633,10 @@ def build_model(cls, args, task): # make sure all arguments are present in older models base_architecture(args) - # if getattr(args, "load_pretrained_decoder_from", None) is not None: - # ckpt = torch.load(getattr(args, "load_pretrained_decoder_from", None)) - # decoder_args_dict = cls.get_decoder_args_from_checkpoint(ckpt["cfg"]) - # args = cls.override_decoder_args(args, decoder_args_dict) + if getattr(args, "load_pretrained_decoder_from", None) is not None: + ckpt = torch.load(getattr(args, "load_pretrained_decoder_from", None)) + decoder_args_dict = cls.get_decoder_args_from_checkpoint(ckpt["cfg"]) + args = cls.override_decoder_args(args, decoder_args_dict) decoder_embed_tokens = build_embedding( task.target_dictionary, args.decoder_embed_dim From d3427df2bcf4769bbfde3cb97e3404daca791943 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 20 Sep 2022 02:14:06 -0700 Subject: [PATCH 20/35] Fix comment in warning --- fairseq/tasks/speech_to_speech.py | 2 +- fairseq/tasks/speech_to_text.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py index b112291e0a..da40b79656 100644 --- a/fairseq/tasks/speech_to_speech.py +++ b/fairseq/tasks/speech_to_speech.py @@ -232,7 +232,7 @@ def __init__(self, args, tgt_dict, infer_tgt_lang_id=None): if not self.eos_token_mt: raise Warning( - "Please provide --eos_token to replace eos in sequence generator" + "Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" ) self._infer_tgt_lang_id = infer_tgt_lang_id diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index f1da986419..829bef94b0 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -85,7 +85,7 @@ def __init__(self, args, tgt_dict): if not self.eos_token_mt: raise Warning( - "Please provide --eos_token to replace eos in sequence generator" + "Please provide eos_token in --multitask-config-yaml to replace eos in sequence generator" ) def _get_speaker_to_id(self): From 71f0b8113a188d2e9265464f3db9dce07d7be92f Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 20 Sep 2022 03:19:13 -0700 Subject: [PATCH 21/35] Add is_fisrt_pass_decoder flag --- fairseq/data/audio/data_cfg.py | 32 +++++++++ .../s2s_conformer_translatotron2.py | 20 +++--- .../speech_to_speech/s2s_conformer_unity.py | 22 +++--- .../speech_to_text/xm_transformer_unity.py | 70 ++++++++++--------- fairseq/tasks/speech_to_speech.py | 11 ++- fairseq/tasks/speech_to_text.py | 18 +++-- 6 files changed, 110 insertions(+), 63 deletions(-) diff --git a/fairseq/data/audio/data_cfg.py b/fairseq/data/audio/data_cfg.py index c6ea331ffe..dc8154e0de 100644 --- a/fairseq/data/audio/data_cfg.py +++ b/fairseq/data/audio/data_cfg.py @@ -257,6 +257,24 @@ def get_single_task(self, name): assert name in self.config, f"multitask '{name}' does not exist!" return self.config[name] + @property + def first_pass_decoder_task_index(self): + """Return the task index of the first-pass text decoder. + If there are multiple 'is_first_pass_decoder: True' in the config file, + the last task is used for the first-pass decoder. + If there is no 'is_first_pass_decoder: True' in the config file, + the last task whose task_name includes 'target' and decoder_type is not ctc. + """ + idx = -1 + for i, (k, v) in enumerate(self.config.items()): + if v.is_first_pass_decoder: + idx = i + if idx < 0: + for i, (k, v) in enumerate(self.config.items()): + if k.startwith("target") and v.decoder_type == "transformer": + idx = i + return idx + class SingleTaskConfig(object): def __init__(self, name, config): @@ -349,3 +367,17 @@ def eos_token(self): @property def rdrop_alpha(self): return self.config.get("rdrop_alpha", None) + + @property + def is_first_pass_decoder(self): + flag = self.config.get("is_first_pass_decoder", False) + if flag: + if self.decoder_type == "ctc": + raise ValueError( + "First-pass decoder in the multi-decoder model must not be CTC." + ) + if "target" not in self.task_name: + raise Warning( + 'The name of the first-pass decoder does not include "target".' + ) + return flag diff --git a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py index 8637e0de9f..8016daee8d 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_translatotron2.py @@ -121,17 +121,13 @@ def build_model(cls, args, task): base_model = cls(encoder, decoder) # set up multitask decoders - is_mt_decoder = False base_model.mt_task_name = None base_model.multitask_decoders = {} - n_aux_tasks = len(list(task.multitask_tasks.items())) - for i, (task_name, task_obj) in enumerate(task.multitask_tasks.items()): - if i == n_aux_tasks - 1: - is_mt_decoder = True + has_first_pass_decoder = False + for task_name, task_obj in task.multitask_tasks.items(): + if task_obj.is_first_pass_decoder: + has_first_pass_decoder = True base_model.mt_task_name = task_name - assert "target" in task_name - assert task_obj.args.decoder_type == "transformer" - # NOTE: we assume that the last task is for the first-pass decoder in_dim = ( args.encoder_embed_dim @@ -142,7 +138,7 @@ def build_model(cls, args, task): task_obj.args, task_obj.target_dictionary, in_dim, - is_mt_decoder, + task_obj.is_first_pass_decoder, getattr(args, "translation_decoder_layers", 4), getattr(args, "decoder_embed_dim", 256), getattr(args, "decoder_attention_heads", 4), @@ -158,11 +154,13 @@ def build_model(cls, args, task): getattr(base_model, f"{task_name}_decoder") ) - assert is_mt_decoder, "set at least one intermediate non-CTC decoder" + assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder" # set up encoder on top of the auxiliary MT decoder if getattr(args, "synthesizer_encoder_layers", 0) > 0: base_model.synthesizer_encoder = cls.build_text_encoder(args) + else: + base_model.synthesizer_encoder = None return base_model @@ -210,7 +208,7 @@ def forward( mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) # 2. TTS encoder - if hasattr(self, "synthesizer_encoder"): + if self.synthesizer_encoder is not None: tts_encoder_out = self.synthesizer_encoder( x, mt_decoder_padding_mask, diff --git a/fairseq/models/speech_to_speech/s2s_conformer_unity.py b/fairseq/models/speech_to_speech/s2s_conformer_unity.py index c759ef36a2..b7b1a5eed1 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_unity.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_unity.py @@ -83,7 +83,7 @@ def build_multitask_decoder( args, tgt_dict, in_dim, - is_mt_decoder, + is_first_pass_decoder, decoder_layers, decoder_embed_dim, decoder_attention_heads, @@ -91,7 +91,7 @@ def build_multitask_decoder( decoder_args = args.decoder_args decoder_args.encoder_embed_dim = in_dim if args.decoder_type == "transformer": - if is_mt_decoder: + if is_first_pass_decoder: multitask_text_transformer_decoder_arch( decoder_args, decoder_layers, @@ -157,17 +157,13 @@ def build_model(cls, args, task): ) # set up multitask decoders - is_mt_decoder = False base_model.mt_task_name = None base_model.multitask_decoders = {} - n_aux_tasks = len(list(task.multitask_tasks.items())) - for i, (task_name, task_obj) in enumerate(task.multitask_tasks.items()): - if i == n_aux_tasks - 1: - is_mt_decoder = True + has_first_pass_decoder = False + for task_name, task_obj in task.multitask_tasks.items(): + if task_obj.is_first_pass_decoder: + has_first_pass_decoder = True base_model.mt_task_name = task_name - assert "target" in task_name - assert task_obj.args.decoder_type == "transformer" - # NOTE: we assume that the last task is for the first-pass decoder in_dim = ( args.encoder_embed_dim @@ -178,7 +174,7 @@ def build_model(cls, args, task): task_obj.args, task_obj.target_dictionary, in_dim, - is_mt_decoder, + task_obj.is_first_pass_decoder, getattr(args, "translation_decoder_layers", 4), getattr(args, "decoder_embed_dim", 256), getattr(args, "decoder_attention_heads", 4), @@ -194,7 +190,7 @@ def build_model(cls, args, task): getattr(base_model, f"{task_name}_decoder") ) - assert is_mt_decoder, "set at least one intermediate non-CTC decoder" + assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder" # set up encoder on top of the auxiliary MT decoder if getattr(args, "synthesizer_encoder_layers", 0) > 0: @@ -246,7 +242,7 @@ def forward( mt_decoder_padding_mask = prev_output_tokens_mt.eq(mt_decoder.padding_idx) # 2. T2U encoder - if hasattr(self, "synthesizer_encoder"): + if self.synthesizer_encoder is not None: t2u_encoder_out = self.synthesizer_encoder( x, mt_decoder_padding_mask, diff --git a/fairseq/models/speech_to_text/xm_transformer_unity.py b/fairseq/models/speech_to_text/xm_transformer_unity.py index e04c05075c..7406f483f5 100644 --- a/fairseq/models/speech_to_text/xm_transformer_unity.py +++ b/fairseq/models/speech_to_text/xm_transformer_unity.py @@ -13,6 +13,9 @@ register_model_architecture, ) from fairseq.models.speech_to_speech.modules.ctc_decoder import CTCDecoder +from fairseq.models.speech_to_speech.modules.transformer_encoder import ( + TransformerEncoderNoEmb, +) from fairseq.models.speech_to_text.xm_transformer import XMTransformerModel from fairseq.models.speech_to_text.xm_transformer import ( base_architecture as xm_t_base_architecture, @@ -100,7 +103,7 @@ def add_args(cls, parser): ) @classmethod - def build_text_decoder(cls, args, task): + def build_text_decoder(cls, args, tgt_dict): _args = copy.deepcopy(args) if args.adaptor_proj or args.encoder_proj: # not V0 arch @@ -111,8 +114,8 @@ def build_text_decoder(cls, args, task): _args.layerdrop = _args.decoder_layerdrop _args.decoder_layers = _args.translation_decoder_layers - embed_tokens = build_embedding(task.target_dictionary, _args.decoder_embed_dim) - decoder = TransformerDecoder(_args, task.target_dictionary, embed_tokens) + embed_tokens = build_embedding(tgt_dict, _args.decoder_embed_dim) + decoder = TransformerDecoder(_args, tgt_dict, embed_tokens) if getattr(args, "load_pretrained_aux_decoder_from", None) is not None: decoder = cls.maybe_load_pretrained( @@ -178,20 +181,19 @@ def build_model(cls, args, task): # set up multitask decoders base_model.mt_task_name = None base_model.multitask_decoders = {} - n_aux_tasks = len(list(task.multitask_tasks.items())) - for i, (task_name, task_obj) in enumerate(task.multitask_tasks.items()): - - if i < n_aux_tasks - 1: - task_decoder = cls.build_multitask_decoder( - task_obj.args, task_obj.target_dictionary, args.decoder_embed_dim - ) - else: + has_first_pass_decoder = False + for task_name, task_obj in task.multitask_tasks.items(): + if task_obj.is_first_pass_decoder: + has_first_pass_decoder = True base_model.mt_task_name = task_name - assert "target" in task_name - assert task_obj.args.decoder_type == "transformer" - # NOTE: we assume that the last task is for the first-pass decoder - task_decoder = cls.build_text_decoder(args, task_obj) + task_decoder = cls.build_multitask_decoder( + args, + task_obj.args, + task_obj.target_dictionary, + args.decoder_embed_dim, + task_obj.is_first_pass_decoder, + ) setattr(base_model, f"{task_name}_decoder", task_decoder) decoder_model_cls = ( @@ -203,6 +205,8 @@ def build_model(cls, args, task): getattr(base_model, f"{task_name}_decoder") ) + assert has_first_pass_decoder, "set at least one intermediate non-CTC decoder" + # set up encoder on top of the auxiliary MT decoder if getattr(args, "synthesizer_encoder_layers", 0) > 0: base_model.synthesizer_encoder = cls.build_t2u_encoder(unit_args) @@ -212,24 +216,29 @@ def build_model(cls, args, task): return base_model @classmethod - def build_multitask_decoder(cls, args, tgt_dict, in_dim): - decoder_args = args.decoder_args + def build_multitask_decoder( + cls, args, mtl_args, tgt_dict, in_dim, is_first_pass_decoder + ): + decoder_args = mtl_args.decoder_args decoder_args.encoder_embed_dim = in_dim - if args.decoder_type == "transformer": - from fairseq.models.speech_to_speech import ( - base_multitask_text_transformer_decoder_arch, - ) + if mtl_args.decoder_type == "transformer": + if is_first_pass_decoder: + task_decoder = cls.build_text_decoder(args, tgt_dict) + else: + from fairseq.models.speech_to_speech import ( + base_multitask_text_transformer_decoder_arch, + ) - base_multitask_text_transformer_decoder_arch(decoder_args) # 2L - task_decoder = TransformerDecoder( - decoder_args, - tgt_dict, - embed_tokens=TransformerModelBase.build_embedding( + base_multitask_text_transformer_decoder_arch(decoder_args) # 2L + task_decoder = TransformerDecoder( decoder_args, tgt_dict, - decoder_args.decoder_embed_dim, - ), - ) + embed_tokens=TransformerModelBase.build_embedding( + decoder_args, + tgt_dict, + decoder_args.decoder_embed_dim, + ), + ) elif args.decoder_type == "ctc": task_decoder = CTCDecoder( dictionary=tgt_dict, @@ -250,9 +259,6 @@ def build_t2u_encoder(cls, args): _args.encoder_ffn_embed_dim = args.decoder_ffn_embed_dim _args.encoder_attention_heads = args.decoder_attention_heads _args.encoder_normalize_before = True - - from fairseq.models.speech_to_speech import TransformerEncoderNoEmb - return TransformerEncoderNoEmb(_args) def forward( diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py index da40b79656..e4666172e8 100644 --- a/fairseq/tasks/speech_to_speech.py +++ b/fairseq/tasks/speech_to_speech.py @@ -221,10 +221,15 @@ def __init__(self, args, tgt_dict, infer_tgt_lang_id=None): multitask_cfg = MultitaskConfig( Path(args.data) / args.multitask_config_yaml ) - for task_name, task_config in multitask_cfg.get_all_tasks().items(): - task_obj = DummyMultiTask(task_config, task_config.tgt_dict) + first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index + for i, (task_name, task_config) in enumerate( + multitask_cfg.get_all_tasks().items() + ): + task_obj = DummyMultiTask( + task_config, task_config.tgt_dict, i == first_pass_task_idx + ) self.multitask_tasks[task_name] = task_obj - if "target" in task_name and task_obj.args.decoder_type != "ctc": + if task_obj.is_first_pass_decoder: self.tgt_dict_mt = task_obj.target_dictionary if task_config.prepend_bos_and_append_tgt_lang_tag: self.eos_token_mt = task_config.eos_token diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 829bef94b0..5496b30ed5 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -74,10 +74,15 @@ def __init__(self, args, tgt_dict): multitask_cfg = MultitaskConfig( Path(args.data) / args.multitask_config_yaml ) - for task_name, task_config in multitask_cfg.get_all_tasks().items(): - task_obj = DummyMultiTask(task_config, task_config.tgt_dict) + first_pass_task_idx = multitask_cfg.first_pass_decoder_task_index + for i, (task_name, task_config) in enumerate( + multitask_cfg.get_all_tasks().items() + ): + task_obj = DummyMultiTask( + task_config, task_config.tgt_dict, i == first_pass_task_idx + ) self.multitask_tasks[task_name] = task_obj - if "target" in task_name and task_config.decoder_type != "ctc": + if task_obj.is_first_pass_decoder: self.tgt_dict_mt = task_obj.target_dictionary if task_config.prepend_bos_and_append_tgt_lang_tag: self.eos_token_mt = task_config.eos_token @@ -301,14 +306,19 @@ def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): class DummyMultiTask(LegacyFairseqTask): - def __init__(self, args, tgt_dict): + def __init__(self, args, tgt_dict, first_pass=False): super().__init__(args) self.tgt_dict = tgt_dict + self.first_pass = first_pass @property def target_dictionary(self): return self.tgt_dict + @property + def is_first_pass_decoder(self): + return self.first_pass + def inference_step( self, generator, models, sample, prefix_tokens=None, constraints=None ): From c6d95684e7ae4a4536b0576a20fb27e71c750ecb Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 21 Sep 2022 07:06:47 -0700 Subject: [PATCH 22/35] Change Translatotron2SpeechGenerator to MultiDecoderSpeechGenerator --- fairseq/speech_generator.py | 2 +- fairseq/tasks/speech_to_speech.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fairseq/speech_generator.py b/fairseq/speech_generator.py index 8951eefbff..d44ccfb54c 100644 --- a/fairseq/speech_generator.py +++ b/fairseq/speech_generator.py @@ -126,7 +126,7 @@ def generate(self, model, sample, has_targ=False, **kwargs): return finalized -class Translatotron2SpeechGenerator(SpeechGenerator): +class MultiDecoderSpeechGenerator(SpeechGenerator): def __init__( self, models, diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py index 07e3aaaca8..4d280da615 100644 --- a/fairseq/tasks/speech_to_speech.py +++ b/fairseq/tasks/speech_to_speech.py @@ -432,9 +432,9 @@ def build_generator( if getattr(args, "teacher_forcing", False): raise NotImplementedError else: - from fairseq.speech_generator import Translatotron2SpeechGenerator + from fairseq.speech_generator import MultiDecoderSpeechGenerator - generator = Translatotron2SpeechGenerator + generator = MultiDecoderSpeechGenerator lang_token_ids_aux = { i From bfb9b63146859972c9cb76d980de46053d607eae Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 28 Sep 2022 11:43:14 -0700 Subject: [PATCH 23/35] Move inference code to examples/speech_to_speech/unity --- examples/speech_to_speech/__init__.py | 2 + examples/speech_to_speech/unity/__init__.py | 7 + .../unity/sequence_generator.py | 1066 +++++++++++++++++ .../sequence_generator_multi_decoder.py | 4 +- fairseq/sequence_generator.py | 108 +- fairseq/speech_generator.py | 2 +- fairseq/tasks/speech_to_speech.py | 2 +- fairseq/tasks/speech_to_text.py | 2 +- 8 files changed, 1109 insertions(+), 84 deletions(-) create mode 100644 examples/speech_to_speech/unity/__init__.py create mode 100644 examples/speech_to_speech/unity/sequence_generator.py rename {fairseq => examples/speech_to_speech/unity}/sequence_generator_multi_decoder.py (99%) diff --git a/examples/speech_to_speech/__init__.py b/examples/speech_to_speech/__init__.py index 6264236915..812b3c30b9 100644 --- a/examples/speech_to_speech/__init__.py +++ b/examples/speech_to_speech/__init__.py @@ -2,3 +2,5 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + +from . import unity # noqa diff --git a/examples/speech_to_speech/unity/__init__.py b/examples/speech_to_speech/unity/__init__.py new file mode 100644 index 0000000000..349db7c65e --- /dev/null +++ b/examples/speech_to_speech/unity/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import sequence_generator # noqa +from . import sequence_generator_multi_decoder # noqa diff --git a/examples/speech_to_speech/unity/sequence_generator.py b/examples/speech_to_speech/unity/sequence_generator.py new file mode 100644 index 0000000000..ac542f4aa3 --- /dev/null +++ b/examples/speech_to_speech/unity/sequence_generator.py @@ -0,0 +1,1066 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import sys +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from fairseq import search, utils +from fairseq.data import data_utils +from fairseq.models import FairseqIncrementalDecoder +from fairseq.ngram_repeat_block import NGramRepeatBlock + + +class SequenceGenerator(nn.Module): + def __init__( + self, + models, + tgt_dict, + beam_size=1, + max_len_a=0, + max_len_b=200, + max_len=0, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + tokens_to_suppress=(), + ): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + if isinstance(models, EnsembleModel): + self.model = models + else: + self.model = EnsembleModel(models) + self.tgt_dict = tgt_dict + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() if eos is None else eos + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + + self.token_indices_to_suppress: Optional[Tensor] = None + token_indices_to_suppress = [] + for token_string in tokens_to_suppress: + token_index = tgt_dict.index(token_string) + assert token_index != self.unk + token_indices_to_suppress.append(token_index) + if len(token_indices_to_suppress) > 0: + self.token_indices_to_suppress = torch.Tensor( + token_indices_to_suppress + ).long() + + self.vocab_size = len(tgt_dict) + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.model.set_decoder_beam_size(self.beam_size) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.max_len = max_len or self.model.max_decoder_positions() + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.temperature = temperature + self.match_source_len = match_source_len + + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None + + assert temperature > 0, "--temperature must be greater than 0" + + self.search = ( + search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy + ) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths + ) + + self.model.eval() + + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + + def cuda(self): + self.model.cuda() + return self + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + # TODO(myleott): unused, deprecate after pytorch-translate migration + def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): + """Iterate over a batched dataset and yield individual translations. + Args: + cuda (bool, optional): use GPU for generation + timer (StopwatchMeter, optional): time generations + """ + for sample in data_itr: + s = utils.move_to_cuda(sample) if cuda else sample + if "net_input" not in s: + continue + input = s["net_input"] + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in input.items() if k != "prev_output_tokens" + } + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate(encoder_input) + if timer is not None: + timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) + for i, id in enumerate(s["id"].data): + # remove padding + src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) + ref = ( + utils.strip_pad(s["target"].data[i, :], self.pad) + if s["target"] is not None + else None + ) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate( + self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs + ) -> List[List[Dict[str, Tensor]]]: + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, **kwargs) + + def _generate( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + net_input = sample["net_input"] + + if "src_tokens" in net_input: + src_tokens = net_input["src_tokens"] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ( + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + ) + elif "source" in net_input: + src_tokens = net_input["source"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + elif "features" in net_input: + src_tokens = net_input["features"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + else: + raise Exception( + "expected src_tokens or source in net input. input keys: " + + str(net_input.keys()) + ) + + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, self.beam_size) + + # compute the encoder output for each beam + with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): + encoder_outs = self.model.forward_encoder(net_input) + + finalized = self.generate_decoder( + encoder_outs, + src_tokens, + src_lengths, + sample, + prefix_tokens, + constraints, + bos_token, + ) + return finalized + + def generate_decoder( + self, + encoder_outs, + src_tokens, + src_lengths, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + aux_task_name="", + encoder_outs2: Optional[Tensor] = None, + ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + + decoder_name = f"{aux_task_name}_decoder" if aux_task_name else "decoder" + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + self.max_len - 1, + ) + assert ( + self.min_len <= max_len + ), "min_len cannot be larger than max_len, please adjust these!" + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + if encoder_outs2 is not None: + encoder_outs2 = self.model.reorder_encoder_out(encoder_outs2, new_order) + + # initialize buffers + scores = ( + torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() + ) # +1 for eos; pad is never chosen for scoring + tokens = ( + torch.zeros(bsz * beam_size, max_len + 2) + .to(src_tokens) + .long() + .fill_(self.pad) + ) # +2 for eos and pad + tokens[:, 0] = self.eos if bos_token is None else bos_token + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = ( + torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + # a boolean array indicating if the sentence at the index is finished or not + finished = [False for i in range(bsz)] + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = ( + (torch.arange(0, bsz) * beam_size) + .unsqueeze(1) + .type_as(tokens) + .to(src_tokens.device) + ) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) + + reorder_state: Optional[Tensor] = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( + batch_idxs + ) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size + ) + original_batch_idxs = original_batch_idxs[batch_idxs] + self.model.reorder_incremental_state( + incremental_states, reorder_state, decoder_name + ) + encoder_outs = self.model.reorder_encoder_out( + encoder_outs, reorder_state + ) + if encoder_outs2 is not None: + encoder_outs2 = self.model.reorder_encoder_out( + encoder_outs2, reorder_state + ) + with torch.autograd.profiler.record_function( + "EnsembleModel: forward_decoder" + ): + lprobs, avg_attn_scores = self.model.forward_decoder( + tokens[:, : step + 1], + encoder_outs, + incremental_states, + self.temperature, + decoder_name=decoder_name, + encoder_outs2=encoder_outs2, + ) + + if self.lm_model is not None and not aux_task_name: + lm_out = self.lm_model(tokens[:, : step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + # handle max length constraint + if step >= max_len: + lprobs[:, : self.eos] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf + + # handle prefix tokens (possibly with different lengths) + if ( + prefix_tokens is not None + and step < prefix_tokens.size(1) + and step < max_len + ): + lprobs, tokens, scores = self._prefix_tokens( + step, lprobs, scores, tokens, prefix_tokens, beam_size + ) + else: + if step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs[:, self.eos] = -math.inf + + if self.token_indices_to_suppress is not None: + lprobs[:, self.token_indices_to_suppress] = -math.inf + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty( + bsz * beam_size, avg_attn_scores.size(1), max_len + 2 + ).to(scores) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(lprobs) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.repeat_ngram_blocker is not None: + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, : step + 1], + original_batch_idxs, + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len, f"{step} < {max_len}" + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device + ) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange( + bsz, device=cand_indices.device + ).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1 + ) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False + ) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) + active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, : step + 1] = torch.index_select( + tokens[:, : step + 1], dim=0, index=active_bbsz_idx + ) + # Select the next token for each of them + tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( + cand_indices, dim=1, index=active_hypos + ) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx + ) + scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( + cand_scores, dim=1, index=active_hypos + ) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, : step + 2] = torch.index_select( + attn[:, :, : step + 2], dim=0, index=active_bbsz_idx + ) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem["score"].item()) for elem in finalized[sent]] + ) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] + finalized[sent] = torch.jit.annotate( + List[Dict[str, Tensor]], finalized[sent] + ) + return finalized + + def _prefix_tokens( + self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int + ): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] + ) + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ + :, 0, 1 : step + 1 + ] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select(0, bbsz_idx)[ + :, 1 : step + 2 + ] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] + if attn is not None + else None + ) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1) ** self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) + + unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode="trunc") + sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) + + # Create a set of "{sent}{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # For every finished beam item + # sentence index in the current (possibly reduced) batch + seen = (sent << 32) + unfin_idx + unique_seen: List[int] = torch.unique(seen).tolist() + + if self.match_source_len: + condition = step > torch.index_select(src_lengths, 0, unfin_idx) + eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores) + sent_list: List[int] = sent.tolist() + for i in range(bbsz_idx.size()[0]): + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent_list[i]]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent_list[i]].append( + { + "tokens": tokens_clone[i], + "score": eos_scores[i], + "attention": hypo_attn, # src_len x tgt_len + "alignment": torch.empty(0), + "positional_scores": pos_scores[i], + } + ) + + newly_finished: List[int] = [] + for unique_s in unique_seen: + # check termination conditions for this sentence + unique_sent: int = unique_s >> 32 + unique_unfin_idx: int = unique_s - (unique_sent << 32) + + if not finished[unique_sent] and self.is_finished( + step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size + ): + finished[unique_sent] = True + newly_finished.append(unique_unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + +class EnsembleModel(nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + self.has_incremental: bool = False + if all( + hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + for m in models + ): + self.has_incremental = True + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, "encoder") + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min( + [ + m.max_decoder_positions() + for m in self.models + if hasattr(m, "max_decoder_positions") + ] + + [sys.maxsize] + ) + + def set_decoder_beam_size(self, beam_size): + """Set beam size for efficient beamable enc-dec attention.""" + if beam_size > 1: + for model in self.models: + if hasattr(model, "set_beam_size"): + model.set_beam_size(beam_size) + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + return [model.encoder.forward_torchscript(net_input) for model in self.models] + + @torch.jit.export + def forward_decoder( + self, + tokens, + encoder_outs: List[Dict[str, List[Tensor]]], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: float = 1.0, + decoder_name="decoder", + encoder_outs2: List[Dict[str, List[Tensor]]] = None, + ): + log_probs = [] + avg_attn: Optional[Tensor] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None + encoder_out2: Optional[Dict[str, List[Tensor]]] = None + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + if encoder_outs2 is not None: + encoder_out2 = encoder_outs2[i] + # decode each model + if self.has_incremental_states(): + if encoder_out2 is not None: + decoder_out = getattr(model, decoder_name).forward( + tokens, + encoder_out=encoder_out, + encoder_out2=encoder_out2, + incremental_state=incremental_states[i], + ) + else: + decoder_out = getattr(model, decoder_name).forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) + else: + if hasattr(model, decoder_name): + decoder_out = getattr(model, decoder_name).forward( + tokens, encoder_out=encoder_out + ) + else: + decoder_out = model.forward(tokens) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + if decoder_len > 1 and decoder_out[1] is not None: + if isinstance(decoder_out[1], Tensor): + attn = decoder_out[1] + else: + attn_holder = decoder_out[1]["attn"] + if isinstance(attn_holder, Tensor): + attn = attn_holder + elif attn_holder is not None: + attn = attn_holder[0] + if attn is not None: + attn = attn[:, -1, :] + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else decoder_out[1], + ) + probs = getattr(model, decoder_name).get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None + ) + probs = probs[:, -1, :] + if self.models_size == 1: + return probs, attn + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( + self.models_size + ) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out( + self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order + ): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[Dict[str, List[Tensor]]] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order) + ) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + decoder_name="decoder", + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + getattr(model, decoder_name).reorder_incremental_state_scripting( + incremental_states[i], new_order + ) + + +# class SequenceGeneratorWithAlignment(SequenceGenerator): +# def __init__( +# self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs +# ): +# """Generates translations of a given source sentence. + +# Produces alignments following "Jointly Learning to Align and +# Translate with Transformer Models" (Garg et al., EMNLP 2019). + +# Args: +# left_pad_target (bool, optional): Whether or not the +# hypothesis should be left padded or not when they are +# teacher forced for generating alignments. +# """ +# super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) +# self.left_pad_target = left_pad_target + +# if print_alignment == "hard": +# self.extract_alignment = utils.extract_hard_alignment +# elif print_alignment == "soft": +# self.extract_alignment = utils.extract_soft_alignment + +# @torch.no_grad() +# def generate(self, models, sample, **kwargs): +# finalized = super()._generate(sample, **kwargs) + +# src_tokens = sample["net_input"]["src_tokens"] +# bsz = src_tokens.shape[0] +# beam_size = self.beam_size +# ( +# src_tokens, +# src_lengths, +# prev_output_tokens, +# tgt_tokens, +# ) = self._prepare_batch_for_alignment(sample, finalized) +# if any(getattr(m, "full_context_alignment", False) for m in self.model.models): +# attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) +# else: +# attn = [ +# finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) +# for i in range(bsz * beam_size) +# ] + +# if src_tokens.device != "cpu": +# src_tokens = src_tokens.to("cpu") +# tgt_tokens = tgt_tokens.to("cpu") +# attn = [i.to("cpu") for i in attn] + +# # Process the attn matrix to extract hard alignments. +# for i in range(bsz * beam_size): +# alignment = self.extract_alignment( +# attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos +# ) +# finalized[i // beam_size][i % beam_size]["alignment"] = alignment +# return finalized + +# def _prepare_batch_for_alignment(self, sample, hypothesis): +# src_tokens = sample["net_input"]["src_tokens"] +# bsz = src_tokens.shape[0] +# src_tokens = ( +# src_tokens[:, None, :] +# .expand(-1, self.beam_size, -1) +# .contiguous() +# .view(bsz * self.beam_size, -1) +# ) +# src_lengths = sample["net_input"]["src_lengths"] +# src_lengths = ( +# src_lengths[:, None] +# .expand(-1, self.beam_size) +# .contiguous() +# .view(bsz * self.beam_size) +# ) +# prev_output_tokens = data_utils.collate_tokens( +# [beam["tokens"] for example in hypothesis for beam in example], +# self.pad, +# self.eos, +# self.left_pad_target, +# move_eos_to_beginning=True, +# ) +# tgt_tokens = data_utils.collate_tokens( +# [beam["tokens"] for example in hypothesis for beam in example], +# self.pad, +# self.eos, +# self.left_pad_target, +# move_eos_to_beginning=False, +# ) +# return src_tokens, src_lengths, prev_output_tokens, tgt_tokens + + +# class EnsembleModelWithAlignment(EnsembleModel): +# """A wrapper around an ensemble of models.""" + +# def __init__(self, models): +# super().__init__(models) + +# def forward_align(self, src_tokens, src_lengths, prev_output_tokens): +# avg_attn = None +# for model in self.models: +# decoder_out = model(src_tokens, src_lengths, prev_output_tokens) +# attn = decoder_out[1]["attn"][0] +# if avg_attn is None: +# avg_attn = attn +# else: +# avg_attn.add_(attn) +# if len(self.models) > 1: +# avg_attn.div_(len(self.models)) +# return avg_attn diff --git a/fairseq/sequence_generator_multi_decoder.py b/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py similarity index 99% rename from fairseq/sequence_generator_multi_decoder.py rename to examples/speech_to_speech/unity/sequence_generator_multi_decoder.py index 9fae5d5f7a..8af6413c28 100644 --- a/fairseq/sequence_generator_multi_decoder.py +++ b/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py @@ -10,7 +10,6 @@ from torch import Tensor from fairseq import search -from fairseq.sequence_generator import SequenceGenerator class MultiDecoderSequenceGenerator(nn.Module): @@ -69,6 +68,9 @@ def __init__( length (default: False) """ super().__init__() + + from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator + self.generator = SequenceGenerator( models, tgt_dict, diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index e01f2fd113..13f99078c7 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -91,6 +91,7 @@ def __init__( ).long() self.vocab_size = len(tgt_dict) + self.beam_size = beam_size # the max beam size is the dictionary size - 1, since we never select pad self.beam_size = min(beam_size, self.vocab_size - 1) self.model.set_decoder_beam_size(self.beam_size) @@ -209,6 +210,13 @@ def _generate( constraints: Optional[Tensor] = None, bos_token: Optional[int] = None, ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) net_input = sample["net_input"] if "src_tokens" in net_input: @@ -237,55 +245,18 @@ def _generate( + str(net_input.keys()) ) + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + bsz, src_len = src_tokens.size()[:2] + beam_size = self.beam_size + if constraints is not None and not self.search.supports_constraints: raise NotImplementedError( "Target-side constraints were provided, but search method doesn't support them" ) # Initialize constraints, when active - self.search.init_constraints(constraints, self.beam_size) - - # compute the encoder output for each beam - with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): - encoder_outs = self.model.forward_encoder(net_input) - - finalized = self.generate_decoder( - encoder_outs, - src_tokens, - src_lengths, - sample, - prefix_tokens, - constraints, - bos_token, - ) - return finalized - - def generate_decoder( - self, - encoder_outs, - src_tokens, - src_lengths, - sample: Dict[str, Dict[str, Tensor]], - prefix_tokens: Optional[Tensor] = None, - constraints: Optional[Tensor] = None, - bos_token: Optional[int] = None, - aux_task_name="", - encoder_outs2: Optional[Tensor] = None, - ): - incremental_states = torch.jit.annotate( - List[Dict[str, Dict[str, Optional[Tensor]]]], - [ - torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) - for i in range(self.model.models_size) - ], - ) - - # bsz: total number of sentences in beam - # Note that src_tokens may have more than 2 dimensions (i.e. audio features) - bsz, src_len = src_tokens.size()[:2] - beam_size = self.beam_size - - decoder_name = f"{aux_task_name}_decoder" if aux_task_name else "decoder" + self.search.init_constraints(constraints, beam_size) max_len: int = -1 if self.match_source_len: @@ -298,6 +269,9 @@ def generate_decoder( assert ( self.min_len <= max_len ), "min_len cannot be larger than max_len, please adjust these!" + # compute the encoder output for each beam + with torch.autograd.profiler.record_function("EnsembleModel: forward_encoder"): + encoder_outs = self.model.forward_encoder(net_input) # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) @@ -305,8 +279,6 @@ def generate_decoder( encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) # ensure encoder_outs is a List. assert encoder_outs is not None - if encoder_outs2 is not None: - encoder_outs2 = self.model.reorder_encoder_out(encoder_outs2, new_order) # initialize buffers scores = ( @@ -372,16 +344,10 @@ def generate_decoder( corr.unsqueeze(-1) * beam_size ) original_batch_idxs = original_batch_idxs[batch_idxs] - self.model.reorder_incremental_state( - incremental_states, reorder_state, decoder_name - ) + self.model.reorder_incremental_state(incremental_states, reorder_state) encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state ) - if encoder_outs2 is not None: - encoder_outs2 = self.model.reorder_encoder_out( - encoder_outs2, reorder_state - ) with torch.autograd.profiler.record_function( "EnsembleModel: forward_decoder" ): @@ -390,11 +356,9 @@ def generate_decoder( encoder_outs, incremental_states, self.temperature, - decoder_name=decoder_name, - encoder_outs2=encoder_outs2, ) - if self.lm_model is not None and not aux_task_name: + if self.lm_model is not None: lm_out = self.lm_model(tokens[:, : step + 1]) probs = self.lm_model.get_normalized_probs( lm_out, log_probs=True, sample=None @@ -843,38 +807,23 @@ def forward_decoder( encoder_outs: List[Dict[str, List[Tensor]]], incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], temperature: float = 1.0, - decoder_name="decoder", - encoder_outs2: List[Dict[str, List[Tensor]]] = None, ): log_probs = [] avg_attn: Optional[Tensor] = None encoder_out: Optional[Dict[str, List[Tensor]]] = None - encoder_out2: Optional[Dict[str, List[Tensor]]] = None for i, model in enumerate(self.models): if self.has_encoder(): encoder_out = encoder_outs[i] - if encoder_outs2 is not None: - encoder_out2 = encoder_outs2[i] # decode each model if self.has_incremental_states(): - if encoder_out2 is not None: - decoder_out = getattr(model, decoder_name).forward( - tokens, - encoder_out=encoder_out, - encoder_out2=encoder_out2, - incremental_state=incremental_states[i], - ) - else: - decoder_out = getattr(model, decoder_name).forward( - tokens, - encoder_out=encoder_out, - incremental_state=incremental_states[i], - ) + decoder_out = model.decoder.forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) else: - if hasattr(model, decoder_name): - decoder_out = getattr(model, decoder_name).forward( - tokens, encoder_out=encoder_out - ) + if hasattr(model, "decoder"): + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) else: decoder_out = model.forward(tokens) @@ -896,7 +845,7 @@ def forward_decoder( decoder_out[0][:, -1:, :].div_(temperature), None if decoder_len <= 1 else decoder_out[1], ) - probs = getattr(model, decoder_name).get_normalized_probs( + probs = model.get_normalized_probs( decoder_out_tuple, log_probs=True, sample=None ) probs = probs[:, -1, :] @@ -947,12 +896,11 @@ def reorder_incremental_state( self, incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], new_order, - decoder_name="decoder", ): if not self.has_incremental_states(): return for i, model in enumerate(self.models): - getattr(model, decoder_name).reorder_incremental_state_scripting( + model.decoder.reorder_incremental_state_scripting( incremental_states[i], new_order ) diff --git a/fairseq/speech_generator.py b/fairseq/speech_generator.py index d44ccfb54c..f2cc8b5e86 100644 --- a/fairseq/speech_generator.py +++ b/fairseq/speech_generator.py @@ -146,8 +146,8 @@ def __init__( self.tgt_dict_mt = tgt_dict_mt self.eos_mt = eos_mt + from examples.speech_to_speech.unity.sequence_generator import SequenceGenerator from fairseq import search - from fairseq.sequence_generator import SequenceGenerator self.text_generator = SequenceGenerator( models, diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py index 4d280da615..8991b49e6d 100644 --- a/fairseq/tasks/speech_to_speech.py +++ b/fairseq/tasks/speech_to_speech.py @@ -360,7 +360,7 @@ def build_generator_dual_decoder( args, extra_gen_cls_kwargs=None, ): - from fairseq.sequence_generator_multi_decoder import ( + from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( MultiDecoderSequenceGenerator, ) diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index 3d5d54efc4..b5e500e56c 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -173,7 +173,7 @@ def build_generator_dual_decoder( args, extra_gen_cls_kwargs, ): - from fairseq.sequence_generator_multi_decoder import ( + from examples.speech_to_speech.unity.sequence_generator_multi_decoder import ( MultiDecoderSequenceGenerator, ) From 4c2975c60901f74e173032a2adad2bb55e6ce8fb Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 28 Sep 2022 13:06:40 -0700 Subject: [PATCH 24/35] Fix rdrop default value in aux tasks --- fairseq/data/audio/data_cfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/audio/data_cfg.py b/fairseq/data/audio/data_cfg.py index dc8154e0de..bdcb7745f5 100644 --- a/fairseq/data/audio/data_cfg.py +++ b/fairseq/data/audio/data_cfg.py @@ -366,7 +366,7 @@ def eos_token(self): @property def rdrop_alpha(self): - return self.config.get("rdrop_alpha", None) + return self.config.get("rdrop_alpha", 0.0) @property def is_first_pass_decoder(self): From 4a1451fd4cef1fd8724fe13460e9f47c28a6b4bc Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 28 Sep 2022 14:19:53 -0700 Subject: [PATCH 25/35] Add language tag mapping option to multitask-config-yaml --- fairseq/data/audio/data_cfg.py | 4 ++++ fairseq/data/audio/speech_to_text_dataset.py | 13 ++++--------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/fairseq/data/audio/data_cfg.py b/fairseq/data/audio/data_cfg.py index bdcb7745f5..02d88234d5 100644 --- a/fairseq/data/audio/data_cfg.py +++ b/fairseq/data/audio/data_cfg.py @@ -381,3 +381,7 @@ def is_first_pass_decoder(self): 'The name of the first-pass decoder does not include "target".' ) return flag + + @property + def get_lang_tag_mapping(self): + return self.config.get("lang_tag_mapping", {}) diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index 8a466a8064..a7e3a34859 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -392,11 +392,6 @@ class TextTargetMultitaskData(object): # mandatory columns KEY_ID, KEY_TEXT = "id", "tgt_text" LANG_TAG_TEMPLATE = "" - LANG_TAG_MAPPING = { - "": "[en_XX]", - "": "[es_XX]", - "": "[ru_RU]", - } # FIXME: make this optional def __init__(self, args, split, tgt_dict): samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split) @@ -409,6 +404,7 @@ def __init__(self, args, split, tgt_dict): args.prepend_bos_and_append_tgt_lang_tag ) self.eos_token = args.eos_token + self.lang_tag_mapping = args.get_lang_tag_mapping @classmethod def is_lang_tag(cls, token): @@ -424,10 +420,9 @@ def get_tokenized_tgt_text(self, index: int): text = self.tokenize(self.bpe_tokenizer, text) return text - @classmethod - def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): - lang_tag = cls.LANG_TAG_TEMPLATE.format(lang) - lang_tag = cls.LANG_TAG_MAPPING.get(lang_tag, lang_tag) + def get_lang_tag_idx(self, lang: str, dictionary: Dictionary): + lang_tag = self.LANG_TAG_TEMPLATE.format(lang) + lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag) lang_tag_idx = dictionary.index(lang_tag) assert lang_tag_idx != dictionary.unk(), (lang, lang_tag) return lang_tag_idx From 74c210e52b3ae654d57c2868817bfbd35feb9681 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 28 Sep 2022 15:36:20 -0700 Subject: [PATCH 26/35] Rename encoder_out2 and encoder_outs2 --- .../unity/sequence_generator.py | 28 ++++++------ .../unity/sequence_generator_multi_decoder.py | 6 +-- .../modules/transformer_decoder_aug.py | 4 +- .../speech_to_speech/s2s_conformer_unity.py | 2 +- .../speech_to_speech/s2s_transformer.py | 2 +- .../speech_to_text/xm_transformer_unity.py | 2 +- .../transformer/transformer_decoder_aug.py | 45 ++++++++++--------- fairseq/modules/transformer_layer_aug.py | 12 ++--- 8 files changed, 53 insertions(+), 48 deletions(-) diff --git a/examples/speech_to_speech/unity/sequence_generator.py b/examples/speech_to_speech/unity/sequence_generator.py index ac542f4aa3..0db842fdcd 100644 --- a/examples/speech_to_speech/unity/sequence_generator.py +++ b/examples/speech_to_speech/unity/sequence_generator.py @@ -270,7 +270,7 @@ def generate_decoder( constraints: Optional[Tensor] = None, bos_token: Optional[int] = None, aux_task_name="", - encoder_outs2: Optional[Tensor] = None, + encoder_outs_aug: Optional[Tensor] = None, ): incremental_states = torch.jit.annotate( List[Dict[str, Dict[str, Optional[Tensor]]]], @@ -305,8 +305,10 @@ def generate_decoder( encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) # ensure encoder_outs is a List. assert encoder_outs is not None - if encoder_outs2 is not None: - encoder_outs2 = self.model.reorder_encoder_out(encoder_outs2, new_order) + if encoder_outs_aug is not None: + encoder_outs_aug = self.model.reorder_encoder_out( + encoder_outs_aug, new_order + ) # initialize buffers scores = ( @@ -378,9 +380,9 @@ def generate_decoder( encoder_outs = self.model.reorder_encoder_out( encoder_outs, reorder_state ) - if encoder_outs2 is not None: - encoder_outs2 = self.model.reorder_encoder_out( - encoder_outs2, reorder_state + if encoder_outs_aug is not None: + encoder_outs_aug = self.model.reorder_encoder_out( + encoder_outs_aug, reorder_state ) with torch.autograd.profiler.record_function( "EnsembleModel: forward_decoder" @@ -391,7 +393,7 @@ def generate_decoder( incremental_states, self.temperature, decoder_name=decoder_name, - encoder_outs2=encoder_outs2, + encoder_outs_aug=encoder_outs_aug, ) if self.lm_model is not None and not aux_task_name: @@ -844,24 +846,24 @@ def forward_decoder( incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], temperature: float = 1.0, decoder_name="decoder", - encoder_outs2: List[Dict[str, List[Tensor]]] = None, + encoder_outs_aug: List[Dict[str, List[Tensor]]] = None, ): log_probs = [] avg_attn: Optional[Tensor] = None encoder_out: Optional[Dict[str, List[Tensor]]] = None - encoder_out2: Optional[Dict[str, List[Tensor]]] = None + encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None for i, model in enumerate(self.models): if self.has_encoder(): encoder_out = encoder_outs[i] - if encoder_outs2 is not None: - encoder_out2 = encoder_outs2[i] + if encoder_outs_aug is not None: + encoder_out_aug = encoder_outs_aug[i] # decode each model if self.has_incremental_states(): - if encoder_out2 is not None: + if encoder_out_aug is not None: decoder_out = getattr(model, decoder_name).forward( tokens, encoder_out=encoder_out, - encoder_out2=encoder_out2, + encoder_out_aug=encoder_out_aug, incremental_state=incremental_states[i], ) else: diff --git a/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py b/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py index 8af6413c28..49d127b34d 100644 --- a/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py +++ b/examples/speech_to_speech/unity/sequence_generator_multi_decoder.py @@ -241,10 +241,10 @@ def _generate( } if getattr(single_model, "t2u_augmented_cross_attn", False): - encoder_outs2 = [t2u_encoder_out] + encoder_outs_aug = [t2u_encoder_out] else: encoder_outs = [t2u_encoder_out] - encoder_outs2 = None + encoder_outs_aug = None # 3. T2U decoder finalized = self.generator.generate_decoder( @@ -255,6 +255,6 @@ def _generate( prefix_tokens, constraints, bos_token, - encoder_outs2=encoder_outs2, + encoder_outs_aug=encoder_outs_aug, ) return finalized diff --git a/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py b/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py index 6650eed415..68f42c2b36 100644 --- a/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py +++ b/fairseq/models/speech_to_speech/modules/transformer_decoder_aug.py @@ -41,7 +41,7 @@ def forward( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]] = None, - encoder_out2: Optional[Dict[str, List[Tensor]]] = None, + encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, @@ -72,7 +72,7 @@ def forward( x, extra = self.extract_features( prev_output_tokens, encoder_out=encoder_out, - encoder_out2=encoder_out2, + encoder_out_aug=encoder_out_aug, incremental_state=incremental_state, full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, diff --git a/fairseq/models/speech_to_speech/s2s_conformer_unity.py b/fairseq/models/speech_to_speech/s2s_conformer_unity.py index b7b1a5eed1..64388d6d16 100644 --- a/fairseq/models/speech_to_speech/s2s_conformer_unity.py +++ b/fairseq/models/speech_to_speech/s2s_conformer_unity.py @@ -259,7 +259,7 @@ def forward( decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, - encoder_out2=t2u_encoder_out, + encoder_out_aug=t2u_encoder_out, ) else: decoder_out = self.decoder( diff --git a/fairseq/models/speech_to_speech/s2s_transformer.py b/fairseq/models/speech_to_speech/s2s_transformer.py index 94cffd5e34..07393d2598 100644 --- a/fairseq/models/speech_to_speech/s2s_transformer.py +++ b/fairseq/models/speech_to_speech/s2s_transformer.py @@ -237,7 +237,7 @@ def forward_encoder(self, src_tokens, src_lengths, speaker=None, **kwargs): @register_model("s2ut_transformer") class S2UTTransformerModel(S2STransformerMultitaskModelBase): """ - Direct speech-to-speech translation model with S2T Transformer encoder + Transformer discrete unit decoder + Direct speech-to-speech translation model with Transformer encoder + Transformer discrete unit decoder https://arxiv.org/abs/2107.05604 """ diff --git a/fairseq/models/speech_to_text/xm_transformer_unity.py b/fairseq/models/speech_to_text/xm_transformer_unity.py index 7406f483f5..0e7230d8a7 100644 --- a/fairseq/models/speech_to_text/xm_transformer_unity.py +++ b/fairseq/models/speech_to_text/xm_transformer_unity.py @@ -313,7 +313,7 @@ def forward( decoder_out = self.decoder( prev_output_tokens, encoder_out=encoder_out, - encoder_out2=t2u_encoder_out, + encoder_out_aug=t2u_encoder_out, ) else: decoder_out = self.decoder( diff --git a/fairseq/models/transformer/transformer_decoder_aug.py b/fairseq/models/transformer/transformer_decoder_aug.py index 3f0603045d..c5e7101794 100644 --- a/fairseq/models/transformer/transformer_decoder_aug.py +++ b/fairseq/models/transformer/transformer_decoder_aug.py @@ -94,7 +94,7 @@ def forward( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]] = None, - encoder_out2: Optional[Dict[str, List[Tensor]]] = None, + encoder_out_aug: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, @@ -125,7 +125,7 @@ def forward( x, extra = self.extract_features( prev_output_tokens, encoder_out=encoder_out, - encoder_out2=encoder_out2, + encoder_out_aug=encoder_out_aug, incremental_state=incremental_state, full_context_alignment=full_context_alignment, alignment_layer=alignment_layer, @@ -140,7 +140,7 @@ def extract_features( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]], - encoder_out2: Optional[Dict[str, List[Tensor]]], + encoder_out_aug: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, @@ -149,7 +149,7 @@ def extract_features( return self.extract_features_scriptable( prev_output_tokens, encoder_out, - encoder_out2, + encoder_out_aug, incremental_state, full_context_alignment, alignment_layer, @@ -166,7 +166,7 @@ def extract_features_scriptable( self, prev_output_tokens, encoder_out: Optional[Dict[str, List[Tensor]]], - encoder_out2: Optional[Dict[str, List[Tensor]]], + encoder_out_aug: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, @@ -202,12 +202,15 @@ def extract_features_scriptable( if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0: padding_mask = encoder_out["encoder_padding_mask"][0] - enc2: Optional[Tensor] = None - padding_mask2: Optional[Tensor] = None - if encoder_out2 is not None and len(encoder_out2["encoder_out"]) > 0: - enc2 = encoder_out2["encoder_out"][0] - if encoder_out2 is not None and len(encoder_out2["encoder_padding_mask"]) > 0: - padding_mask2 = encoder_out2["encoder_padding_mask"][0] + enc_aug: Optional[Tensor] = None + padding_mask_aug: Optional[Tensor] = None + if encoder_out_aug is not None and len(encoder_out_aug["encoder_out"]) > 0: + enc_aug = encoder_out_aug["encoder_out"][0] + if ( + encoder_out_aug is not None + and len(encoder_out_aug["encoder_padding_mask"]) > 0 + ): + padding_mask_aug = encoder_out_aug["encoder_padding_mask"][0] # embed positions positions = None @@ -249,7 +252,7 @@ def extract_features_scriptable( # decoder layers attn: Optional[Tensor] = None - attn2: Optional[Tensor] = None + attn_aug: Optional[Tensor] = None inner_states: List[Optional[Tensor]] = [x] for idx, layer in enumerate(self.layers): if incremental_state is None and not full_context_alignment: @@ -257,12 +260,12 @@ def extract_features_scriptable( else: self_attn_mask = None - x, layer_attn, layer_attn2, _ = layer( + x, layer_attn, layer_attn_aug, _ = layer( x, enc, padding_mask, - enc2, - padding_mask2, + enc_aug, + padding_mask_aug, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, @@ -272,8 +275,8 @@ def extract_features_scriptable( inner_states.append(x) if layer_attn is not None and idx == alignment_layer: attn = layer_attn.float().to(x) - if layer_attn2 is not None and idx == alignment_layer: - attn2 = layer_attn2.float().to(x) + if layer_attn_aug is not None and idx == alignment_layer: + attn_aug = layer_attn_aug.float().to(x) if attn is not None: if alignment_heads is not None: @@ -282,12 +285,12 @@ def extract_features_scriptable( # average probabilities over heads attn = attn.mean(dim=0) - if attn2 is not None: + if attn_aug is not None: if alignment_heads is not None: - attn2 = attn2[:alignment_heads] + attn_aug = attn_aug[:alignment_heads] # average probabilities over heads - attn2 = attn2.mean(dim=0) + attn_aug = attn_aug.mean(dim=0) if self.layer_norm is not None: x = self.layer_norm(x) @@ -298,7 +301,7 @@ def extract_features_scriptable( if self.project_out_dim is not None: x = self.project_out_dim(x) - return x, {"attn": [attn], "attn2": [attn2], "inner_states": inner_states} + return x, {"attn": [attn], "attn_aug": [attn_aug], "inner_states": inner_states} def upgrade_state_dict_named(self, state_dict, name): """Upgrade a (possibly old) state dict for new versions of fairseq.""" diff --git a/fairseq/modules/transformer_layer_aug.py b/fairseq/modules/transformer_layer_aug.py index 2acd5e2e8f..7eb816978a 100644 --- a/fairseq/modules/transformer_layer_aug.py +++ b/fairseq/modules/transformer_layer_aug.py @@ -59,7 +59,7 @@ def forward( x, encoder_out: Optional[torch.Tensor] = None, encoder_padding_mask: Optional[torch.Tensor] = None, - encoder_out2: Optional[torch.Tensor] = None, + encoder_out_aug: Optional[torch.Tensor] = None, encoder_padding_mask2: Optional[torch.Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, prev_self_attn_state: Optional[List[torch.Tensor]] = None, @@ -145,7 +145,7 @@ def forward( x = self.self_attn_layer_norm(x) assert encoder_out is not None - assert encoder_out2 is not None + assert encoder_out_aug is not None if self.encoder_attn_merge_type == "sequential": ratios = self.get_dropnet_ratio() @@ -200,8 +200,8 @@ def forward( x, attn2 = self.encoder_attn2( query=x, - key=encoder_out2, - value=encoder_out2, + key=encoder_out_aug, + value=encoder_out_aug, key_padding_mask=encoder_padding_mask2, incremental_state=incremental_state, static_kv=True, @@ -241,8 +241,8 @@ def forward( ) x2, attn2 = self.encoder_attn2( query=x, - key=encoder_out2, - value=encoder_out2, + key=encoder_out_aug, + value=encoder_out_aug, key_padding_mask=encoder_padding_mask2, incremental_state=incremental_state, static_kv=True, From 961de2e2296d1fe8d5254419c717f92d33d53362 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 28 Sep 2022 15:44:28 -0700 Subject: [PATCH 27/35] Rename UnitYXMTransformerModel to XMTransformerModelUnitY --- fairseq/models/speech_to_text/xm_transformer_unity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/models/speech_to_text/xm_transformer_unity.py b/fairseq/models/speech_to_text/xm_transformer_unity.py index 0e7230d8a7..450bf01efe 100644 --- a/fairseq/models/speech_to_text/xm_transformer_unity.py +++ b/fairseq/models/speech_to_text/xm_transformer_unity.py @@ -61,7 +61,7 @@ def unit_transformer_decoder_arch_large( @register_model("unity_xm_transformer") -class UnitYXMTransformerModel(XMTransformerModel): +class XMTransformerModelUnitY(XMTransformerModel): @classmethod def hub_models(cls): base_url = "http://dl.fbaipublicfiles.com/fairseq/s2t" From 2c580c7bac7842c6a9863b78b953fae04d497d93 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 28 Sep 2022 15:48:07 -0700 Subject: [PATCH 28/35] Support num_best_checkpoints in average_checkpoints --- scripts/average_checkpoints.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/scripts/average_checkpoints.py b/scripts/average_checkpoints.py index a4711e4840..49f4f9d912 100644 --- a/scripts/average_checkpoints.py +++ b/scripts/average_checkpoints.py @@ -10,6 +10,7 @@ import re import torch + from fairseq.file_io import PathManager @@ -113,6 +114,9 @@ def main(): num_group.add_argument('--num-update-checkpoints', type=int, help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by' ' input, and average last this many of them.') + num_group.add_argument('--num-best-checkpoints', type=int, default=0, + help='if set, will try to find checkpoints with names checkpoint_best_ee_xx.pt in the path specified by' + ' input, and average last this many of them.') parser.add_argument('--checkpoint-upper-bound', type=int, help='when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, ' 'when using --num-update-checkpoints, this will set an upper bound on which update to use' @@ -150,6 +154,18 @@ def main(): ) print("averaging checkpoints: ", args.inputs) + if args.num_best_checkpoints > 0: + args.inputs = list( + sorted( + args.inputs, + key=lambda x: float( + os.path.basename(x).split("_")[-1].replace(".pt", "") + ), + ) + ) + args.inputs = args.inputs[: args.num_best_checkpoints] + for path in args.inputs: + print(os.path.basename(path)) new_state = average_checkpoints(args.inputs) with PathManager.open(args.output, "wb") as f: torch.save(new_state, f) From 40a8a0f03ff7c6d8352c8299eb297145907f48a2 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Wed, 28 Sep 2022 15:49:07 -0700 Subject: [PATCH 29/35] Fix has_multitask --- fairseq/data/audio/speech_to_speech_dataset.py | 2 +- fairseq/data/audio/speech_to_text_dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/data/audio/speech_to_speech_dataset.py b/fairseq/data/audio/speech_to_speech_dataset.py index 6efd834730..fe4b61f831 100644 --- a/fairseq/data/audio/speech_to_speech_dataset.py +++ b/fairseq/data/audio/speech_to_speech_dataset.py @@ -319,7 +319,7 @@ def _from_list( src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] - has_multitask = len(multitask) > 0 + has_multitask = multitask is not None and len(multitask.keys()) > 0 dataset_cls = ( SpeechToSpeechMultitaskDataset if has_multitask else SpeechToSpeechDataset ) diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index a7e3a34859..f0e13ad7a7 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -576,7 +576,7 @@ def _from_list( src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] - has_multitask = len(multitask) > 0 + has_multitask = multitask is not None and len(multitask.keys()) > 0 dataset_cls = ( SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset ) From b41a5df30820ad1f9840301333b521de6f2e5c85 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 4 Oct 2022 10:49:29 -0700 Subject: [PATCH 30/35] Inherit SequenceGenerator --- .../unity/sequence_generator.py | 508 ++---------------- 1 file changed, 30 insertions(+), 478 deletions(-) diff --git a/examples/speech_to_speech/unity/sequence_generator.py b/examples/speech_to_speech/unity/sequence_generator.py index 0db842fdcd..3373ee88bb 100644 --- a/examples/speech_to_speech/unity/sequence_generator.py +++ b/examples/speech_to_speech/unity/sequence_generator.py @@ -8,16 +8,13 @@ from typing import Dict, List, Optional import torch -import torch.nn as nn from torch import Tensor -from fairseq import search, utils -from fairseq.data import data_utils -from fairseq.models import FairseqIncrementalDecoder -from fairseq.ngram_repeat_block import NGramRepeatBlock +from fairseq.sequence_generator import EnsembleModel as EnsembleModelBase +from fairseq.sequence_generator import SequenceGenerator as SequenceGeneratorBase -class SequenceGenerator(nn.Module): +class SequenceGenerator(SequenceGeneratorBase): def __init__( self, models, @@ -64,144 +61,36 @@ def __init__( match_source_len (bool, optional): outputs should match the source length (default: False) """ - super().__init__() + super().__init__( + models=models, + tgt_dict=tgt_dict, + beam_size=beam_size, + max_len_a=max_len_a, + max_len_b=max_len_b, + max_len=max_len, + min_len=min_len, + normalize_scores=normalize_scores, + len_penalty=len_penalty, + unk_penalty=unk_penalty, + temperature=temperature, + match_source_len=match_source_len, + no_repeat_ngram_size=no_repeat_ngram_size, + search_strategy=search_strategy, + eos=eos, + symbols_to_strip_from_output=symbols_to_strip_from_output, + lm_model=lm_model, + lm_weight=lm_weight, + tokens_to_suppress=tokens_to_suppress, + ) + if isinstance(models, EnsembleModel): self.model = models else: self.model = EnsembleModel(models) - self.tgt_dict = tgt_dict - self.pad = tgt_dict.pad() - self.unk = tgt_dict.unk() - self.eos = tgt_dict.eos() if eos is None else eos - self.symbols_to_strip_from_output = ( - symbols_to_strip_from_output.union({self.eos}) - if symbols_to_strip_from_output is not None - else {self.eos} - ) - self.token_indices_to_suppress: Optional[Tensor] = None - token_indices_to_suppress = [] - for token_string in tokens_to_suppress: - token_index = tgt_dict.index(token_string) - assert token_index != self.unk - token_indices_to_suppress.append(token_index) - if len(token_indices_to_suppress) > 0: - self.token_indices_to_suppress = torch.Tensor( - token_indices_to_suppress - ).long() - - self.vocab_size = len(tgt_dict) - # the max beam size is the dictionary size - 1, since we never select pad - self.beam_size = min(beam_size, self.vocab_size - 1) self.model.set_decoder_beam_size(self.beam_size) - self.max_len_a = max_len_a - self.max_len_b = max_len_b - self.min_len = min_len - self.max_len = max_len or self.model.max_decoder_positions() - - self.normalize_scores = normalize_scores - self.len_penalty = len_penalty - self.unk_penalty = unk_penalty - self.temperature = temperature - self.match_source_len = match_source_len - - if no_repeat_ngram_size > 0: - self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) - else: - self.repeat_ngram_blocker = None - - assert temperature > 0, "--temperature must be greater than 0" - - self.search = ( - search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy - ) - # We only need to set src_lengths in LengthConstrainedBeamSearch. - # As a module attribute, setting it would break in multithread - # settings when the model is shared. - self.should_set_src_lengths = ( - hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths - ) - self.model.eval() - self.lm_model = lm_model - self.lm_weight = lm_weight - if self.lm_model is not None: - self.lm_model.eval() - - def cuda(self): - self.model.cuda() - return self - - @torch.no_grad() - def forward( - self, - sample: Dict[str, Dict[str, Tensor]], - prefix_tokens: Optional[Tensor] = None, - bos_token: Optional[int] = None, - ): - """Generate a batch of translations. - - Args: - sample (dict): batch - prefix_tokens (torch.LongTensor, optional): force decoder to begin - with these tokens - bos_token (int, optional): beginning of sentence token - (default: self.eos) - """ - return self._generate(sample, prefix_tokens, bos_token=bos_token) - - # TODO(myleott): unused, deprecate after pytorch-translate migration - def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): - """Iterate over a batched dataset and yield individual translations. - Args: - cuda (bool, optional): use GPU for generation - timer (StopwatchMeter, optional): time generations - """ - for sample in data_itr: - s = utils.move_to_cuda(sample) if cuda else sample - if "net_input" not in s: - continue - input = s["net_input"] - # model.forward normally channels prev_output_tokens into the decoder - # separately, but SequenceGenerator directly calls model.encoder - encoder_input = { - k: v for k, v in input.items() if k != "prev_output_tokens" - } - if timer is not None: - timer.start() - with torch.no_grad(): - hypos = self.generate(encoder_input) - if timer is not None: - timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) - for i, id in enumerate(s["id"].data): - # remove padding - src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) - ref = ( - utils.strip_pad(s["target"].data[i, :], self.pad) - if s["target"] is not None - else None - ) - yield id, src, ref, hypos[i] - - @torch.no_grad() - def generate( - self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs - ) -> List[List[Dict[str, Tensor]]]: - """Generate translations. Match the api of other fairseq generators. - - Args: - models (List[~fairseq.models.FairseqModel]): ensemble of models - sample (dict): batch - prefix_tokens (torch.LongTensor, optional): force decoder to begin - with these tokens - constraints (torch.LongTensor, optional): force decoder to include - the list of constraints - bos_token (int, optional): beginning of sentence token - (default: self.eos) - """ - return self._generate(sample, **kwargs) - def _generate( self, sample: Dict[str, Dict[str, Tensor]], @@ -270,7 +159,9 @@ def generate_decoder( constraints: Optional[Tensor] = None, bos_token: Optional[int] = None, aux_task_name="", - encoder_outs_aug: Optional[Tensor] = None, + encoder_outs_aug: Optional[ + Tensor + ] = None, # an additional/augmented encoder_outs ): incremental_states = torch.jit.annotate( List[Dict[str, Dict[str, Optional[Tensor]]]], @@ -625,218 +516,12 @@ def generate_decoder( ) return finalized - def _prefix_tokens( - self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int - ): - """Handle prefix tokens""" - prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) - prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) - prefix_mask = prefix_toks.ne(self.pad) - lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) - lprobs[prefix_mask] = lprobs[prefix_mask].scatter( - -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] - ) - # if prefix includes eos, then we should make sure tokens and - # scores are the same across all beams - eos_mask = prefix_toks.eq(self.eos) - if eos_mask.any(): - # validate that the first beam matches the prefix - first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ - :, 0, 1 : step + 1 - ] - eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] - target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] - assert (first_beam == target_prefix).all() - - # copy tokens, scores and lprobs from the first beam to all beams - tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) - scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) - lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) - return lprobs, tokens, scores - - def replicate_first_beam(self, tensor, mask, beam_size: int): - tensor = tensor.view(-1, beam_size, tensor.size(-1)) - tensor[mask] = tensor[mask][:, :1, :] - return tensor.view(-1, tensor.size(-1)) - - def finalize_hypos( - self, - step: int, - bbsz_idx, - eos_scores, - tokens, - scores, - finalized: List[List[Dict[str, Tensor]]], - finished: List[bool], - beam_size: int, - attn: Optional[Tensor], - src_lengths, - max_len: int, - ): - """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. - A sentence is finalized when {beam_size} finished items have been collected for it. - Returns number of sentences (not beam items) being finalized. - These will be removed from the batch and not processed further. - Args: - bbsz_idx (Tensor): - """ - assert bbsz_idx.numel() == eos_scores.numel() - - # clone relevant token and attention tensors. - # tokens is (batch * beam, max_len). So the index_select - # gets the newly EOS rows, then selects cols 1..{step + 2} - tokens_clone = tokens.index_select(0, bbsz_idx)[ - :, 1 : step + 2 - ] # skip the first index, which is EOS - - tokens_clone[:, step] = self.eos - attn_clone = ( - attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] - if attn is not None - else None - ) - - # compute scores per token position - pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] - pos_scores[:, step] = eos_scores - # convert from cumulative to per-position scores - pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] - - # normalize sentence-level scores - if self.normalize_scores: - eos_scores /= (step + 1) ** self.len_penalty - - # cum_unfin records which sentences in the batch are finished. - # It helps match indexing between (a) the original sentences - # in the batch and (b) the current, possibly-reduced set of - # sentences. - cum_unfin: List[int] = [] - prev = 0 - for f in finished: - if f: - prev += 1 - else: - cum_unfin.append(prev) - cum_fin_tensor = torch.tensor(cum_unfin, dtype=torch.int).to(bbsz_idx) - - unfin_idx = torch.div(bbsz_idx, beam_size, rounding_mode="trunc") - sent = unfin_idx + torch.index_select(cum_fin_tensor, 0, unfin_idx) - - # Create a set of "{sent}{unfin_idx}", where - # "unfin_idx" is the index in the current (possibly reduced) - # list of sentences, and "sent" is the index in the original, - # unreduced batch - # For every finished beam item - # sentence index in the current (possibly reduced) batch - seen = (sent << 32) + unfin_idx - unique_seen: List[int] = torch.unique(seen).tolist() - - if self.match_source_len: - condition = step > torch.index_select(src_lengths, 0, unfin_idx) - eos_scores = torch.where(condition, torch.tensor(-math.inf), eos_scores) - sent_list: List[int] = sent.tolist() - for i in range(bbsz_idx.size()[0]): - # An input sentence (among those in a batch) is finished when - # beam_size hypotheses have been collected for it - if len(finalized[sent_list[i]]) < beam_size: - if attn_clone is not None: - # remove padding tokens from attn scores - hypo_attn = attn_clone[i] - else: - hypo_attn = torch.empty(0) - - finalized[sent_list[i]].append( - { - "tokens": tokens_clone[i], - "score": eos_scores[i], - "attention": hypo_attn, # src_len x tgt_len - "alignment": torch.empty(0), - "positional_scores": pos_scores[i], - } - ) - - newly_finished: List[int] = [] - for unique_s in unique_seen: - # check termination conditions for this sentence - unique_sent: int = unique_s >> 32 - unique_unfin_idx: int = unique_s - (unique_sent << 32) - - if not finished[unique_sent] and self.is_finished( - step, unique_unfin_idx, max_len, len(finalized[unique_sent]), beam_size - ): - finished[unique_sent] = True - newly_finished.append(unique_unfin_idx) - - return newly_finished - - def is_finished( - self, - step: int, - unfin_idx: int, - max_len: int, - finalized_sent_len: int, - beam_size: int, - ): - """ - Check whether decoding for a sentence is finished, which - occurs when the list of finalized sentences has reached the - beam size, or when we reach the maximum length. - """ - assert finalized_sent_len <= beam_size - if finalized_sent_len == beam_size or step == max_len: - return True - return False - - -class EnsembleModel(nn.Module): +class EnsembleModel(EnsembleModelBase): """A wrapper around an ensemble of models.""" def __init__(self, models): - super().__init__() - self.models_size = len(models) - # method '__len__' is not supported in ModuleList for torch script - self.single_model = models[0] - self.models = nn.ModuleList(models) - - self.has_incremental: bool = False - if all( - hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) - for m in models - ): - self.has_incremental = True - - def forward(self): - pass - - def has_encoder(self): - return hasattr(self.single_model, "encoder") - - def has_incremental_states(self): - return self.has_incremental - - def max_decoder_positions(self): - return min( - [ - m.max_decoder_positions() - for m in self.models - if hasattr(m, "max_decoder_positions") - ] - + [sys.maxsize] - ) - - def set_decoder_beam_size(self, beam_size): - """Set beam size for efficient beamable enc-dec attention.""" - if beam_size > 1: - for model in self.models: - if hasattr(model, "set_beam_size"): - model.set_beam_size(beam_size) - - @torch.jit.export - def forward_encoder(self, net_input: Dict[str, Tensor]): - if not self.has_encoder(): - return None - return [model.encoder.forward_torchscript(net_input) for model in self.models] + super().__init__(models) @torch.jit.export def forward_decoder( @@ -920,30 +605,6 @@ def forward_decoder( avg_attn.div_(self.models_size) return avg_probs, avg_attn - @torch.jit.export - def reorder_encoder_out( - self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order - ): - """ - Reorder encoder output according to *new_order*. - - Args: - encoder_out: output from the ``forward()`` method - new_order (LongTensor): desired order - - Returns: - *encoder_out* rearranged according to *new_order* - """ - new_outs: List[Dict[str, List[Tensor]]] = [] - if not self.has_encoder(): - return new_outs - for i, model in enumerate(self.models): - assert encoder_outs is not None - new_outs.append( - model.encoder.reorder_encoder_out(encoder_outs[i], new_order) - ) - return new_outs - @torch.jit.export def reorder_incremental_state( self, @@ -957,112 +618,3 @@ def reorder_incremental_state( getattr(model, decoder_name).reorder_incremental_state_scripting( incremental_states[i], new_order ) - - -# class SequenceGeneratorWithAlignment(SequenceGenerator): -# def __init__( -# self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs -# ): -# """Generates translations of a given source sentence. - -# Produces alignments following "Jointly Learning to Align and -# Translate with Transformer Models" (Garg et al., EMNLP 2019). - -# Args: -# left_pad_target (bool, optional): Whether or not the -# hypothesis should be left padded or not when they are -# teacher forced for generating alignments. -# """ -# super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) -# self.left_pad_target = left_pad_target - -# if print_alignment == "hard": -# self.extract_alignment = utils.extract_hard_alignment -# elif print_alignment == "soft": -# self.extract_alignment = utils.extract_soft_alignment - -# @torch.no_grad() -# def generate(self, models, sample, **kwargs): -# finalized = super()._generate(sample, **kwargs) - -# src_tokens = sample["net_input"]["src_tokens"] -# bsz = src_tokens.shape[0] -# beam_size = self.beam_size -# ( -# src_tokens, -# src_lengths, -# prev_output_tokens, -# tgt_tokens, -# ) = self._prepare_batch_for_alignment(sample, finalized) -# if any(getattr(m, "full_context_alignment", False) for m in self.model.models): -# attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) -# else: -# attn = [ -# finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) -# for i in range(bsz * beam_size) -# ] - -# if src_tokens.device != "cpu": -# src_tokens = src_tokens.to("cpu") -# tgt_tokens = tgt_tokens.to("cpu") -# attn = [i.to("cpu") for i in attn] - -# # Process the attn matrix to extract hard alignments. -# for i in range(bsz * beam_size): -# alignment = self.extract_alignment( -# attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos -# ) -# finalized[i // beam_size][i % beam_size]["alignment"] = alignment -# return finalized - -# def _prepare_batch_for_alignment(self, sample, hypothesis): -# src_tokens = sample["net_input"]["src_tokens"] -# bsz = src_tokens.shape[0] -# src_tokens = ( -# src_tokens[:, None, :] -# .expand(-1, self.beam_size, -1) -# .contiguous() -# .view(bsz * self.beam_size, -1) -# ) -# src_lengths = sample["net_input"]["src_lengths"] -# src_lengths = ( -# src_lengths[:, None] -# .expand(-1, self.beam_size) -# .contiguous() -# .view(bsz * self.beam_size) -# ) -# prev_output_tokens = data_utils.collate_tokens( -# [beam["tokens"] for example in hypothesis for beam in example], -# self.pad, -# self.eos, -# self.left_pad_target, -# move_eos_to_beginning=True, -# ) -# tgt_tokens = data_utils.collate_tokens( -# [beam["tokens"] for example in hypothesis for beam in example], -# self.pad, -# self.eos, -# self.left_pad_target, -# move_eos_to_beginning=False, -# ) -# return src_tokens, src_lengths, prev_output_tokens, tgt_tokens - - -# class EnsembleModelWithAlignment(EnsembleModel): -# """A wrapper around an ensemble of models.""" - -# def __init__(self, models): -# super().__init__(models) - -# def forward_align(self, src_tokens, src_lengths, prev_output_tokens): -# avg_attn = None -# for model in self.models: -# decoder_out = model(src_tokens, src_lengths, prev_output_tokens) -# attn = decoder_out[1]["attn"][0] -# if avg_attn is None: -# avg_attn = attn -# else: -# avg_attn.add_(attn) -# if len(self.models) > 1: -# avg_attn.div_(len(self.models)) -# return avg_attn From b403911edcb2d31374949aaa5a06ce090dbf91c9 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 4 Oct 2022 10:52:04 -0700 Subject: [PATCH 31/35] Reflect recent updates --- .../speech_to_speech/unity/sequence_generator.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/speech_to_speech/unity/sequence_generator.py b/examples/speech_to_speech/unity/sequence_generator.py index 3373ee88bb..c482098feb 100644 --- a/examples/speech_to_speech/unity/sequence_generator.py +++ b/examples/speech_to_speech/unity/sequence_generator.py @@ -103,9 +103,15 @@ def _generate( if "src_tokens" in net_input: src_tokens = net_input["src_tokens"] # length of the source text being the character length except EndOfSentence and pad - src_lengths = ( - (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) - ) + # if src_lengths exists in net_input (speech_to_text dataset case), then use it + if "src_lengths" in net_input: + src_lengths = net_input["src_lengths"] + else: + src_lengths = ( + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)) + .long() + .sum(dim=1) + ) elif "source" in net_input: src_tokens = net_input["source"] src_lengths = ( From 05688ded533e39b9dc62e87c9258a7cb2aa4cbf1 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 4 Oct 2022 10:52:37 -0700 Subject: [PATCH 32/35] Minor fix in logging --- fairseq/criterions/speech_to_speech_criterion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py index d8e89916f3..4d426cfb03 100644 --- a/fairseq/criterions/speech_to_speech_criterion.py +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -37,7 +37,7 @@ def __init__(self, multitask_tasks, rdrop_alpha=0.0): if rdrop_alpha_task is None: rdrop_alpha_task = rdrop_alpha self.rdrop_alpha_mtl = rdrop_alpha_task - logger.info(f"rdrop_alpha is set to {rdrop_alpha_task}") + logger.info(f"rdrop_alpha is set to {rdrop_alpha_task} for {task_name}") if task_obj.args.decoder_type == "ctc": self.multitask_criterion[task_name] = CtcCriterion( From aa11f8da65702fe09d0911b14040c2157e2d9850 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 4 Oct 2022 11:21:19 -0700 Subject: [PATCH 33/35] Fix typo --- fairseq/data/audio/data_cfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/data/audio/data_cfg.py b/fairseq/data/audio/data_cfg.py index 02d88234d5..6be6f6521c 100644 --- a/fairseq/data/audio/data_cfg.py +++ b/fairseq/data/audio/data_cfg.py @@ -271,7 +271,7 @@ def first_pass_decoder_task_index(self): idx = i if idx < 0: for i, (k, v) in enumerate(self.config.items()): - if k.startwith("target") and v.decoder_type == "transformer": + if k.startswith("target") and v.decoder_type == "transformer": idx = i return idx From d56ba4a97dc751d3917e1a3514ca4dd896350fae Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 4 Oct 2022 11:22:47 -0700 Subject: [PATCH 34/35] Refactor SpeechToSpectrogram2passMultitaskTaskCriterion --- .../criterions/speech_to_speech_criterion.py | 24 +------------------ 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/fairseq/criterions/speech_to_speech_criterion.py b/fairseq/criterions/speech_to_speech_criterion.py index 4d426cfb03..7fb3e6ba68 100644 --- a/fairseq/criterions/speech_to_speech_criterion.py +++ b/fairseq/criterions/speech_to_speech_criterion.py @@ -429,7 +429,7 @@ def reduce_metrics(cls, logging_outputs) -> None: @register_criterion("speech_to_spectrogram_2pass", dataclass=Tacotron2CriterionConfig) class SpeechToSpectrogram2passMultitaskTaskCriterion( - Tacotron2Criterion, MultitaskCriterion + SpeechToSpectrogramMultitaskTaskCriterion ): def __init__( self, @@ -448,7 +448,6 @@ def __init__( bce_pos_weight, ctc_weight, ) - MultitaskCriterion.__init__(self, task.multitask_tasks) def forward(self, model, sample, reduction="mean"): bsz, max_len, _ = sample["target"].size() @@ -511,24 +510,3 @@ def forward(self, model, sample, reduction="mean"): loss += multitask_loss logging_output["multitask"] = multitask_log return loss, sample_size, logging_output - - @classmethod - def reduce_metrics(cls, logging_outputs) -> None: - super().reduce_metrics(logging_outputs) - - # inference metrics - if "targ_frames" in logging_outputs[0]: - n = sum(log.get("norm_frames", 0) for log in logging_outputs) - for key, new_key in [ - ("mcd_loss", "mcd_loss"), - ("pred_frames", "pred_ratio"), - ("nins", "ins_rate"), - ("ndel", "del_rate"), - ]: - val = sum(log.get(key, 0) for log in logging_outputs) - metrics.log_scalar(new_key, val / n, n, round=3) - - if "multitask" not in logging_outputs[0]: - return - - MultitaskCriterion.reduce_metrics(logging_outputs) From 0da578a24255594083f177e7a0c3e41c850f0478 Mon Sep 17 00:00:00 2001 From: Hirofumi Inaguma Date: Tue, 4 Oct 2022 11:38:44 -0700 Subject: [PATCH 35/35] Minor update for multitask --- fairseq/data/audio/speech_to_text_dataset.py | 2 +- fairseq/tasks/speech_to_speech.py | 4 +++- fairseq/tasks/speech_to_text.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/fairseq/data/audio/speech_to_text_dataset.py b/fairseq/data/audio/speech_to_text_dataset.py index f0e13ad7a7..cdf71558fd 100644 --- a/fairseq/data/audio/speech_to_text_dataset.py +++ b/fairseq/data/audio/speech_to_text_dataset.py @@ -443,7 +443,7 @@ def build_bpe(self, args): else: return None - def get(self, sample_id, tgt_lang): + def get(self, sample_id, tgt_lang=None): if sample_id in self.data: tokenized = self.get_tokenized_tgt_text(sample_id) target = self.dict.encode_line( diff --git a/fairseq/tasks/speech_to_speech.py b/fairseq/tasks/speech_to_speech.py index 8991b49e6d..5aaaa95a90 100644 --- a/fairseq/tasks/speech_to_speech.py +++ b/fairseq/tasks/speech_to_speech.py @@ -226,7 +226,9 @@ def __init__(self, args, tgt_dict, infer_tgt_lang_id=None): multitask_cfg.get_all_tasks().items() ): task_obj = DummyMultiTask( - task_config, task_config.tgt_dict, i == first_pass_task_idx + task_config, + task_config.tgt_dict, + first_pass=i == first_pass_task_idx, ) self.multitask_tasks[task_name] = task_obj if task_obj.is_first_pass_decoder: diff --git a/fairseq/tasks/speech_to_text.py b/fairseq/tasks/speech_to_text.py index b5e500e56c..884082112a 100644 --- a/fairseq/tasks/speech_to_text.py +++ b/fairseq/tasks/speech_to_text.py @@ -79,7 +79,9 @@ def __init__(self, args, tgt_dict): multitask_cfg.get_all_tasks().items() ): task_obj = DummyMultiTask( - task_config, task_config.tgt_dict, i == first_pass_task_idx + task_config, + task_config.tgt_dict, + first_pass=i == first_pass_task_idx, ) self.multitask_tasks[task_name] = task_obj if task_obj.is_first_pass_decoder: