Skip to content

Commit 79436e1

Browse files
authored
Test python clusterer's plotting methods (#301)
1 parent 1db729d commit 79436e1

File tree

2 files changed

+123
-17
lines changed

2 files changed

+123
-17
lines changed

CLUEstering/CLUEstering.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -760,17 +760,6 @@ def n_clusters(self) -> int:
760760
"""
761761
return self.clust_prop.n_clusters
762762

763-
@property
764-
def n_seeds(self) -> int:
765-
"""
766-
Number of seeds found.
767-
768-
:return: Number of seeds.
769-
:rtype: int
770-
"""
771-
772-
return self.clust_prop.n_seeds
773-
774763
@property
775764
def clusters(self) -> np.ndarray:
776765
"""
@@ -1088,8 +1077,6 @@ def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '
10881077
# Customization of axis ticks
10891078
if x_ticks is not None:
10901079
plt.xticks(x_ticks)
1091-
if y_ticks is not None:
1092-
plt.yticks(y_ticks)
10931080

10941081
if filepath is not None:
10951082
plt.savefig(filepath)
@@ -1098,8 +1085,7 @@ def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '
10981085
elif self.clust_data.n_dim == 2:
10991086
data = {'x0': self.coords[0],
11001087
'x1': self.coords[1],
1101-
'cluster_ids': self._cluster_ids,
1102-
'isSeed': self._is_seed}
1088+
'cluster_ids': self.cluster_ids}
11031089
df_ = pd.DataFrame(data)
11041090

11051091
max_clusterid = max(df_["cluster_ids"])
@@ -1109,8 +1095,6 @@ def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '
11091095
for i in range(0, max_clusterid+1):
11101096
dfi = df_[df_.cluster_ids == i] # ith cluster
11111097
plt.scatter(dfi.x0, dfi.x1, s=pt_size, marker='.')
1112-
df_seed = df_[df_.isSeed == 1] # Only Seeds
1113-
plt.scatter(df_seed.x0, df_seed.x1, s=seed_size, color='r', marker='*')
11141098

11151099
# Customization of the plot title
11161100
plt.title(plot_title, fontsize=title_size)

tests/test_plotters.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
'''
2+
Test the clusterer's plotting methods
3+
'''
4+
5+
import matplotlib
6+
import os
7+
import sys
8+
import pandas as pd
9+
import pytest
10+
from check_result import check_result
11+
sys.path.insert(1, '../CLUEstering/')
12+
import CLUEstering as clue
13+
matplotlib.use("Agg")
14+
15+
16+
@pytest.fixture
17+
def dataset1d():
18+
'''
19+
Returns a 1d dataset for testing
20+
'''
21+
22+
data = {'x': [1, 2, 1.5, 8, 9, 8.5, 50, 51, 49.5], 'weight': [1]*9}
23+
return pd.DataFrame(data)
24+
25+
26+
@pytest.fixture
27+
def dataset2d():
28+
'''
29+
Returns a 2d dataset for testing
30+
'''
31+
32+
data = {'x': [1, 2, 1.5, 8, 9, 8.5, 50, 51, 49.5],
33+
'y': [1, 2, 1.5, 8, 9, 8.5, 50, 51, 49.5],
34+
'weight': [1]*9}
35+
return pd.DataFrame(data)
36+
37+
38+
@pytest.fixture
39+
def dataset3d():
40+
'''
41+
Returns a generic dataset for testing
42+
'''
43+
44+
data = {'x': [1, 2, 1.5, 8, 9, 8.5, 50, 51, 49.5],
45+
'y': [1, 2, 1.5, 8, 9, 8.5, 50, 51, 49.5],
46+
'z': [1, 2, 1.5, 8, 9, 8.5, 50, 51, 49.5],
47+
'weight': [1]*9}
48+
return pd.DataFrame(data)
49+
50+
51+
def test_input_plotter(dataset1d, dataset2d, dataset3d):
52+
'''
53+
Tests the input plotter method of the clusterer
54+
'''
55+
56+
c = clue.clusterer(3., 2., 3.)
57+
58+
xticks = [0, 10, 20, 30, 40, 50, 60]
59+
yticks = [0, 10, 20, 30, 40, 50, 60]
60+
zticks = [0, 10, 20, 30, 40, 50, 60]
61+
62+
filename = 'dataset1d.png'
63+
c.read_data(dataset1d)
64+
assert c.n_dim == 1
65+
c.input_plotter()
66+
c.input_plotter(grid=True, xticks=xticks)
67+
c.input_plotter(filename)
68+
assert os.path.isfile(filename)
69+
70+
filename = 'dataset2d.png'
71+
c.read_data(dataset2d)
72+
assert c.n_dim == 2
73+
c.input_plotter()
74+
c.input_plotter(grid=True, xticks=xticks, yticks=yticks)
75+
c.input_plotter(filename)
76+
assert os.path.isfile(filename)
77+
78+
filename = 'dataset3d.png'
79+
c.read_data(dataset3d)
80+
assert c.n_dim == 3
81+
c.input_plotter()
82+
c.input_plotter(grid=True, xticks=xticks, yticks=yticks, zticks=zticks)
83+
c.input_plotter(filename)
84+
assert os.path.isfile(filename)
85+
86+
def test_output_plotter(dataset1d, dataset2d, dataset3d):
87+
'''
88+
Tests the output plotter method of the clusterer
89+
'''
90+
91+
c = clue.clusterer(3., 2., 3.)
92+
93+
xticks = [0, 10, 20, 30, 40, 50, 60]
94+
yticks = [0, 10, 20, 30, 40, 50, 60]
95+
zticks = [0, 10, 20, 30, 40, 50, 60]
96+
97+
filename = 'output1d.png'
98+
c.read_data(dataset1d)
99+
assert c.n_dim == 1
100+
c.run_clue()
101+
c.cluster_plotter()
102+
c.cluster_plotter(grid=True, xticks=xticks)
103+
c.cluster_plotter(filename)
104+
assert os.path.isfile(filename)
105+
106+
filename = 'output2d.png'
107+
c.read_data(dataset2d)
108+
assert c.n_dim == 2
109+
c.run_clue()
110+
c.cluster_plotter()
111+
c.cluster_plotter(grid=True, xticks=xticks, yticks=yticks)
112+
c.cluster_plotter(filename)
113+
assert os.path.isfile(filename)
114+
115+
filename = 'output3d.png'
116+
c.read_data(dataset3d)
117+
assert c.n_dim == 3
118+
c.run_clue()
119+
c.cluster_plotter()
120+
c.cluster_plotter(grid=True, xticks=xticks, yticks=yticks, zticks=zticks)
121+
c.cluster_plotter(filename)
122+
assert os.path.isfile(filename)

0 commit comments

Comments
 (0)