|
| 1 | +struct Hamerly <: AbstractKMeansAlg end |
| 2 | + |
| 3 | +function kmeans(alg::Hamerly, design_matrix, k; |
| 4 | + n_threads = Threads.nthreads(), |
| 5 | + k_init = "k-means++", max_iters = 300, |
| 6 | + tol = 1e-6, verbose = true, init = nothing) |
| 7 | + nrow, ncol = size(design_matrix) |
| 8 | + containers = create_containers(alg, k, nrow, ncol, n_threads) |
| 9 | + |
| 10 | + return kmeans!(alg, containers, design_matrix, k, n_threads = n_threads, |
| 11 | + k_init = k_init, max_iters = max_iters, tol = tol, |
| 12 | + verbose = verbose, init = init) |
| 13 | +end |
| 14 | + |
| 15 | +function kmeans!(alg::Hamerly, containers, design_matrix, k; |
| 16 | + n_threads = Threads.nthreads(), |
| 17 | + k_init = "k-means++", max_iters = 300, |
| 18 | + tol = 1e-6, verbose = true, init = nothing) |
| 19 | + nrow, ncol = size(design_matrix) |
| 20 | + centroids = init == nothing ? smart_init(design_matrix, k, n_threads, init=k_init).centroids : deepcopy(init) |
| 21 | + |
| 22 | + initialize!(alg, containers, centroids, design_matrix, n_threads) |
| 23 | + |
| 24 | + converged = false |
| 25 | + niters = 1 |
| 26 | + J_previous = 0.0 |
| 27 | + |
| 28 | + # Update centroids & labels with closest members until convergence |
| 29 | + |
| 30 | + while niters <= max_iters |
| 31 | + update_containers!(containers, alg, centroids, n_threads) |
| 32 | + update_centroids!(centroids, containers, alg, design_matrix, n_threads) |
| 33 | + J = sum(containers.ub) |
| 34 | + move_centers!(centroids, containers, alg) |
| 35 | + update_bounds!(containers, n_threads) |
| 36 | + |
| 37 | + if verbose |
| 38 | + # Show progress and terminate if J stopped decreasing. |
| 39 | + println("Iteration $niters: Jclust = $J") |
| 40 | + end |
| 41 | + |
| 42 | + # Check for convergence |
| 43 | + if (niters > 1) & (abs(J - J_previous) < (tol * J)) |
| 44 | + converged = true |
| 45 | + break |
| 46 | + end |
| 47 | + |
| 48 | + J_previous = J |
| 49 | + niters += 1 |
| 50 | + end |
| 51 | + |
| 52 | + totalcost = sum_of_squares(design_matrix, containers.labels, centroids) |
| 53 | + |
| 54 | + # Terminate algorithm with the assumption that K-means has converged |
| 55 | + if verbose & converged |
| 56 | + println("Successfully terminated with convergence.") |
| 57 | + end |
| 58 | + |
| 59 | + # TODO empty placeholder vectors should be calculated |
| 60 | + # TODO Float64 type definitions is too restrictive, should be relaxed |
| 61 | + # especially during GPU related development |
| 62 | + return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged) |
| 63 | +end |
| 64 | + |
| 65 | +function collect_containers(alg::Hamerly, containers, n_threads) |
| 66 | + if n_threads == 1 |
| 67 | + @inbounds containers.centroids_new[end] .= containers.centroids_new[1] ./ containers.centroids_cnt[1]' |
| 68 | + else |
| 69 | + @inbounds containers.centroids_new[end] .= containers.centroids_new[1] |
| 70 | + @inbounds containers.centroids_cnt[end] .= containers.centroids_cnt[1] |
| 71 | + @inbounds for i in 2:n_threads |
| 72 | + containers.centroids_new[end] .+= containers.centroids_new[i] |
| 73 | + containers.centroids_cnt[end] .+= containers.centroids_cnt[i] |
| 74 | + end |
| 75 | + |
| 76 | + @inbounds containers.centroids_new[end] .= containers.centroids_new[end] ./ containers.centroids_cnt[end]' |
| 77 | + end |
| 78 | +end |
| 79 | + |
| 80 | +function create_containers(alg::Hamerly, k, nrow, ncol, n_threads) |
| 81 | + lng = n_threads + 1 |
| 82 | + centroids_new = Vector{Array{Float64,2}}(undef, lng) |
| 83 | + centroids_cnt = Vector{Vector{Int}}(undef, lng) |
| 84 | + |
| 85 | + for i = 1:lng |
| 86 | + centroids_new[i] = zeros(nrow, k) |
| 87 | + centroids_cnt[i] = zeros(k) |
| 88 | + end |
| 89 | + |
| 90 | + # Upper bound to the closest center |
| 91 | + ub = Vector{Float64}(undef, ncol) |
| 92 | + |
| 93 | + # lower bound to the second closest center |
| 94 | + lb = Vector{Float64}(undef, ncol) |
| 95 | + |
| 96 | + labels = zeros(Int, ncol) |
| 97 | + |
| 98 | + # distance that centroid moved |
| 99 | + p = Vector{Float64}(undef, k) |
| 100 | + |
| 101 | + # distance from the center to the closest other center |
| 102 | + s = Vector{Float64}(undef, k) |
| 103 | + |
| 104 | + return ( |
| 105 | + centroids_new = centroids_new, |
| 106 | + centroids_cnt = centroids_cnt, |
| 107 | + labels = labels, |
| 108 | + ub = ub, |
| 109 | + lb = lb, |
| 110 | + p = p, |
| 111 | + s = s, |
| 112 | + ) |
| 113 | +end |
| 114 | + |
| 115 | +function initialize!(alg::Hamerly, containers, centroids, design_matrix, n_threads) |
| 116 | + ncol = size(design_matrix, 2) |
| 117 | + |
| 118 | + if n_threads == 1 |
| 119 | + r = axes(design_matrix, 2) |
| 120 | + chunk_initialize!(alg, containers, centroids, design_matrix, r, 1) |
| 121 | + else |
| 122 | + ranges = splitter(ncol, n_threads) |
| 123 | + |
| 124 | + waiting_list = Vector{Task}(undef, n_threads - 1) |
| 125 | + |
| 126 | + for i in 1:n_threads - 1 |
| 127 | + waiting_list[i] = @spawn chunk_initialize!(alg, containers, centroids, |
| 128 | + design_matrix, ranges[i], i + 1) |
| 129 | + end |
| 130 | + |
| 131 | + chunk_initialize!(alg, containers, centroids, design_matrix, ranges[end], 1) |
| 132 | + |
| 133 | + wait.(waiting_list) |
| 134 | + end |
| 135 | +end |
| 136 | + |
| 137 | +function chunk_initialize!(alg::Hamerly, containers, centroids, design_matrix, r, idx) |
| 138 | + centroids_cnt = containers.centroids_cnt[idx] |
| 139 | + centroids_new = containers.centroids_new[idx] |
| 140 | + |
| 141 | + @inbounds for i in r |
| 142 | + label = point_all_centers!(containers, centroids, design_matrix, i) |
| 143 | + centroids_cnt[label] += 1 |
| 144 | + for j in axes(design_matrix, 1) |
| 145 | + centroids_new[j, label] += design_matrix[j, i] |
| 146 | + end |
| 147 | + end |
| 148 | +end |
| 149 | + |
| 150 | +function update_containers!(containers, ::Hamerly, centroids, n_threads) |
| 151 | + s = containers.s |
| 152 | + s .= Inf |
| 153 | + @inbounds for i in axes(centroids, 2) |
| 154 | + for j in i+1:size(centroids, 2) |
| 155 | + d = distance(centroids, centroids, i, j) |
| 156 | + d = 0.25*d |
| 157 | + s[i] = s[i] > d ? d : s[i] |
| 158 | + s[j] = s[j] > d ? d : s[j] |
| 159 | + end |
| 160 | + end |
| 161 | +end |
| 162 | + |
| 163 | +function update_centroids!(centroids, containers, alg::Hamerly, design_matrix, n_threads) |
| 164 | + |
| 165 | + if n_threads == 1 |
| 166 | + r = axes(design_matrix, 2) |
| 167 | + chunk_update_centroids!(centroids, containers, alg, design_matrix, r, 1) |
| 168 | + else |
| 169 | + ncol = size(design_matrix, 2) |
| 170 | + ranges = splitter(ncol, n_threads) |
| 171 | + |
| 172 | + waiting_list = Vector{Task}(undef, n_threads - 1) |
| 173 | + |
| 174 | + for i in 1:length(ranges) - 1 |
| 175 | + waiting_list[i] = @spawn chunk_update_centroids!(centroids, containers, |
| 176 | + alg, design_matrix, ranges[i], i) |
| 177 | + end |
| 178 | + |
| 179 | + chunk_update_centroids!(centroids, containers, alg, design_matrix, ranges[end], n_threads) |
| 180 | + |
| 181 | + wait.(waiting_list) |
| 182 | + |
| 183 | + end |
| 184 | + |
| 185 | + collect_containers(alg, containers, n_threads) |
| 186 | +end |
| 187 | + |
| 188 | +function chunk_update_centroids!( |
| 189 | + centroids, |
| 190 | + containers, |
| 191 | + alg::Hamerly, |
| 192 | + design_matrix, |
| 193 | + r, |
| 194 | + idx, |
| 195 | +) |
| 196 | + |
| 197 | + # unpack containers for easier manipulations |
| 198 | + centroids_new = containers.centroids_new[idx] |
| 199 | + centroids_cnt = containers.centroids_cnt[idx] |
| 200 | + labels = containers.labels |
| 201 | + s = containers.s |
| 202 | + lb = containers.lb |
| 203 | + ub = containers.ub |
| 204 | + |
| 205 | + @inbounds for i in r |
| 206 | + # m ← max(s(a(i))/2, l(i)) |
| 207 | + m = max(s[labels[i]], lb[i]) |
| 208 | + # first bound test |
| 209 | + if ub[i] > m |
| 210 | + # tighten upper bound |
| 211 | + label = labels[i] |
| 212 | + ub[i] = distance(design_matrix, centroids, i, label) |
| 213 | + # second bound test |
| 214 | + if ub[i] > m |
| 215 | + label_new = point_all_centers!(containers, centroids, design_matrix, i) |
| 216 | + if label != label_new |
| 217 | + labels[i] = label_new |
| 218 | + centroids_cnt[label_new] += 1 |
| 219 | + centroids_cnt[label] -= 1 |
| 220 | + for j in axes(design_matrix, 1) |
| 221 | + centroids_new[j, label_new] += design_matrix[j, i] |
| 222 | + centroids_new[j, label] -= design_matrix[j, i] |
| 223 | + end |
| 224 | + end |
| 225 | + end |
| 226 | + end |
| 227 | + end |
| 228 | +end |
| 229 | + |
| 230 | +function point_all_centers!(containers, centroids, design_matrix, i) |
| 231 | + ub = containers.ub |
| 232 | + lb = containers.lb |
| 233 | + labels = containers.labels |
| 234 | + |
| 235 | + min_distance = Inf |
| 236 | + min_distance2 = Inf |
| 237 | + label = 1 |
| 238 | + @inbounds for k in axes(centroids, 2) |
| 239 | + dist = distance(design_matrix, centroids, i, k) |
| 240 | + if min_distance > dist |
| 241 | + label = k |
| 242 | + min_distance2 = min_distance |
| 243 | + min_distance = dist |
| 244 | + elseif min_distance2 > dist |
| 245 | + min_distance2 = dist |
| 246 | + end |
| 247 | + end |
| 248 | + |
| 249 | + ub[i] = min_distance |
| 250 | + lb[i] = min_distance2 |
| 251 | + labels[i] = label |
| 252 | + |
| 253 | + return label |
| 254 | +end |
| 255 | + |
| 256 | +function move_centers!(centroids, containers, ::Hamerly) |
| 257 | + centroids_new = containers.centroids_new[end] |
| 258 | + p = containers.p |
| 259 | + |
| 260 | + @inbounds for i in axes(centroids, 2) |
| 261 | + d = 0.0 |
| 262 | + for j in axes(centroids, 1) |
| 263 | + d += (centroids[j, i] - centroids_new[j, i])^2 |
| 264 | + centroids[j, i] = centroids_new[j, i] |
| 265 | + end |
| 266 | + p[i] = d |
| 267 | + end |
| 268 | +end |
| 269 | + |
| 270 | +function update_bounds!(containers, n_threads) |
| 271 | + p = containers.p |
| 272 | + |
| 273 | + r1, r2 = double_argmax(p) |
| 274 | + pr1 = p[r1] |
| 275 | + pr2 = p[r2] |
| 276 | + |
| 277 | + if n_threads == 1 |
| 278 | + r = axes(containers.ub, 1) |
| 279 | + chunk_update_bounds!(containers, r, r1, r2, pr1, pr2) |
| 280 | + else |
| 281 | + ncol = length(containers.ub) |
| 282 | + ranges = splitter(ncol, n_threads) |
| 283 | + |
| 284 | + waiting_list = Vector{Task}(undef, n_threads - 1) |
| 285 | + |
| 286 | + for i in 1:n_threads - 1 |
| 287 | + waiting_list[i] = @spawn chunk_update_bounds!(containers, ranges[i], r1, r2, pr1, pr2) |
| 288 | + end |
| 289 | + |
| 290 | + chunk_update_bounds!(containers, ranges[end], r1, r2, pr1, pr2) |
| 291 | + |
| 292 | + for i in 1:n_threads - 1 |
| 293 | + wait(waiting_list[i]) |
| 294 | + end |
| 295 | + end |
| 296 | +end |
| 297 | + |
| 298 | +function chunk_update_bounds!(containers, r, r1, r2, pr1, pr2) |
| 299 | + p = containers.p |
| 300 | + ub = containers.ub |
| 301 | + lb = containers.lb |
| 302 | + labels = containers.labels |
| 303 | + |
| 304 | + @inbounds for i in r |
| 305 | + label = labels[i] |
| 306 | + ub[i] += 2*sqrt(ub[i] * p[label]) + p[label] |
| 307 | + if r1 == label |
| 308 | + lb[i] += pr2 - 2*sqrt(pr2*lb[i]) |
| 309 | + else |
| 310 | + lb[i] += pr1 - 2*sqrt(pr1*lb[i]) |
| 311 | + end |
| 312 | + end |
| 313 | +end |
| 314 | + |
| 315 | +function double_argmax(p) |
| 316 | + r1, r2 = 1, 1 |
| 317 | + d1 = p[1] |
| 318 | + d2 = -1.0 |
| 319 | + for i in 2:length(p) |
| 320 | + if p[i] > d1 |
| 321 | + r2 = r1 |
| 322 | + r1 = i |
| 323 | + d2 = d1 |
| 324 | + d1 = p[i] |
| 325 | + elseif p[i] > d2 |
| 326 | + d2 = p[i] |
| 327 | + r2 = i |
| 328 | + end |
| 329 | + end |
| 330 | + |
| 331 | + r1, r2 |
| 332 | +end |
| 333 | + |
| 334 | +""" |
| 335 | + distance(X1, X2, i1, i2) |
| 336 | +
|
| 337 | +Allocation less calculation of square eucledean distance between vectors X1[:, i1] and X2[:, i2] |
| 338 | +""" |
| 339 | +function distance(X1, X2, i1, i2) |
| 340 | + d = 0.0 |
| 341 | + @inbounds for i in axes(X1, 1) |
| 342 | + d += (X1[i, i1] - X2[i, i2])^2 |
| 343 | + end |
| 344 | + |
| 345 | + return d |
| 346 | +end |
0 commit comments