Skip to content

Commit 3cab42a

Browse files
authored
Merge pull request #101 from PyDataBlog/minibatch
* Added MiniBatch algorithm
2 parents 9c221c9 + 6010166 commit 3cab42a

11 files changed

+336
-40
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ParallelKMeans"
22
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af"
33
authors = ["Bernard Brenyah", "Andrey Oskin"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

README.md

+9-6
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://PyDataBlog.github.io/ParallelKMeans.jl/stable)
44
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://PyDataBlog.github.io/ParallelKMeans.jl/dev)
5-
[![Build Status](https://www.travis-ci.org/PyDataBlog/ParallelKMeans.jl.svg?branch=master)](https://www.travis-ci.org/PyDataBlog/ParallelKMeans.jl)
6-
[![Coverage Status](https://coveralls.io/repos/github/PyDataBlog/ParallelKMeans.jl/badge.svg?branch=master)](https://coveralls.io/github/PyDataBlog/ParallelKMeans.jl?branch=master)
5+
[![Build Status](https://github.com/PyDataBlog/ParallelKMeans.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/PyDataBlog/ParallelKMeans.jl/actions/workflows/CI.yml/badge.svg)
6+
[![codecov](https://codecov.io/gh/PyDataBlog/ParallelKMeans.jl/branch/master/graph/badge.svg?token=799USS6BPH)](https://codecov.io/gh/PyDataBlog/ParallelKMeans.jl)
77
[![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2FPyDataBlog%2FParallelKMeans.jl.svg?type=shield)](https://app.fossa.com/projects/git%2Bgithub.com%2FPyDataBlog%2FParallelKMeans.jl?ref=badge_shield)
88
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/PyDataBlog/ParallelKMeans.jl/master)
99
_________________________________________________________________________________________________________
@@ -22,10 +22,13 @@ ________________________________________________________________________________
2222

2323
## Table Of Content
2424

25-
1. [Documentation](#Documentation)
26-
2. [Installation](#Installation)
27-
3. [Features](#Features)
28-
4. [License](#License)
25+
- [ParallelKMeans](#parallelkmeans)
26+
- [Table Of Content](#table-of-content)
27+
- [Documentation](#documentation)
28+
- [Installation](#installation)
29+
- [Features](#features)
30+
- [Benchmarks](#benchmarks)
31+
- [License](#license)
2932

3033
_________________________________________________________________________________________________________
3134

docs/src/index.md

+8-6
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ pkg> free ParallelKMeans
7878
- [X] Implementation of [Coresets](http://proceedings.mlr.press/v51/lucic16-supp.pdf).
7979
- [X] Support for weighted K-means.
8080
- [X] Support of MLJ Random generation hyperparameter.
81+
- [X] Implementation of [Mini-batch KMeans variant](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf)
8182
- [ ] Support for other distance metrics supported by [Distances.jl](https://github.com/JuliaStats/Distances.jl#supported-distances).
8283
- [ ] Implementation of [Geometric methods to accelerate k-means algorithm](http://cs.baylor.edu/~hamerly/papers/sdm2016_rysavy_hamerly.pdf).
8384
- [ ] Native support for tabular data inputs outside of MLJModels' interface.
84-
- [ ] Refactoring and finalization of API design.
85-
- [ ] GPU support.
85+
- [ ] GPU support?
8686
- [ ] Distributed calculations support.
8787
- [ ] Optimization of code base.
8888
- [ ] Improved Documentation
@@ -127,8 +127,8 @@ r.converged # whether the procedure converged
127127
- [Elkan()](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf) - Recommended for high dimensional data.
128128
- [Yinyang()](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/ding15.pdf) - Recommended for large dimensions and/or large number of clusters.
129129
- [Coreset()](http://proceedings.mlr.press/v51/lucic16-supp.pdf) - Recommended for very fast clustering of very large datasets, when extreme accuracy is not important.
130+
- [MiniBatch()](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf) - Recommended for extremely large datasets, when extreme accuracy is not important.
130131
- [Geometric()](http://cs.baylor.edu/~hamerly/papers/sdm2016_rysavy_hamerly.pdf) - (Coming soon)
131-
- [MiniBatch()](https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf) - (Coming soon)
132132

133133
### Practical Usage Examples
134134

@@ -175,9 +175,9 @@ Currently, this package is benchmarked against similar implementations in both P
175175

176176
*Note*: All benchmark tests are made on the same computer to help eliminate any bias.
177177

178-
|PC Name |CPU |Ram |
179-
|:---------------------------:|:------------------------:|:----------------:|
180-
|iMac (Retina 5K 27-inch 2019)|3 GHz 6-Core Intel Core i5|8 GB 2667 MHz DDR4|
178+
|PC Name |CPU |Ram |
179+
|:---------------------------:|:------------------------:|:-----------------:|
180+
|iMac (Retina 5K 27-inch 2019)|3 GHz 6-Core Intel Core i5|24 GB 2667 MHz DDR4|
181181

182182
Currently, the benchmark speed tests are based on the search for optimal number of clusters using the [Elbow Method](https://en.wikipedia.org/wiki/Elbow_method_(clustering)) since this is a practical use case for most practioners employing the K-Means algorithm.
183183

@@ -213,6 +213,8 @@ ________________________________________________________________________________
213213
- 0.1.7 Added `Yinyang` and `Coreset` support in MLJ interface; added `weights` support in MLJ; added RNG seed support in MLJ interface and through all algorithms; added metric support.
214214
- 0.1.8 Minor cleanup
215215
- 0.1.9 Added travis support for Julia 1.5
216+
- 0.2.0 Updated MLJ Interface
217+
- 0.2.1 Mini-batch implementation
216218

217219
## Contributing
218220

src/ParallelKMeans.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ include("hamerly.jl")
1515
include("elkan.jl")
1616
include("yinyang.jl")
1717
include("coreset.jl")
18+
include("mini_batch.jl")
1819
include("mlj_interface.jl")
1920

2021
export kmeans
21-
export Lloyd, Hamerly, Elkan, Yinyang, 阴阳, Coreset
22+
export Lloyd, Hamerly, Elkan, Yinyang, 阴阳, Coreset, MiniBatch
2223

2324
end # module

src/kmeans.jl

+17-7
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Allocationless calculation of square eucledean distance between vectors X1[:, i1
115115
@inline function distance(metric::Euclidean, X1, X2, i1, i2)
116116
# here goes my definition
117117
d = zero(eltype(X1))
118-
# TODO: break of the loop if d is larger than threshold (known minimum disatnce)
118+
# TODO: break of the loop if d is larger than threshold (known minimum distance)
119119
@inbounds @simd for i in axes(X1, 1)
120120
d += (X1[i, i1] - X2[i, i2])^2
121121
end
@@ -170,20 +170,30 @@ alternatively one can use `rand` to choose random points for init.
170170
171171
A `KmeansResult` structure representing labels, centroids, and sum_squares is returned.
172172
"""
173-
function kmeans(alg::AbstractKMeansAlg, design_matrix, k; weights = nothing,
173+
function kmeans(alg::AbstractKMeansAlg, design_matrix, k;
174+
weights = nothing,
174175
n_threads = Threads.nthreads(),
175-
k_init = "k-means++", max_iters = 300,
176-
tol = eltype(design_matrix)(1e-6), verbose = false,
177-
init = nothing, rng = Random.GLOBAL_RNG, metric = Euclidean())
176+
k_init = "k-means++",
177+
max_iters = 300,
178+
tol = eltype(design_matrix)(1e-6),
179+
verbose = false,
180+
init = nothing,
181+
rng = Random.GLOBAL_RNG,
182+
metric = Euclidean())
178183

179184
nrow, ncol = size(design_matrix)
180185

181186
# Create containers based on the dimensions and specifications
182187
containers = create_containers(alg, design_matrix, k, nrow, ncol, n_threads)
183188

184189
return kmeans!(alg, containers, design_matrix, k, weights, metric;
185-
n_threads = n_threads, k_init = k_init, max_iters = max_iters,
186-
tol = tol, verbose = verbose, init = init, rng = rng)
190+
n_threads = n_threads,
191+
k_init = k_init,
192+
max_iters = max_iters,
193+
tol = tol,
194+
verbose = verbose,
195+
init = init,
196+
rng = rng)
187197

188198
end
189199

src/lloyd.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Basic algorithm for k-means calculation.
66
struct Lloyd <: AbstractKMeansAlg end
77

88
"""
9-
Kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true)
9+
kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true)
1010
1111
Mutable version of `kmeans` function. Definition of arguments and results can be
1212
found in `kmeans`.

src/mini_batch.jl

+183
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
"""
2+
MiniBatch(b::Int)
3+
`b` represents the size of the batch which should be sampled.
4+
5+
Sculley et al. 2007 Mini batch k-means algorithm implementation.
6+
7+
```julia
8+
X = rand(30, 100_000) # 100_000 random points in 30 dimensions
9+
10+
kmeans(MiniBatch(100), X, 3) # 3 clusters, MiniBatch algorithm with 100 batch samples at each iteration
11+
```
12+
"""
13+
struct MiniBatch <: AbstractKMeansAlg
14+
b::Int # batch size
15+
end
16+
17+
18+
MiniBatch() = MiniBatch(100)
19+
20+
function kmeans!(alg::MiniBatch, containers, X, k,
21+
weights = nothing, metric = Euclidean(); n_threads = Threads.nthreads(),
22+
k_init = "k-means++", init = nothing, max_iters = 300,
23+
tol = eltype(X)(1e-6), max_no_improvement = 10, verbose = false, rng = Random.GLOBAL_RNG)
24+
25+
# Retrieve initialized artifacts from the container
26+
centroids = containers.centroids_new
27+
batch_rand_idx = containers.batch_rand_idx
28+
labels = containers.labels
29+
30+
# Get the type and dimensions of design matrix, X - (Step 1)
31+
T = eltype(X)
32+
nrow, ncol = size(X)
33+
34+
# Initiate cluster centers - (Step 2) in paper
35+
centroids .= isnothing(init) ? smart_init(X, k, n_threads, weights, rng, init = k_init).centroids : deepcopy(init)
36+
37+
# Initialize counter for the no. of data in each cluster - (Step 3) in paper
38+
N = zeros(T, k)
39+
40+
# Initialize various artifacts
41+
converged = false
42+
niters = 1
43+
counter = 0
44+
J_previous = zero(T)
45+
J = zero(T)
46+
totalcost = zero(T)
47+
48+
# Main Steps. Batch update centroids until convergence
49+
while niters <= max_iters # Step 4 in paper
50+
51+
# b examples picked randomly from X (Step 5 in paper)
52+
isnothing(weights) ? rand!(rng, batch_rand_idx, 1:ncol) : wsample!(rng, 1:ncol, weights, batch_rand_idx)
53+
54+
# Cache/label the batch samples nearest to the centers (Step 6 & 7)
55+
@inbounds for i in batch_rand_idx
56+
min_dist = distance(metric, X, centroids, i, 1)
57+
label = 1
58+
59+
for j in 2:size(centroids, 2)
60+
dist = distance(metric, X, centroids, i, j)
61+
label = dist < min_dist ? j : label
62+
min_dist = dist < min_dist ? dist : min_dist
63+
end
64+
65+
labels[i] = label
66+
67+
##### Batch gradient step #####
68+
# iterate over examples (each column) ==> (Step 9)
69+
# Get cached center/label for each example label = labels[i] => (Step 10)
70+
71+
# Update per-center counts
72+
N[label] += isnothing(weights) ? 1 : weights[i] # (Step 11)
73+
74+
# Get per-center learning rate (Step 12)
75+
lr = 1 / N[label]
76+
77+
# Take gradient step (Step 13) # TODO: Replace with faster loop?
78+
@views centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* X[:, i])
79+
end
80+
81+
# Reassign all labels based on new centres generated from the latest sample
82+
labels .= reassign_labels(X, metric, labels, centroids)
83+
84+
# Calculate cost on whole dataset after reassignment and check for convergence
85+
@parallelize 1 ncol sum_of_squares(containers, X, labels, centroids, weights, metric)
86+
J = sum(containers.sum_of_squares)
87+
88+
if verbose
89+
# Show progress and terminate if J stopped decreasing.
90+
println("Iteration $niters: Jclust = $J")
91+
end
92+
93+
# Check for early stopping convergence
94+
if (niters > 1) & (abs(J - J_previous) < (tol * J))
95+
counter += 1
96+
97+
# Declare convergence if max_no_improvement criterion is met
98+
if counter >= max_no_improvement
99+
converged = true
100+
# Compute label assignment for the complete dataset
101+
labels .= reassign_labels(X, metric, labels, centroids)
102+
103+
# Compute totalcost for the complete dataset
104+
@parallelize 1 ncol sum_of_squares(containers, X, labels, centroids, weights, metric)
105+
totalcost = sum(containers.sum_of_squares)
106+
107+
# Print convergence message to user
108+
if verbose
109+
println("Successfully terminated with convergence.")
110+
end
111+
112+
break
113+
end
114+
else
115+
counter = 0
116+
end
117+
118+
# Warn users if model doesn't converge at max iterations
119+
if (niters >= max_iters) & (!converged)
120+
121+
if verbose
122+
println("Clustering model failed to converge. Labelling data with latest centroids.")
123+
end
124+
125+
labels .= reassign_labels(X, metric, labels, centroids)
126+
127+
# Compute totalcost for unconverged model
128+
@parallelize 1 ncol sum_of_squares(containers, X, labels, centroids, weights, metric)
129+
totalcost = sum(containers.sum_of_squares)
130+
131+
break
132+
end
133+
134+
J_previous = J
135+
niters += 1
136+
end
137+
138+
# Push learned artifacts to KmeansResult
139+
return KmeansResult(centroids, labels, T[], Int[], T[], totalcost, niters, converged)
140+
end
141+
142+
"""
143+
reassign_labels(DMatrix, metric, labels, centres)
144+
145+
An internal function to relabel DMatrix based on centres and metric.
146+
"""
147+
function reassign_labels(DMatrix, metric, labels, centres)
148+
@inbounds for i in axes(DMatrix, 2)
149+
min_dist = distance(metric, DMatrix, centres, i, 1)
150+
label = 1
151+
152+
for j in 2:size(centres, 2)
153+
dist = distance(metric, DMatrix, centres, i, j)
154+
label = dist < min_dist ? j : label
155+
min_dist = dist < min_dist ? dist : min_dist
156+
end
157+
158+
labels[i] = label
159+
end
160+
return labels
161+
end
162+
163+
"""
164+
create_containers(::MiniBatch, k, nrow, ncol, n_threads)
165+
166+
Internal function for the creation of all necessary intermidiate structures.
167+
168+
- `centroids_new` - container which holds new positions of centroids
169+
- `labels` - vector which holds labels of corresponding points
170+
- `sum_of_squares` - vector which holds the sum of squares values for each thread
171+
- `batch_rand_idx` - vector which holds the selected batch indices
172+
"""
173+
function create_containers(alg::MiniBatch, X, k, nrow, ncol, n_threads)
174+
# Initiate placeholders to avoid allocations
175+
T = eltype(X)
176+
labels = Vector{Int}(undef, ncol) # labels vector
177+
sum_of_squares = Vector{T}(undef, 1) # total_sum_calculation
178+
batch_rand_idx = Vector{Int}(undef, alg.b) # selected batch indices
179+
centroids_new = Matrix{T}(undef, nrow, k) # centroids
180+
181+
return (batch_rand_idx = batch_rand_idx, centroids_new = centroids_new,
182+
labels = labels, sum_of_squares = sum_of_squares)
183+
end

0 commit comments

Comments
 (0)