Skip to content

Commit bf9888c

Browse files
authored
Remove seed buffer and properties from Python API (#303)
Remove clusters getter
1 parent ead43ea commit bf9888c

File tree

2 files changed

+15
-71
lines changed

2 files changed

+15
-71
lines changed

CLUEstering/CLUEstering.py

Lines changed: 14 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class ClusteringDataSoA:
157157
158158
:param coords: Input coordinates including weights.
159159
:type coords: np.ndarray
160-
:param results: Clustering results including cluster IDs and seed flags.
160+
:param results: Clustering results including cluster IDs.
161161
:type results: np.ndarray
162162
:param n_dim: Number of dimensions.
163163
:type n_dim: int
@@ -184,40 +184,27 @@ class cluster_properties:
184184
185185
:param n_clusters: Number of clusters constructed.
186186
:type n_clusters: int
187-
:param n_seeds: Number of seeds found (clusters excluding outliers).
188-
:type n_seeds: int
189-
:param clusters: List of clusters found.
190-
:type clusters: np.ndarray
191187
:param cluster_ids: Cluster ID for each point.
192188
:type cluster_ids: np.ndarray
193-
:param is_seed: Array where 1 indicates a seed and 0 otherwise.
194-
:type is_seed: np.ndarray
195189
:param cluster_points: Lists of point IDs belonging to each cluster.
196190
:type cluster_points: np.ndarray
197191
:param points_per_cluster: Number of points per cluster.
198192
:type points_per_cluster: np.ndarray
199-
:param output_df: DataFrame containing is_seed and cluster_ids as columns.
193+
:param output_df: DataFrame containing the cluster_ids.
200194
:type output_df: pd.DataFrame
201195
"""
202196

203197
n_clusters : int
204-
n_seeds : int
205-
clusters : np.ndarray
206198
cluster_ids : np.ndarray
207-
is_seed : np.ndarray
208199
cluster_points : np.ndarray
209200
points_per_cluster : np.ndarray
210201
output_df : pd.DataFrame
211202

212203
def __eq__(self, other):
213204
if self.n_clusters != other.n_clusters:
214205
return False
215-
if self.n_seeds != other.n_seeds:
216-
return False
217206
if not (self.cluster_ids == other.cluster_ids).all():
218207
return False
219-
if not (self.is_seed == other.is_seed).all():
220-
return False
221208
return True
222209

223210

@@ -317,11 +304,8 @@ def _read_array(self, input_data: Union[list, np.ndarray]) -> None:
317304
coords = np.vstack([input_data[:-1], # coordinates SoA
318305
input_data[-1]], # weights
319306
dtype=np.float32)
320-
results = np.vstack([np.zeros(npoints, dtype=np.int32), # cluster ids
321-
np.zeros(npoints, dtype=np.int32)], # is_seed
322-
dtype=np.int32)
323307
coords = np.ascontiguousarray(coords, dtype=np.float32)
324-
results = np.ascontiguousarray(results, dtype=np.int32)
308+
results = np.zeros(npoints, dtype=np.int32) # cluster ids
325309
self.clust_data = ClusteringDataSoA(coords,
326310
results,
327311
ndim,
@@ -337,11 +321,8 @@ def _read_array(self, input_data: Union[list, np.ndarray]) -> None:
337321
coords = np.vstack([input_data[:-1].T, # coordinates SoA
338322
input_data[-1]], # weights
339323
dtype=np.float32)
340-
results = np.vstack([np.zeros(npoints, dtype=np.int32), # cluster ids
341-
np.zeros(npoints, dtype=np.int32)], # is_seed
342-
dtype=np.int32)
343324
coords = np.ascontiguousarray(coords, dtype=np.float32)
344-
results = np.ascontiguousarray(results, dtype=np.int32)
325+
results = np.zeros(npoints, dtype=np.int32) # cluster ids
345326
self.clust_data = ClusteringDataSoA(coords,
346327
results,
347328
ndim,
@@ -399,10 +380,8 @@ def _handle_dataframe(self, df_: pd.DataFrame) -> None:
399380
npoints = len(df_.index)
400381
coords = df_.iloc[:, 0:-1].to_numpy()
401382
coords = np.vstack([coords.T, df_.iloc[:, -1]], dtype=np.float32)
402-
results = np.vstack([np.zeros(npoints, dtype=np.int32),
403-
np.zeros(npoints, dtype=np.int32)], dtype=np.int32)
404383
coords = np.ascontiguousarray(coords, dtype=np.float32)
405-
results = np.ascontiguousarray(results, dtype=np.int32)
384+
results = np.zeros(npoints, dtype=np.int32)
406385

407386
self.clust_data = ClusteringDataSoA(coords, results, ndim, npoints)
408387

@@ -675,25 +654,19 @@ def run_clue(self,
675654
print("HIP module not found. Please re-compile the library and try again.")
676655

677656
finish = time.time_ns()
678-
cluster_ids = data.results[0]
679-
is_seed = data.results[1]
680-
n_seeds = np.sum(is_seed)
657+
cluster_ids = data.results
681658
n_clusters = np.max(cluster_ids) + 1
682-
clusters = np.arange(n_clusters, dtype=np.int32)
683659

684660
cluster_points = [[] for _ in range(n_clusters)]
685661
for i in range(self.clust_data.n_points):
686662
if cluster_ids[i] != -1:
687663
cluster_points[cluster_ids[i]].append(i)
688664

689665
points_per_cluster = np.array([len(clust) for clust in cluster_points])
690-
output_df = pd.DataFrame({'cluster_ids': cluster_ids, 'is_seed': is_seed})
666+
output_df = pd.DataFrame({'cluster_ids': cluster_ids})
691667

692668
self.clust_prop = cluster_properties(n_clusters,
693-
n_seeds,
694-
clusters,
695669
cluster_ids,
696-
is_seed,
697670
np.asarray(cluster_points, dtype=object),
698671
points_per_cluster,
699672
output_df)
@@ -784,17 +757,6 @@ def n_clusters(self) -> int:
784757
"""
785758
return self.clust_prop.n_clusters
786759

787-
@property
788-
def clusters(self) -> np.ndarray:
789-
"""
790-
List of clusters found.
791-
792-
:return: Array of cluster identifiers.
793-
:rtype: np.ndarray
794-
"""
795-
796-
return self.clust_prop.clusters
797-
798760
@property
799761
def cluster_ids(self) -> np.ndarray:
800762
"""
@@ -842,9 +804,9 @@ def points_per_cluster(self) -> np.ndarray:
842804
@property
843805
def output_df(self) -> pd.DataFrame:
844806
"""
845-
DataFrame containing cluster_ids and seed information.
807+
DataFrame containing cluster_ids.
846808
847-
:return: Pandas DataFrame with cluster assignments and seed flags.
809+
:return: Pandas DataFrame with cluster assignments.
848810
:rtype: pd.DataFrame
849811
"""
850812

@@ -1025,11 +987,11 @@ def input_plotter(self, filepath: Union[str, None] = None, plot_title: str = '',
1025987
def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '',
1026988
title_size: float = 16, x_label: str = 'x', y_label: str = 'y',
1027989
z_label: str = 'z', label_size: float = 16, outl_size: float = 10,
1028-
pt_size: float = 10, seed_size: float = 25, grid: bool = True,
990+
pt_size: float = 10, grid: bool = True,
1029991
grid_style: str = '--', grid_size: float = 0.2, x_ticks=None,
1030992
y_ticks=None, z_ticks=None, **kwargs) -> None:
1031993
"""
1032-
Plots clusters with different colors, seeds as stars, and outliers as gray crosses.
994+
Plots clusters with different colors and outliers as gray crosses.
1033995
1034996
:param filepath: Path to save the plot. If None, the plot is shown interactively.
1035997
:type filepath: str or None
@@ -1049,8 +1011,6 @@ def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '
10491011
:type outl_size: float
10501012
:param pt_size: Marker size for cluster points.
10511013
:type pt_size: float
1052-
:param seed_size: Marker size for seed points.
1053-
:type seed_size: float
10541014
:param grid: Whether to display a grid.
10551015
:type grid: bool
10561016
:param grid_style: Line style of the grid.
@@ -1073,8 +1033,7 @@ def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '
10731033
if self.clust_data.n_dim == 1:
10741034
data = {'x0': self.coords[0],
10751035
'x1': np.zeros(self.clust_data.n_points),
1076-
'cluster_ids': self.clust_prop.cluster_ids,
1077-
'isSeed': self.clust_prop.is_seed}
1036+
'cluster_ids': self.clust_prop.cluster_ids}
10781037
df_ = pd.DataFrame(data)
10791038

10801039
max_clusterid = max(df_["cluster_ids"])
@@ -1084,8 +1043,6 @@ def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '
10841043
for i in range(0, max_clusterid+1):
10851044
dfi = df_[df_.cluster_ids == i] # ith cluster
10861045
plt.scatter(dfi.x0, dfi.x1, s=pt_size, marker='.')
1087-
df_seed = df_[df_.isSeed == 1] # Only Seeds
1088-
plt.scatter(df_seed.x0, df_seed.x1, s=seed_size, color='r', marker='*')
10891046

10901047
# Customization of the plot title
10911048
plt.title(plot_title, fontsize=title_size)
@@ -1145,8 +1102,7 @@ def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '
11451102
data = {'x0': self.coords[0],
11461103
'x1': self.coords[1],
11471104
'x2': self.coords[2],
1148-
'cluster_ids': self.clust_prop.cluster_ids,
1149-
'isSeed': self.clust_prop.is_seed}
1105+
'cluster_ids': self.clust_prop.cluster_ids}
11501106
df_ = pd.DataFrame(data)
11511107

11521108
max_clusterid = max(df_["cluster_ids"])
@@ -1159,9 +1115,6 @@ def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '
11591115
dfi = df_[df_.cluster_ids == i]
11601116
ax_.scatter(dfi.x0, dfi.x1, dfi.x2, s=pt_size, marker = '.')
11611117

1162-
df_seed = df_[df_.isSeed == 1] # Only Seeds
1163-
ax_.scatter(df_seed.x0, df_seed.x1, df_seed.x2, s=seed_size, color = 'r', marker = '*')
1164-
11651118
# Customization of the plot title
11661119
ax_.set_title(plot_title, fontsize=title_size)
11671120

@@ -1190,7 +1143,7 @@ def cluster_plotter(self, filepath: Union[str, None] = None, plot_title: str = '
11901143

11911144
def to_csv(self, output_folder: str, file_name: str) -> None:
11921145
"""
1193-
Creates a file containing the coordinates of all the points, their cluster_ids, and is_seed.
1146+
Creates a file containing the coordinates of all the points and their cluster_ids.
11941147
11951148
:param output_folder: Full path to the desired output folder.
11961149
:type output_folder: str
@@ -1211,7 +1164,6 @@ def to_csv(self, output_folder: str, file_name: str) -> None:
12111164
data['x' + str(i)] = self.clust_data.coords[i]
12121165
data['weight'] = self.clust_data.coords[-1]
12131166
data['cluster_ids'] = self.clust_prop.cluster_ids
1214-
data['is_seed'] = self.clust_prop.is_seed
12151167

12161168
df_ = pd.DataFrame(data)
12171169
df_.to_csv(out_path, index=False)
@@ -1235,13 +1187,10 @@ def import_clusterer(self, input_folder: str, file_name: str) -> None:
12351187
in_path = input_folder + file_name
12361188
df_ = pd.read_csv(in_path, dtype=float)
12371189
cluster_ids = np.asarray(df_["cluster_ids"], dtype=int)
1238-
is_seed = np.array(df_["is_seed"], dtype=int)
12391190

12401191
self._handle_dataframe(df_.iloc[:, :-2])
12411192

1242-
n_seeds = np.sum(is_seed)
12431193
n_clusters = np.max(cluster_ids) + 1
1244-
clusters = np.arange(n_clusters, dtype=np.int32)
12451194

12461195
cluster_points = [[] for _ in range(n_clusters)]
12471196
for i in range(self.clust_data.n_points):
@@ -1250,10 +1199,7 @@ def import_clusterer(self, input_folder: str, file_name: str) -> None:
12501199
points_per_cluster = np.array([len(clust) for clust in cluster_points])
12511200
self.clust_prop = cluster_properties(
12521201
n_clusters,
1253-
n_seeds,
1254-
clusters,
12551202
cluster_ids,
1256-
is_seed,
12571203
np.asarray(cluster_points, dtype=object),
12581204
points_per_cluster,
12591205
df_

tests/test_clusterer_properties.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,10 @@ def test_clusterer_properties(dataset):
4343
assert nclusters > 0
4444
cluster_ids = c.cluster_ids
4545
assert np.max(cluster_ids) + 1 == nclusters
46-
clusters = c.clusters
47-
assert len(clusters) == nclusters
4846
labels = c.labels
4947
assert len(labels) == 999
5048
assert (cluster_ids == labels).all()
5149
cluster_points = c.cluster_points
5250
assert len(cluster_points) == nclusters
5351
output_df = c.output_df
54-
assert output_df.shape == (999, 2)
52+
assert output_df.shape == (999, 1)

0 commit comments

Comments
 (0)