Skip to content

Commit

Permalink
Merge pull request #45 from bsc-wdc/kmeans-fix
Browse files Browse the repository at this point in the history
Kmeans fix
  • Loading branch information
kafkasl authored Nov 29, 2018
2 parents 8c850be + 2ba6778 commit 0cf5d34
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 167 deletions.
4 changes: 2 additions & 2 deletions dislib/cluster/dbscan/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pycompss.api.api import compss_wait_on
from pycompss.api.task import task

from dislib.cluster.dbscan.classes import Square
from dislib.cluster.dbscan.classes import Region
from dislib.data import Dataset


Expand Down Expand Up @@ -101,7 +101,7 @@ def fit(self, dataset):

# Compute dbscan in each region of the grid
for idx in np.ndindex(grid.shape):
grid[idx] = Square(idx, self._eps, grid.shape, region_sizes)
grid[idx] = Region(idx, self._eps, grid.shape, region_sizes)
grid[idx].init_data(sorted_data, grid.shape)
grid[idx].partial_scan(self._min_samples, self._max_samples)

Expand Down
2 changes: 1 addition & 1 deletion dislib/cluster/dbscan/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
_NO_CP = -3


class Square(object):
class Region(object):
def __init__(self, coord, epsilon, grid_shape, region_sizes):
self.coord = coord
self.epsilon = epsilon
Expand Down
12 changes: 6 additions & 6 deletions dislib/cluster/kmeans/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def fit(self, dataset):
-----
This method modifies the input Dataset by setting the cluster labels.
"""
centers = _init_centers(dataset[0], self._n_clusters,
centers = _init_centers(dataset.n_features, self._n_clusters,
self._random_state)
self.centers = compss_wait_on(centers)

Expand Down Expand Up @@ -105,7 +105,7 @@ def fit_predict(self, dataset):
for subset in dataset:
labels.append(_get_label(subset))

return np.array(compss_wait_on(labels))
return np.concatenate(compss_wait_on(labels))

def predict(self, x):
""" Predict the closest cluster each sample in x belongs to.
Expand All @@ -122,9 +122,10 @@ def predict(self, x):
"""
labels = []

for x in x:
dist = np.linalg.norm(x - self.centers, axis=1)
for sample in x:
dist = np.linalg.norm(sample - self.centers, axis=1)
labels.append(np.argmin(dist))

return np.array(labels)

def _converged(self, old_centers, iter):
Expand Down Expand Up @@ -154,9 +155,8 @@ def _get_label(subset):


@task(returns=np.array)
def _init_centers(subset, n_clusters, random_state):
def _init_centers(n_features, n_clusters, random_state):
np.random.seed(random_state)
n_features = subset.samples.shape[1]
centers = np.random.random((n_clusters, n_features))
return centers

Expand Down
1 change: 1 addition & 0 deletions examples/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def main():

kmeans = KMeans(n_clusters=3, random_state=random_state)
y_pred = kmeans.fit_predict(dataset)

plt.subplot(224)
plt.scatter(x_filtered[:, 0], x_filtered[:, 1], c=y_pred)
centers = kmeans.centers
Expand Down
Loading

0 comments on commit 0cf5d34

Please sign in to comment.