-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathbleu.py
108 lines (77 loc) · 2.92 KB
/
bleu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# bleu.py
# author: Playinf
# email: [email protected]
import math
from collections import Counter
def closest_length(candidate, references):
clen = len(candidate)
closest_diff = 9999
closest_len = 9999
for reference in references:
rlen = len(reference)
diff = abs(rlen - clen)
if diff < closest_diff:
closest_diff = diff
closest_len = rlen
elif diff == closest_diff:
closest_len = rlen if rlen < closest_len else closest_len
return closest_len
def shortest_length(references):
return min([len(ref) for ref in references])
def modified_precision(candidate, references, n):
tngrams = len(candidate) + 1 - n
counts = Counter([tuple(candidate[i : i + n]) for i in range(tngrams)])
if len(counts) == 0:
return 0, 0
max_counts = {}
for reference in references:
rngrams = len(reference) + 1 - n
ngrams = [tuple(reference[i : i + n]) for i in range(rngrams)]
ref_counts = Counter(ngrams)
for ngram in counts:
mcount = 0 if ngram not in max_counts else max_counts[ngram]
rcount = 0 if ngram not in ref_counts else ref_counts[ngram]
max_counts[ngram] = max(mcount, rcount)
clipped_counts = {}
for ngram, count in counts.items():
clipped_counts[ngram] = min(count, max_counts[ngram])
return float(sum(clipped_counts.values())), float(sum(counts.values()))
def brevity_penalty(trans, refs, mode="closest"):
bp_c = 0.0
bp_r = 0.0
for candidate, references in zip(trans, refs):
bp_c += len(candidate)
if mode == "shortest":
bp_r += shortest_length(references)
else:
bp_r += closest_length(candidate, references)
return math.exp(min(0, 1.0 - bp_r / bp_c))
# trans: a list of tokenized sentence
# refs: a list of list of tokenized reference sentences
def bleu(trans, refs, bp="closest", smooth=False, n=4, weights=None):
p_norm = [0 for i in range(n)]
p_denorm = [0 for i in range(n)]
for candidate, references in zip(trans, refs):
for i in range(n):
ccount, tcount = modified_precision(candidate, references, i + 1)
p_norm[i] += ccount
p_denorm[i] += tcount
bleu_n = [0 for i in range(n)]
for i in range(n):
# add one smoothing
if smooth and i > 0:
p_norm[i] += 1
p_denorm[i] += 1
if p_norm[i] == 0 or p_denorm[i] == 0:
bleu_n[i] = -9999
else:
bleu_n[i] = math.log(float(p_norm[i]) / float(p_denorm[i]))
if weights:
if len(weights) != n:
raise ValueError("len(weights) != n: invalid weight number")
log_precision = sum([bleu_n[i] * weights[i] for i in range(n)])
else:
log_precision = sum(bleu_n) / float(n)
bp = brevity_penalty(trans, refs, bp)
bleu = bp * math.exp(log_precision)
return bleu