-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
53 lines (41 loc) · 1.45 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import pickle as pkl
import glob
from sklearn import svm
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_score
import networkx as nx
def load_labels(path):
labels = []
for i in range(len(glob.glob("%s/*.graphml" % (path)))):
G = nx.read_graphml("%s/%s.graphml" % (path, i))
labels.append(G.graph("label"))
return labels
"""
Loads embeddings from pickle files and returns it.
:param path: path to the folder containing all pickle files. These files are usually generated by the SG2V or run_sgcn scripts.
:return List of embeddings
"""
def load_embeddings(path):
embeddings = []
for i in range(len(glob.glob("%s/*.pkl" % (path)))):
with open("%s/%s.pkl" % (path, i), "rb") as f:
emb = pkl.load(f)
embeddings.append(emb)
return embeddings
"""
Performs the evaluation.
:param embeddings: list of embeddings
:param labels: List of labels
:return None
"""
def evaluation(embeddings, labels):
binary_classifier_model = svm.SVC(class_weight="balanced")
scores = cross_val_score(binary_classifier_model, embeddings, labels, cv=3, scoring='f1_micro')
print ("Micro F-measure: %0.4f" % (scores.mean()))
def evaluate_SG2V():
for method in ["g2v", "sg2vn", "sg2vsb"]:
embeddings = load_embeddings("out/SG2V/%s" % (method))
labels = load_labels("data/%s" % (method))
evaluation(embeddings, labels)