|
| 1 | + |
| 2 | +# All Abstract types defined |
| 3 | +""" |
| 4 | + AbstractKMeansAlg |
| 5 | +
|
| 6 | +Abstract base type inherited by all sub-KMeans algorithms. |
| 7 | +""" |
| 8 | +abstract type AbstractKMeansAlg end |
| 9 | + |
| 10 | + |
| 11 | +""" |
| 12 | + ClusteringResult |
| 13 | +
|
| 14 | +Base type for the output of clustering algorithm. |
| 15 | +""" |
| 16 | +abstract type ClusteringResult end |
| 17 | + |
| 18 | + |
| 19 | +# Here we mimic `Clustering` output structure |
| 20 | +""" |
| 21 | + KmeansResult{C,D<:Real,WC<:Real} <: ClusteringResult |
| 22 | +
|
| 23 | +The output of [`kmeans`](@ref) and [`kmeans!`](@ref). |
| 24 | +# Type parameters |
| 25 | + * `C<:AbstractMatrix{<:AbstractFloat}`: type of the `centers` matrix |
| 26 | + * `D<:Real`: type of the assignment cost |
| 27 | + * `WC<:Real`: type of the cluster weight |
| 28 | + # C is the type of centers, an (abstract) matrix of size (d x k) |
| 29 | +# D is the type of pairwise distance computation from points to cluster centers |
| 30 | +# WC is the type of cluster weights, either Int (in the case where points are |
| 31 | +# unweighted) or eltype(weights) (in the case where points are weighted). |
| 32 | +""" |
| 33 | +struct KmeansResult{C<:AbstractMatrix{<:AbstractFloat},D<:Real,WC<:Real} <: ClusteringResult |
| 34 | + centers::C # cluster centers (d x k) |
| 35 | + assignments::Vector{Int} # assignments (n) |
| 36 | + costs::Vector{D} # cost of the assignments (n) |
| 37 | + counts::Vector{Int} # number of points assigned to each cluster (k) |
| 38 | + wcounts::Vector{WC} # cluster weights (k) |
| 39 | + totalcost::D # total cost (i.e. objective) |
| 40 | + iterations::Int # number of elapsed iterations |
| 41 | + converged::Bool # whether the procedure converged |
| 42 | +end |
| 43 | + |
| 44 | +""" |
| 45 | + sum_of_squares(x, labels, centre, k) |
| 46 | +
|
| 47 | +This function computes the total sum of squares based on the assigned (labels) |
| 48 | +design matrix(x), centroids (centre), and the number of desired groups (k). |
| 49 | +
|
| 50 | +A Float type representing the computed metric is returned. |
| 51 | +""" |
| 52 | +function sum_of_squares(x, labels, centre) |
| 53 | + s = 0.0 |
| 54 | + |
| 55 | + @inbounds for j in axes(x, 2) |
| 56 | + for i in axes(x, 1) |
| 57 | + s += (x[i, j] - centre[i, labels[j]])^2 |
| 58 | + end |
| 59 | + end |
| 60 | + |
| 61 | + return s |
| 62 | +end |
| 63 | + |
| 64 | + |
| 65 | +""" |
| 66 | + Kmeans([alg::AbstractKMeansAlg,] design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true) |
| 67 | +
|
| 68 | +This main function employs the K-means algorithm to cluster all examples |
| 69 | +in the training data (design_matrix) into k groups using either the |
| 70 | +`k-means++` or random initialisation technique for selecting the initial |
| 71 | +centroids. |
| 72 | +
|
| 73 | +At the end of the number of iterations specified (max_iters), convergence is |
| 74 | +achieved if difference between the current and last cost objective is |
| 75 | +less than the tolerance level (tol). An error is thrown if convergence fails. |
| 76 | +
|
| 77 | +Arguments: |
| 78 | +- `alg` defines one of the algorithms used to calculate `k-means`. This |
| 79 | +argument can be omitted, by default Lloyd algorithm is used. |
| 80 | +- `n_threads` defines number of threads used for calculations, by default it is equal |
| 81 | +to the `Threads.nthreads()` which is defined by `JULIA_NUM_THREADS` environmental |
| 82 | +variable. For small size design matrices it make sense to set this argument to 1 in order |
| 83 | +to avoid overhead of threads generation. |
| 84 | +- `k_init` is one of the algorithms used for initialization. By default `k-means++` algorithm is used, |
| 85 | +alternatively one can use `rand` to choose random points for init. |
| 86 | +- `max_iters` is the maximum number of iterations |
| 87 | +- `tol` defines tolerance for early stopping. |
| 88 | +- `verbose` is verbosity level. Details of operations can be either printed or not by setting verbose accordingly. |
| 89 | +
|
| 90 | +A `KmeansResult` structure representing labels, centroids, and sum_squares is returned. |
| 91 | +""" |
| 92 | +function kmeans(alg, design_matrix, k; |
| 93 | + n_threads = Threads.nthreads(), |
| 94 | + k_init = "k-means++", max_iters = 300, |
| 95 | + tol = 1e-6, verbose = true, init = nothing) |
| 96 | + nrow, ncol = size(design_matrix) |
| 97 | + containers = create_containers(alg, k, nrow, ncol, n_threads) |
| 98 | + |
| 99 | + return kmeans!(alg, containers, design_matrix, k, n_threads = n_threads, |
| 100 | + k_init = k_init, max_iters = max_iters, tol = tol, |
| 101 | + verbose = verbose, init = init) |
| 102 | +end |
| 103 | + |
| 104 | +""" |
| 105 | + Kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true) |
| 106 | +
|
| 107 | +Mutable version of `kmeans` function. Definition of arguments and results can be |
| 108 | +found in `kmeans`. |
| 109 | +
|
| 110 | +Argument `containers` represent algorithm specific containers, such as labels, intermidiate |
| 111 | +centroids and so on, which are used during calculations. |
| 112 | +""" |
| 113 | +function kmeans!(alg, containers, design_matrix, k; |
| 114 | + n_threads = Threads.nthreads(), |
| 115 | + k_init = "k-means++", max_iters = 300, |
| 116 | + tol = 1e-6, verbose = true, init = nothing) |
| 117 | + nrow, ncol = size(design_matrix) |
| 118 | + centroids = init == nothing ? smart_init(design_matrix, k, n_threads, init=k_init).centroids : deepcopy(init) |
| 119 | + |
| 120 | + converged = false |
| 121 | + niters = 1 |
| 122 | + J_previous = 0.0 |
| 123 | + |
| 124 | + # Update centroids & labels with closest members until convergence |
| 125 | + |
| 126 | + while niters <= max_iters |
| 127 | + update_containers!(containers, alg, centroids, n_threads) |
| 128 | + J = update_centroids!(centroids, containers, alg, design_matrix, n_threads) |
| 129 | + |
| 130 | + if verbose |
| 131 | + # Show progress and terminate if J stopped decreasing. |
| 132 | + println("Iteration $iter: Jclust = $J") |
| 133 | + end |
| 134 | + |
| 135 | + # Check for convergence |
| 136 | + if (niters > 1) & (abs(J - J_previous) < (tol * J)) |
| 137 | + converged = true |
| 138 | + break |
| 139 | + end |
| 140 | + |
| 141 | + J_previous = J |
| 142 | + niters += 1 |
| 143 | + end |
| 144 | + |
| 145 | + totalcost = sum_of_squares(design_matrix, containers.labels, centroids) |
| 146 | + |
| 147 | + # Terminate algorithm with the assumption that K-means has converged |
| 148 | + if verbose & converged |
| 149 | + println("Successfully terminated with convergence.") |
| 150 | + end |
| 151 | + |
| 152 | + # TODO empty placeholder vectors should be calculated |
| 153 | + # TODO Float64 type definitions is too restrictive, should be relaxed |
| 154 | + # especially during GPU related development |
| 155 | + return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged) |
| 156 | +end |
| 157 | + |
| 158 | +""" |
| 159 | + update_centroids!(centroids, containers, alg, design_matrix, n_threads) |
| 160 | +
|
| 161 | +Internal function, used to update centroids by utilizing one of `alg`. It works as |
| 162 | +a wrapper of internal `chunk_update_centroids!` function, splitting incoming |
| 163 | +`design_matrix` in chunks and combining results together. |
| 164 | +""" |
| 165 | +function update_centroids!(centroids, containers, alg, design_matrix, n_threads) |
| 166 | + ncol = size(design_matrix, 2) |
| 167 | + |
| 168 | + if n_threads == 1 |
| 169 | + r = axes(design_matrix, 2) |
| 170 | + J = chunk_update_centroids!(centroids, containers, alg, design_matrix, r, 0) |
| 171 | + |
| 172 | + centroids .= containers.new_centroids ./ containers.centroids_cnt' |
| 173 | + else |
| 174 | + ranges = splitter(ncol, n_threads) |
| 175 | + |
| 176 | + waiting_list = Vector{Task}(undef, n_threads - 1) |
| 177 | + |
| 178 | + for i in 1:length(ranges) - 1 |
| 179 | + waiting_list[i] = @spawn chunk_update_centroids!(centroids, containers, |
| 180 | + alg, design_matrix, ranges[i], i + 1) |
| 181 | + end |
| 182 | + |
| 183 | + J = chunk_update_centroids!(centroids, containers, alg, design_matrix, ranges[end], 1) |
| 184 | + |
| 185 | + J += sum(fetch.(waiting_list)) |
| 186 | + |
| 187 | + for i in 1:length(ranges) - 1 |
| 188 | + containers.new_centroids[1] .+= containers.new_centroids[i + 1] |
| 189 | + containers.centroids_cnt[1] .+= containers.centroids_cnt[i + 1] |
| 190 | + end |
| 191 | + |
| 192 | + centroids .= containers.new_centroids[1] ./ containers.centroids_cnt[1]' |
| 193 | + end |
| 194 | + |
| 195 | + return J/ncol |
| 196 | +end |
0 commit comments