-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhierarchical_clustering.py
69 lines (58 loc) · 2.34 KB
/
hierarchical_clustering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import pdist
from mpl_toolkits.axes_grid1 import make_axes_locatable
from src import config
def make_cat_cluster_fig(cat, bottom_off=False, max_probes=20, metric='cosine', x_max=config.Fig.CAT_CLUSTER_XLIM):
"""
Returns fig showing hierarchical clustering of probes in single category
"""
# load data
cat_prototypes_df = model.get_single_cat_probe_prototype_acts_df(cat)
if len(cat_prototypes_df) > max_probes:
ids = np.random.choice(len(cat_prototypes_df) - 1, max_probes, replace=False)
cat_prototypes_df = cat_prototypes_df.iloc[ids]
probes_in_cat = cat_prototypes_df.index
else:
probes_in_cat = cat_prototypes_df.index.tolist()
# fig
fig, ax = plt.subplots(figsize=(config.Fig.fig_size, 4), dpi=config.Fig.dpi)
dist_matrix = pdist(cat_prototypes_df.values, metric=metric)
linkages = linkage(dist_matrix, method='complete')
dendrogram(linkages,
ax=ax,
leaf_label_func=lambda x: probes_in_cat[x],
orientation='right',
leaf_font_size=8)
ax.set_title(cat)
ax.set_xlim([0, x_max])
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['top'].set_visible(False)
if bottom_off:
ax.xaxis.set_ticklabels([]) # hides ticklabels
ax.spines['bottom'].set_visible(False)
fig.tight_layout()
return fig
def make_multi_cat_clust_fig(cats, metric='cosine'):
"""
Returns fig showing hierarchical clustering of probes from multiple categories
"""
# load data
df = pd.DataFrame(pd.concat((model.get_single_cat_probe_prototype_acts_df(cat) for cat in cats), axis=0))
cat_acts_mat = df.values
cats_probe_list = df.index
# fig
fig, ax = plt.subplots(figsize=(config.Fig.fig_size, 5 * len(cats)), dpi=config.Fig.dpi)
dist_matrix = pdist(cat_acts_mat, metric)
linkages = linkage(dist_matrix, method='complete')
dendrogram(linkages,
ax=ax,
labels=cats_probe_list,
orientation='right',
leaf_font_size=10)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['top'].set_visible(False)
return fig