-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbleu.py
147 lines (134 loc) · 4.38 KB
/
bleu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse
import random
from subprocess import Popen, PIPE
import os
import sys
from itertools import islice
parser = argparse.ArgumentParser(description='Script to compute BLEU')
parser.add_argument(
"--ref", type=str, required=True,
help="path to file with references"
)
parser.add_argument(
"--hyp", type=str, required=True,
help="path to file with hypotheses"
)
parser.add_argument(
"--n", type=int, required=True,
help="--n argument used to generate --ref file using eval.py"
)
parser.add_argument(
"--with_contexts", dest="with_contexts", action="store_true",
help="whether to consider contexts or not when compute BLEU"
)
parser.add_argument(
"--bleu_path", type=str, required=True,
help="path to mosesdecoder sentence-bleu binary"
)
parser.add_argument(
"--mode", type=str, required=True,
help="whether to average or take random example per word"
)
args = parser.parse_args()
assert args.mode in ["average", "random"], "--mode must be averange or random"
def next_n_lines(file_opened, N):
return [x.strip() for x in islice(file_opened, N)]
def read_def_file(file, n, with_contexts=False):
defs = {}
while True:
lines = next_n_lines(file, n + 2)
if len(lines) == 0:
break
assert len(lines) == n + 2, "Something bad in hyps file"
word = lines[0].split("Word:")[1].strip()
context = lines[1].split("Context:")[1].strip()
dict_key = word + " " + context if with_contexts else word
if dict_key not in defs:
defs[dict_key] = []
for i in range(2, n + 2):
defs[dict_key].append(lines[i].strip())
return defs
def read_ref_file(file, with_contexts=False):
defs = {}
while True:
lines = next_n_lines(file, 3)
if len(lines) == 0:
break
assert len(lines) == 3, "Something bad in refs file"
word = lines[0].split("Word:")[1].strip()
context = lines[1].split("Context:")[1].strip()
definition = lines[2].split("Definition:")[1].strip()
dict_key = word + " " + context if with_contexts else word
if dict_key not in defs:
defs[dict_key] = []
defs[dict_key].append(definition)
return defs
def get_bleu_score(bleu_path, all_ref_paths, d, hyp_path):
with open(hyp_path, 'w') as ofp:
ofp.write(d)
read_cmd = ['cat', hyp_path]
bleu_cmd = [bleu_path] + all_ref_paths
rp = Popen(read_cmd, stdout=PIPE)
bp = Popen(bleu_cmd, stdin=rp.stdout, stdout=PIPE, stderr=devnull)
out, err = bp.communicate()
if err is None:
return float(out.strip())
else:
return None
with open(args.ref) as ifp:
refs = read_ref_file(ifp, args.with_contexts)
with open(args.hyp) as ifp:
hyps = read_def_file(ifp, args.n, args.with_contexts)
assert len(refs) == len(hyps), "Number of words being defined mismatched!"
tmp_dir = "/tmp"
suffix = str(random.random())
words = refs.keys()
hyp_path = os.path.join(tmp_dir, 'hyp' + suffix)
to_be_deleted = set()
to_be_deleted.add(hyp_path)
# Computing BLEU
devnull = open(os.devnull, 'w')
score = 0
count = 0
total_refs = 0
total_hyps = 0
for word in words:
if word not in refs or word not in hyps:
continue
wrefs = refs[word]
whyps = hyps[word]
# write out references
all_ref_paths = []
for i, d in enumerate(wrefs):
ref_path = os.path.join(tmp_dir, 'ref' + suffix + str(i))
with open(ref_path, 'w') as ofp:
ofp.write(d)
all_ref_paths.append(ref_path)
to_be_deleted.add(ref_path)
total_refs += len(all_ref_paths)
# score for each output
micro_score = 0
micro_count = 0
if args.mode == "average":
for d in whyps:
rhscore = get_bleu_score(
args.bleu_path, all_ref_paths, d, hyp_path)
if rhscore is not None:
micro_score += rhscore
micro_count += 1
elif args.mode == "random":
d = random.choice(whyps)
rhscore = get_bleu_score(args.bleu_path, all_ref_paths, d, hyp_path)
if rhscore is not None:
micro_score += rhscore
micro_count += 1
total_hyps += micro_count
score += micro_score / micro_count
count += 1
devnull.close()
# delete tmp files
for f in to_be_deleted:
os.remove(f)
print("BLEU: ", score / count)
print("NUM HYPS USED: ", total_hyps)
print("NUM REFS USED: ", total_refs)