-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdemo-eval-2.py
45 lines (36 loc) · 1.92 KB
/
demo-eval-2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from clause.util.utils import get_base_dir
from clause import Ranking
from clause import TripleSet
# ***Example Evaluation***
# This example illustrates how to read and evaluate a ranking that has been stored previously.
# Please run demo-eval-1.py first to create this ranking
target = f"{get_base_dir()}/data/wnrr/test.txt"
ranking_path = f"{get_base_dir()}/local/ranking-wnrr.txt"
testset = TripleSet(target)
ranking = Ranking(ranking_path)
ranking.compute_scores(testset.triples)
print("MRR " + '{0:.6f}'.format(ranking.hits.get_mrr()))
print("hits@1 " + '{0:.6f}'.format(ranking.hits.get_hits_at_k(1)))
print("hits@10 " + '{0:.6f}'.format(ranking.hits.get_hits_at_k(10)))
print("*** EVALUATION RESULTS ****")
print("Num triples: " + str(len(testset.triples)))
print("MRR " + '{0:.6f}'.format(ranking.hits.get_mrr()))
print("hits@1 " + '{0:.6f}'.format(ranking.hits.get_hits_at_k(1)))
print("hits@3 " + '{0:.6f}'.format(ranking.hits.get_hits_at_k(3)))
print("hits@10 " + '{0:.6f}'.format(ranking.hits.get_hits_at_k(10)))
print()
# now some code to some nice overview on the different relations and directions
# the loop interates over all relations in the test set
print("relation".ljust(25) + "\t" + "MRR-h" + "\t" + "MRR-t" + "\t" + "Num triples")
for rel in testset.rels:
rel_token = testset.index.id2to[rel]
# store all triples that use the current relation rel in rtriples
rtriples = list(filter(lambda x: x.rel == rel, testset.triples))
# compute scores in head direction ...
ranking.compute_scores(rtriples, True, False)
(mrr_head, h1_head) = (ranking.hits.get_mrr(), ranking.hits.get_hits_at_k(1))
# ... and in tail direction
ranking.compute_scores(rtriples, False, True)
(mrr_tail, h1_tail) = (ranking.hits.get_mrr(), ranking.hits.get_hits_at_k(1))
# print the resulting scores
print(rel_token.ljust(25) + "\t" + '{0:.3f}'.format(mrr_head) + "\t" + '{0:.3f}'.format(mrr_tail) + "\t" + str(len(rtriples)))