-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
70 lines (55 loc) · 2.13 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import cfgs.config as config
from common.trainer import Trainer
# from common.trainer_bert import Trainer
import argparse, yaml
import random
from easydict import EasyDict as edict
from get_save_vectors import initial_vectors
def parse_args():
'''
Parse input arguments
'''
parser = argparse.ArgumentParser(description='Bilinear Args')
parser.add_argument('--run', dest='run_mode',
choices=['train', 'val', 'test', 'sample', 'attr'],
help='{train, val, test,}',
type=str, required=True)
parser.add_argument('--model', dest='model',
choices=['bilinear', 'lstm', 'nsc', 'upnn', 'huapa'],
help='{bilinear, ...}',
default='bilinear', type=str)
parser.add_argument('--dataset', dest='dataset',
choices=['imdb', 'yelp_13', 'yelp_14', 'digital', 'industrial', 'software'],
help='{imdb, yelp_13, yelp_14}',
default='imdb', type=str)
parser.add_argument('--gpu', dest='gpu',
help="gpu select, eg.'0, 1, 2'",
type=str,
default="0, 1")
parser.add_argument('--seed', dest='seed',
help='fix random seed',
type=int,
default=random.randint(0, 99999999))
parser.add_argument('--version', dest='version',
help='version control',
type=str,
default="default")
args = parser.parse_args()
return args
if __name__ == '__main__':
__C = config.__C
args = parse_args()
cfg_file = "cfgs/{}_model.yml".format(args.model)
with open(cfg_file, 'r') as f:
yaml_dict = yaml.load(f)
args_dict = edict({**yaml_dict, **vars(args)})
config.add_edit(args_dict, __C)
config.proc(__C)
print('Hyper Parameters:')
config.config_print(__C)
# __C.check_path()
if __C.model == "bilinear":
execution = Trainer(__C)
execution.run(__C.run_mode)
else:
exit()