-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathoptions.py
78 lines (75 loc) · 4.43 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse
import torch
parser = argparse.ArgumentParser()
parser.add_argument('--max_epoch', type=int, default=50)
# parser.add_argument('--dataset', default='yelp') # yelp, yelp-aren or amazon
# path to the datasets
parser.add_argument('--src_data_dir', default='../data/en_triples/')
parser.add_argument('--tgt_data_dir', default='../data/kr_triples/')
parser.add_argument('--en_train_lines', type=int, default=0) # set to 0 to use all
parser.add_argument('--ch_train_lines', type=int, default=0) # set to 0 to use all
parser.add_argument('--max_seq_len', type=int, default=32) # set to 0 to not truncate
parser.add_argument('--random_seed', type=int, default=1)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--model_save_file', default='./save/adan')
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0005)
parser.add_argument('--Q_learning_rate', type=float, default=0.0005)
# path to BWE
# parser.add_argument('--emb_filename', default='../data/wiki.ko.align.vec')
parser.add_argument('--emb_filename', default='en_fr_word_vec_reduced.txt')
parser.add_argument('--fix_emb', action='store_true')
parser.add_argument('--random_emb', action='store_true')
# use a fixed <unk> token for all words without pretrained embeddings when building vocab
parser.add_argument('--fix_unk', action='store_true')
parser.add_argument('--emb_size', type=int, default=300)
parser.add_argument('--model', default='cnn') # dan or lstm or cnn
# for LSTM model
parser.add_argument('--attn', default='dot') # attention mechanism (for LSTM): avg, last, dot
parser.add_argument('--bdrnn', dest='bdrnn', action='store_true', default=True) # bi-directional LSTM
# use deep averaging network or deep summing network (for DAN model)
parser.add_argument('--sum_pooling/', dest='sum_pooling', action='store_true')
parser.add_argument('--avg_pooling/', dest='sum_pooling', action='store_false')
# for CNN model
parser.add_argument('--kernel_num', type=int, default=400)
parser.add_argument('--kernel_sizes', type=int, nargs='+', default=[3,4,5])
parser.add_argument('--hidden_size', type=int, default=900)
parser.add_argument('--F_layers', type=int, default=1)
parser.add_argument('--P_layers', type=int, default=2)
parser.add_argument('--Q_layers', type=int, default=2)
parser.add_argument('--n_critic', type=int, default=5)
parser.add_argument('--lambd', type=float, default=0.01)
parser.add_argument('--F_bn/', dest='F_bn', action='store_true')
parser.add_argument('--no_F_bn/', dest='F_bn', action='store_false')
parser.add_argument('--P_bn/', dest='P_bn', action='store_true', default=True)
parser.add_argument('--no_P_bn/', dest='P_bn', action='store_false')
parser.add_argument('--Q_bn/', dest='Q_bn', action='store_true', default=True)
parser.add_argument('--no_Q_bn/', dest='Q_bn', action='store_false')
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--clip_lower', type=float, default=-0.01)
parser.add_argument('--clip_upper', type=float, default=0.01)
parser.add_argument('--max_step', type=int, default=100000)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--debug/', dest='debug', action='store_true')
parser.add_argument("--bert_model", default='bert-base-multilingual-cased', type=str)
parser.add_argument("--tb_log_dir", default="runs/null", type=str)
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--target_lang", required=True, type=str, choices=["kr", "fr", "zh", "ru", "nl", "ar", "ja", "es", "it", "pt", "hi", "de", "sv", "ur", "mn", "sk", "kk", "tr", "sl", "ms", "hu"])
parser.add_argument('--do_eval', action='store_true')
parser.add_argument('--do_train', action='store_true')
parser.add_argument('--test_model', type=str)
parser.add_argument('--test_F_model', type=str)
parser.add_argument('--test_P_model', type=str)
parser.add_argument('--test_file', type=str)
parser.add_argument('--do_raw_mbert', action='store_true')
parser.add_argument('--debugging', action='store_true')
parser.add_argument('--mono_train', type=str)
parser.add_argument('--tv_train', type=str)
parser.add_argument('--dev_file', type=str)
parser.add_argument('--monolingual', action='store_true')
parser.add_argument('--sov', action='store_true')
parser.add_argument('--do_aug', action='store_true')
opt = parser.parse_args()
if not torch.cuda.is_available():
opt.device = 'cpu'