forked from awni/semantic-rntn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_baseline.py
59 lines (43 loc) · 1.72 KB
/
run_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import optparse
from baseline import Baseline
from ucca_tree import *
def run(args=None):
usage = "usage : %prog [options]"
parser = optparse.OptionParser(usage=usage)
parser.add_option("--test", action="store_true", dest="test", default=False)
parser.add_option("--output_dim", dest="output_dim", type="int", default=0)
parser.add_option("--out_file", dest="out_file", type="string",
default="models/baseline.bin")
parser.add_option("--in_file", dest="in_file", type="string",
default="models/baseline.bin")
parser.add_option("--data", dest="data", type="string", default="train")
(opts, args) = parser.parse_args(args)
# Testing
if opts.test:
test(opts.in_file, opts.data)
return
print("Loading data...")
# load training data
trees = load_trees()
if opts.output_dim == 0:
opts.output_dim = len(load_label_map())
baseline = Baseline(opts.output_dim)
baseline.train(trees)
with open(opts.out_file, 'wb') as fid:
pickle.dump(opts, fid)
baseline.to_file(fid)
def test(baseline_file, data_set):
assert baseline_file is not None, "Must give model to test"
trees = load_trees(data_set)
assert trees, "No data found"
with open(baseline_file, 'rb') as fid:
opts = pickle.load(fid)
baseline = Baseline(opts.output_dim)
baseline.from_file(fid)
print("Testing...")
correct, total, pred = baseline.predict(trees)
print("Correct %d/%d, Acc %f" % (correct, total, correct / float(total)))
print_trees('results/gold.txt', trees, 'Labeled')
print_trees('results/pred_baseline.txt', pred, 'Baseline predicted')
if __name__ == '__main__':
run()