-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
87 lines (66 loc) · 2.28 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from cosine import getCosinePairwise, getAverageVector, getCosineSimilarity, getProxyDistance
def prototypeClustering(x_data, th, opt1=False):
ready_list = [i for i in range(0, len(x_data))]
label_list = [-1] * len(x_data)
cluster_idx = 0
nxt_idx = 0
if opt1:
cos_mat = getCosinePairwise(x_data)
while len(ready_list) != 0:
temp_ready_list = []
src_idx = nxt_idx
nxt_score = 1
cluster_count = 0
for trg_idx in ready_list:
if opt1:
score = cos_mat[src_idx][trg_idx]
else:
score = getCosineSimilarity(x_data[src_idx], x_data[trg_idx])
if score >= th:
label_list[trg_idx] = cluster_idx
cluster_count += 1
else:
temp_ready_list.append(trg_idx)
if nxt_score > score:
nxt_score = score
nxt_idx = trg_idx
if cluster_count == 1:
label_list[src_idx] = -1
cluster_idx -= 1
ready_list = temp_ready_list
cluster_idx += 1
return label_list
def hierarchicalClustering(x_data, prev_label_list, th):
cluster_list = [[] for _ in range(max(prev_label_list) + 1)]
for i, label in enumerate(prev_label_list):
if label != -1:
cluster_list[label].append(i)
centroid_list = []
for idxs in cluster_list:
vectors = [x_data[i] for i in idxs]
centroid_list.append(getAverageVector(vectors))
if len(centroid_list) == 0:
return prev_label_list
elif len(centroid_list) == 1:
labels = np.array([0])
else:
labels = (
AgglomerativeClustering(
n_clusters=None,
affinity=getProxyDistance,
linkage="single",
distance_threshold=1 - th,
)
.fit(centroid_list)
.labels_
)
label_list = []
for i in range(len(x_data)):
prev = prev_label_list[i]
if prev == -1:
label_list.append(-1)
else:
label_list.append(labels[prev])
return label_list