-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathtrain_transmitter.py
145 lines (128 loc) · 4.25 KB
/
train_transmitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Train model for ppl metric with pre-selected parameters.
These parameters have some variance in their final perplexity, but they were
used to achieve the pre-trained model.
"""
import os
import random
import torch
from agents.transmitter.transmitter import ARCH_CHOICE
from parlai.scripts.train_model import setup_args as setup_dict_args, TrainLoop
# if is original, train model on original data; otherwise on revised data.
IS_ORIGINAL = False
TRANSMITTER_DIR = './tmp/transmitter'
VERSION = "transmitter_revised"
def setup_task():
if IS_ORIGINAL:
task_name = 'tasks.convai2transmitter.agents:SelfOriginalTeacher'
else:
task_name = 'tasks.convai2transmitter.agents:SelfRevisedTeacher'
return task_name
def setup_seed(seed=1706123):
# random seed, to evaluate the performance
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
def gpt_setting():
return 10, 1e-4, 'gpt_custom', 1.0
def lstm_setting():
return 64, 3, 'sgd', 0.1
def setup_args():
"""
Use create test env setting
:return: opt
"""
parser = setup_dict_args()
exp_name = VERSION
n_epoches = 100
beam_size = 2
encode_layers = 2
decode_layers = 2
embedding_size = 256
turn_emed_size = 50
encoder_turn_use = False
encoder_dis_use = False
encoder_hidden_size = 1024
decoder_hidden_size = 1024
encode_max_seq_len = 256
decode_max_seq_len = 32
smoothing = 0.05
dropout = 0.1
embedding_type = 'glove'
momentum = 0.9
persona_append_strategy = 'concat'
history_append_strategy = -1
select_persona = False
shuffle_persona = True
share_decoder_input_output_embed = False
num_train_epochs = 4
if ARCH_CHOICE == 'gpt':
batchsize, lr, optimizer, gradient_clip = gpt_setting()
else:
batchsize, lr, optimizer, gradient_clip = lstm_setting()
task_name = setup_task()
parser.set_defaults(
task=task_name,
rank_candidates=False,
# task='tasks.convai2transmitter.agents:SelfRevisedTeacher:no_cands',
model='agents.transmitter.transmitter:TransformerAgent',
model_file='./tmp/transmitter/{}.model'.format(exp_name),
dict_tokenizer='split',
datatype='train',
gpt_lr=6.25e-5,
n_epoches=n_epoches,
num_epochs=num_train_epochs,
batchsize=batchsize,
beam_size=beam_size,
encoder_layers=encode_layers,
decoder_layers=decode_layers,
encoder_embed_dim=embedding_size,
encoder_turn_dim=turn_emed_size,
encoder_turn_use=encoder_turn_use,
encoder_dis_use=encoder_dis_use,
decoder_embed_dim=embedding_size,
encode_max_seq_len=encode_max_seq_len,
decode_max_seq_len=decode_max_seq_len,
select_persona=select_persona,
shuffle_persona=shuffle_persona,
persona_append_strategy=persona_append_strategy,
history_append_strategy=history_append_strategy,
encoder_bidirectional=False,
encoder_hidden_size=encoder_hidden_size,
decoder_hidden_size=decoder_hidden_size,
smoothing=smoothing,
lr=lr,
dropout=dropout,
encoder_dropout_in=dropout,
encoder_dropout_out=0,
decoder_dropout_in=dropout,
decoder_dropout_out=0,
share_decoder_input_output_embed=share_decoder_input_output_embed,
gradient_clip=gradient_clip,
lookuptable='enc_dec',
optimizer=optimizer,
embedding_type=embedding_type,
momentum=momentum,
# rough enough
validation_max_exs=-1,
validation_every_n_secs=3600,
validation_metric='ppl',
validation_metric_mode='min',
validation_patience=10,
log_every_n_secs=30,
gpu=0,
# logging configuration
exp=exp_name,
tensorboard_log=True,
tensorboard_tag='exp',
train_report_metrics='ppl,f1,hits@1',
tensorboard_metrics='ppl,f1,hits@1',
)
return parser
if __name__ == '__main__':
opt = setup_args()
setup_seed()
TrainLoop(opt).train()