Skip to content

Commit 2bfc81d

Browse files
author
Andrey Oskin
committed
Initial hamerly implementation
1 parent 626dec3 commit 2bfc81d

File tree

4 files changed

+373
-3
lines changed

4 files changed

+373
-3
lines changed

src/ParallelKMeans.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ include("seeding.jl")
77
include("kmeans.jl")
88
include("lloyd.jl")
99
include("light_elkan.jl")
10+
include("hamerly.jl")
1011

1112
export kmeans
12-
export Lloyd, LightElkan
13+
export Lloyd, LightElkan, Hamerly
1314

1415
end # module

src/hamerly.jl

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
struct Hamerly <: AbstractKMeansAlg end
2+
3+
function kmeans(alg::Hamerly, design_matrix, k;
4+
n_threads = Threads.nthreads(),
5+
k_init = "k-means++", max_iters = 300,
6+
tol = 1e-6, verbose = true, init = nothing)
7+
nrow, ncol = size(design_matrix)
8+
containers = create_containers(alg, k, nrow, ncol, n_threads)
9+
10+
return kmeans!(alg, containers, design_matrix, k, n_threads = n_threads,
11+
k_init = k_init, max_iters = max_iters, tol = tol,
12+
verbose = verbose, init = init)
13+
end
14+
15+
function kmeans!(alg::Hamerly, containers, design_matrix, k;
16+
n_threads = Threads.nthreads(),
17+
k_init = "k-means++", max_iters = 300,
18+
tol = 1e-6, verbose = true, init = nothing)
19+
nrow, ncol = size(design_matrix)
20+
centroids = init == nothing ? smart_init(design_matrix, k, n_threads, init=k_init).centroids : deepcopy(init)
21+
22+
initialize!(alg, containers, centroids, design_matrix, n_threads)
23+
24+
converged = false
25+
niters = 1
26+
J_previous = 0.0
27+
28+
# Update centroids & labels with closest members until convergence
29+
30+
while niters <= max_iters
31+
update_containers!(containers, alg, centroids, n_threads)
32+
update_centroids!(centroids, containers, alg, design_matrix, n_threads)
33+
J = sum(containers.ub)
34+
move_centers!(centroids, containers, alg)
35+
update_bounds!(containers, n_threads)
36+
37+
if verbose
38+
# Show progress and terminate if J stopped decreasing.
39+
println("Iteration $niters: Jclust = $J")
40+
end
41+
42+
# Check for convergence
43+
if (niters > 1) & (abs(J - J_previous) < (tol * J))
44+
converged = true
45+
break
46+
end
47+
48+
J_previous = J
49+
niters += 1
50+
end
51+
52+
totalcost = sum_of_squares(design_matrix, containers.labels, centroids)
53+
54+
# Terminate algorithm with the assumption that K-means has converged
55+
if verbose & converged
56+
println("Successfully terminated with convergence.")
57+
end
58+
59+
# TODO empty placeholder vectors should be calculated
60+
# TODO Float64 type definitions is too restrictive, should be relaxed
61+
# especially during GPU related development
62+
return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged)
63+
end
64+
65+
function collect_containers(alg::Hamerly, containers, n_threads)
66+
if n_threads == 1
67+
@inbounds containers.centroids_new[end] .= containers.centroids_new[1] ./ containers.centroids_cnt[1]'
68+
else
69+
@inbounds containers.centroids_new[end] .= containers.centroids_new[1]
70+
@inbounds containers.centroids_cnt[end] .= containers.centroids_cnt[1]
71+
@inbounds for i in 2:n_threads
72+
containers.centroids_new[end] .+= containers.centroids_new[i]
73+
containers.centroids_cnt[end] .+= containers.centroids_cnt[i]
74+
end
75+
76+
@inbounds containers.centroids_new[end] .= containers.centroids_new[end] ./ containers.centroids_cnt[end]'
77+
end
78+
end
79+
80+
function create_containers(alg::Hamerly, k, nrow, ncol, n_threads)
81+
lng = n_threads + 1
82+
centroids_new = Vector{Array{Float64,2}}(undef, lng)
83+
centroids_cnt = Vector{Vector{Int}}(undef, lng)
84+
85+
for i = 1:lng
86+
centroids_new[i] = zeros(nrow, k)
87+
centroids_cnt[i] = zeros(k)
88+
end
89+
90+
# Upper bound to the closest center
91+
ub = Vector{Float64}(undef, ncol)
92+
93+
# lower bound to the second closest center
94+
lb = Vector{Float64}(undef, ncol)
95+
96+
labels = zeros(Int, ncol)
97+
98+
# distance that centroid moved
99+
p = Vector{Float64}(undef, k)
100+
101+
# distance from the center to the closest other center
102+
s = Vector{Float64}(undef, k)
103+
104+
return (
105+
centroids_new = centroids_new,
106+
centroids_cnt = centroids_cnt,
107+
labels = labels,
108+
ub = ub,
109+
lb = lb,
110+
p = p,
111+
s = s,
112+
)
113+
end
114+
115+
function initialize!(alg::Hamerly, containers, centroids, design_matrix, n_threads)
116+
ncol = size(design_matrix, 2)
117+
118+
if n_threads == 1
119+
r = axes(design_matrix, 2)
120+
chunk_initialize!(alg, containers, centroids, design_matrix, r, 1)
121+
else
122+
ranges = splitter(ncol, n_threads)
123+
124+
waiting_list = Vector{Task}(undef, n_threads - 1)
125+
126+
for i in 1:n_threads - 1
127+
waiting_list[i] = @spawn chunk_initialize!(alg, containers, centroids,
128+
design_matrix, ranges[i], i + 1)
129+
end
130+
131+
chunk_initialize!(alg, containers, centroids, design_matrix, ranges[end], 1)
132+
133+
wait.(waiting_list)
134+
end
135+
end
136+
137+
function chunk_initialize!(alg::Hamerly, containers, centroids, design_matrix, r, idx)
138+
centroids_cnt = containers.centroids_cnt[idx]
139+
centroids_new = containers.centroids_new[idx]
140+
141+
@inbounds for i in r
142+
label = point_all_centers!(containers, centroids, design_matrix, i)
143+
centroids_cnt[label] += 1
144+
for j in axes(design_matrix, 1)
145+
centroids_new[j, label] += design_matrix[j, i]
146+
end
147+
end
148+
end
149+
150+
function update_containers!(containers, ::Hamerly, centroids, n_threads)
151+
s = containers.s
152+
s .= Inf
153+
@inbounds for i in axes(centroids, 2)
154+
for j in i+1:size(centroids, 2)
155+
d = distance(centroids, centroids, i, j)
156+
d = 0.25*d
157+
s[i] = s[i] > d ? d : s[i]
158+
s[j] = s[j] > d ? d : s[j]
159+
end
160+
end
161+
end
162+
163+
function update_centroids!(centroids, containers, alg::Hamerly, design_matrix, n_threads)
164+
165+
if n_threads == 1
166+
r = axes(design_matrix, 2)
167+
chunk_update_centroids!(centroids, containers, alg, design_matrix, r, 1)
168+
else
169+
ncol = size(design_matrix, 2)
170+
ranges = splitter(ncol, n_threads)
171+
172+
waiting_list = Vector{Task}(undef, n_threads - 1)
173+
174+
for i in 1:length(ranges) - 1
175+
waiting_list[i] = @spawn chunk_update_centroids!(centroids, containers,
176+
alg, design_matrix, ranges[i], i)
177+
end
178+
179+
chunk_update_centroids!(centroids, containers, alg, design_matrix, ranges[end], n_threads)
180+
181+
wait.(waiting_list)
182+
183+
end
184+
185+
collect_containers(alg, containers, n_threads)
186+
end
187+
188+
function chunk_update_centroids!(
189+
centroids,
190+
containers,
191+
alg::Hamerly,
192+
design_matrix,
193+
r,
194+
idx,
195+
)
196+
197+
# unpack containers for easier manipulations
198+
centroids_new = containers.centroids_new[idx]
199+
centroids_cnt = containers.centroids_cnt[idx]
200+
labels = containers.labels
201+
s = containers.s
202+
lb = containers.lb
203+
ub = containers.ub
204+
205+
@inbounds for i in r
206+
# m ← max(s(a(i))/2, l(i))
207+
m = max(s[labels[i]], lb[i])
208+
# first bound test
209+
if ub[i] > m
210+
# tighten upper bound
211+
label = labels[i]
212+
ub[i] = distance(design_matrix, centroids, i, label)
213+
# second bound test
214+
if ub[i] > m
215+
label_new = point_all_centers!(containers, centroids, design_matrix, i)
216+
if label != label_new
217+
labels[i] = label_new
218+
centroids_cnt[label_new] += 1
219+
centroids_cnt[label] -= 1
220+
for j in axes(design_matrix, 1)
221+
centroids_new[j, label_new] += design_matrix[j, i]
222+
centroids_new[j, label] -= design_matrix[j, i]
223+
end
224+
end
225+
end
226+
end
227+
end
228+
end
229+
230+
function point_all_centers!(containers, centroids, design_matrix, i)
231+
ub = containers.ub
232+
lb = containers.lb
233+
labels = containers.labels
234+
235+
min_distance = Inf
236+
min_distance2 = Inf
237+
label = 1
238+
@inbounds for k in axes(centroids, 2)
239+
dist = distance(design_matrix, centroids, i, k)
240+
if min_distance > dist
241+
label = k
242+
min_distance2 = min_distance
243+
min_distance = dist
244+
elseif min_distance2 > dist
245+
min_distance2 = dist
246+
end
247+
end
248+
249+
ub[i] = min_distance
250+
lb[i] = min_distance2
251+
labels[i] = label
252+
253+
return label
254+
end
255+
256+
function move_centers!(centroids, containers, ::Hamerly)
257+
centroids_new = containers.centroids_new[end]
258+
p = containers.p
259+
260+
@inbounds for i in axes(centroids, 2)
261+
d = 0.0
262+
for j in axes(centroids, 1)
263+
d += (centroids[j, i] - centroids_new[j, i])^2
264+
centroids[j, i] = centroids_new[j, i]
265+
end
266+
p[i] = d
267+
end
268+
end
269+
270+
function update_bounds!(containers, n_threads)
271+
p = containers.p
272+
273+
r1, r2 = double_argmax(p)
274+
pr1 = p[r1]
275+
pr2 = p[r2]
276+
277+
if n_threads == 1
278+
r = axes(containers.ub, 1)
279+
chunk_update_bounds!(containers, r, r1, r2, pr1, pr2)
280+
else
281+
ncol = length(containers.ub)
282+
ranges = splitter(ncol, n_threads)
283+
284+
waiting_list = Vector{Task}(undef, n_threads - 1)
285+
286+
for i in 1:n_threads - 1
287+
waiting_list[i] = @spawn chunk_update_bounds!(containers, ranges[i], r1, r2, pr1, pr2)
288+
end
289+
290+
chunk_update_bounds!(containers, ranges[end], r1, r2, pr1, pr2)
291+
292+
for i in 1:n_threads - 1
293+
wait(waiting_list[i])
294+
end
295+
end
296+
end
297+
298+
function chunk_update_bounds!(containers, r, r1, r2, pr1, pr2)
299+
p = containers.p
300+
ub = containers.ub
301+
lb = containers.lb
302+
labels = containers.labels
303+
304+
@inbounds for i in r
305+
label = labels[i]
306+
ub[i] += 2*sqrt(ub[i] * p[label]) + p[label]
307+
if r1 == label
308+
lb[i] += pr2 - 2*sqrt(pr2*lb[i])
309+
else
310+
lb[i] += pr1 - 2*sqrt(pr1*lb[i])
311+
end
312+
end
313+
end
314+
315+
function double_argmax(p)
316+
r1, r2 = 1, 1
317+
d1 = p[1]
318+
d2 = -1.0
319+
for i in 2:length(p)
320+
if p[i] > d1
321+
r2 = r1
322+
r1 = i
323+
d2 = d1
324+
d1 = p[i]
325+
elseif p[i] > d2
326+
d2 = p[i]
327+
r2 = i
328+
end
329+
end
330+
331+
r1, r2
332+
end
333+
334+
"""
335+
distance(X1, X2, i1, i2)
336+
337+
Allocation less calculation of square eucledean distance between vectors X1[:, i1] and X2[:, i2]
338+
"""
339+
function distance(X1, X2, i1, i2)
340+
d = 0.0
341+
@inbounds for i in axes(X1, 1)
342+
d += (X1[i, i1] - X2[i, i2])^2
343+
end
344+
345+
return d
346+
end

src/kmeans.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function kmeans!(alg, containers, design_matrix, k;
129129

130130
if verbose
131131
# Show progress and terminate if J stopped decreasing.
132-
println("Iteration $iter: Jclust = $J")
132+
println("Iteration $niters: Jclust = $J")
133133
end
134134

135135
# Check for convergence
@@ -192,5 +192,5 @@ function update_centroids!(centroids, containers, alg, design_matrix, n_threads)
192192
centroids .= containers.new_centroids[1] ./ containers.centroids_cnt[1]'
193193
end
194194

195-
return J/ncol
195+
return J
196196
end

test/test05_hamerly.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
module TestHamerly
2+
3+
using ParallelKMeans
4+
using ParallelKMeans: initialize!, double_argmax
5+
using Test
6+
using Random
7+
8+
@testset "initialize" begin
9+
X = permutedims([1.0 2; 2 1; 4 5; 6 6])
10+
centroids = permutedims([1.0 2; 4 5; 6 6])
11+
nrow, ncol = size(X)
12+
containers = ParallelKMeans.create_containers(Hamerly(), 3, nrow, ncol, 1)
13+
14+
ParallelKMeans.initialize!(Hamerly(), containers, centroids, X, 1)
15+
@test containers.lb == [18.0, 20.0, 5.0, 5.0]
16+
@test containers.ub == [0.0, 2.0, 0.0, 0.0]
17+
end
18+
19+
@testset "double argmax" begin
20+
@test double_argmax([0.5, 0, 0]) == (1, 2)
21+
end
22+
23+
end # module

0 commit comments

Comments
 (0)