Skip to content

Latest commit

 

History

History
120 lines (103 loc) · 5.74 KB

README.md

File metadata and controls

120 lines (103 loc) · 5.74 KB

Prune-Tune: Finding Sparse Structures for Domain Specific NMT

This example shows how to run the Prune-Tune (Liang et al., 2021) method that first prunes the NMT model and then tunes partial model parameters to learn domain-specific knowledge. Here is the brief introduction of the Prune-Tune method.

The Prune-Tune method is from the following paper.

@inproceedings{jianze2021prunetune,
  title={Finding Sparse Structures for Domain Specific Neural Machine Translation},
  author={Jianze Liang, Chengqi Zhao, Mingxuan Wang, Xipeng Qiu, Lei Li},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  year={2021}
}

Datasets, results and available models

In this recipe, we will show how to adapt a NMT model of general domain to various target domains using Prune-Tune.

Datasets:

  • The general domain model is trained on WMT14 en->de dataset.
  • The target domain contains Novel, EMEA and Oral (IWSLT14). The datasets can be downloaded from HERE.

We report tokenized BLEU. The baseline model is Transformer big.

# Model Pre-trained Model Dataset Approach Train Steps newstest2014 BLEU target domain BLEU
1 Baseline - wmt14 training from scratch 500000 28.4 -
2 Baseline_pruned10 #1 wmt14 gradual pruning 10000 28.5 -
3 IWSLTspec_tune10 #2 iwslt14 partial tuning 10000 28.5 31.4
4 EMEAspec_tune10 #2 emea partial tuning 10000 28.5 30.9
5 NOVELspec_tune10 #2 novel partial tuning 10000 28.5 24.2

Run the Prune-Tune method

Train the general domain model and prune

Following the Weight Pruning, assume we have a well-trained transformer big model on WMT14 en->de dataset with 10% parameters pruned [LINK] and the vocabulary is built using word piece.

Note that, we should add three extra options when pruning:

--include examples/prune_tune/src/ \ 
--entry prune_tune_train \
--nopruning_variable_pattern "(ln/gamma)|(ln/beta)|(modalit)"  # No pruning to LayerNorm/Embedding Layers

Then, we will get the pruned model transformer_big_baseline_pruned10/ in which 10% of parameters are pruned and the weight masks are saved into transformer_big_baseline_pruned10/mask.pkl, where 0 indicates zero-value pruned weight.

Prepare domain dataset

Download the datasets with specific domains from HERE.

# Download novel.tar.gz via the link above.

# Untar novel dataset
tar -zxvf novel.tar.gz

# Preprocess novel data.
bash ./examples/prune_tune/scripts/prepare-target-dataset-wp.sh novel/

we will get the preprocessed training data and raw testsets under directory novel/:

data/wmt14_en_de/
├── dev.de
├── dev.en
├── prediction_args.yml   # the arguments for prediction
├── test.de  # the raw training data
├── test.en
├── train.de  # the raw training data
├── train.en
├── training_args.yml  # the arguments for training
├── training_records # directory of training TFRecords
    ├──train.tfrecords-00000-of-00032
    ├──train.tfrecords.00001-of-00032
    ├── ...
├── translation_wordpiece.yml  # the arguments for training data and data pre-processing logic
└── validation_args.yml  # the arguments for validation

Partially tune the model with target domian dataset

According to the mask file transformer_big_baseline_pruned10/mask.pkl, we can tune the model parameters only at the masked positions.

python3 -m neurst.cli.run_exp \
    --include examples/prune_tune/src/ \
    --entry prune_tune_train \
    --config_paths novel/training_args.yml,novel/translation_wordpiece.yml,novel/validation_args.yml \
    --hparams_set transformer_big \
    --pretrain_model transformer_big_baseline_pruned10/ \
    --model_dir transformer_big_baseline_pruned10_novel/ \
    --initial_global_step 0 \
    --train_steps 10000 \
    --summary_steps 200 \
    --save_checkpoints_steps 1000 \
    --partial_tuning \
    --mask_pkl transformer_big_baseline_pruned10/mask.pkl 

Evaluation on the general and target domain

  • To evaluate on target domain with full model:
python3 -m neurst.cli.run_exp \
    --include examples/prune_tune/src/ \
    --entry mask_predict \
    --config_paths novel/prediction_args.yml \
    --model_dir transformer_big_baseline_pruned10_novel/best

or

python3 -m neurst.cli.run_exp \
    --entry predict \
    --config_paths novel/prediction_args.yml \
    --model_dir transformer_big_baseline_pruned10_novel/best
  • To evaluate on general domain with the general sub-network:
python3 -m neurst.cli.run_exp \
    --include examples/prune_tune/src/ \
    --entry mask_predict \
    --config_paths wmt14_en_de/prediction_args.yml \
    --model_dir transformer_big_baseline_pruned10_novel/best \
    --mask_pkl transformer_big_baseline_pruned10_novel/mask.pkl