Skip to content

Commit 087c96e

Browse files
committed
Initial implementation of KMedoids
1 parent 1de5b52 commit 087c96e

File tree

6 files changed

+157
-2
lines changed

6 files changed

+157
-2
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1818
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2020
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
21+
TableDistances = "e5d66e97-8c70-46bb-8b66-04a2d73ad782"
2122
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2223
TransformsBase = "28dd2a49-a57a-4bfb-84ca-1a49db9b96b8"
2324
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
@@ -37,6 +38,7 @@ PrettyTables = "2"
3738
Random = "1.9"
3839
Statistics = "1.9"
3940
StatsBase = "0.33, 0.34"
41+
TableDistances = "1.0"
4042
Tables = "1.6"
4143
TransformsBase = "1.5"
4244
Unitful = "1.17"

src/TableTransforms.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ module TableTransforms
66

77
using Tables
88
using Unitful
9-
using Statistics
109
using PrettyTables
1110
using AbstractTrees
12-
using LinearAlgebra
11+
using TableDistances
1312
using DataScienceTraits
1413
using CategoricalArrays
14+
using LinearAlgebra
15+
using Statistics
1516
using Random
1617
using CoDa
1718

@@ -90,6 +91,7 @@ export
9091
DRS,
9192
SDS,
9293
ProjectionPursuit,
94+
KMedoids,
9395
Closure,
9496
Remainder,
9597
Compose,

src/transforms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ include("transforms/quantile.jl")
286286
include("transforms/functional.jl")
287287
include("transforms/eigenanalysis.jl")
288288
include("transforms/projectionpursuit.jl")
289+
include("transforms/kmedoids.jl")
289290
include("transforms/closure.jl")
290291
include("transforms/remainder.jl")
291292
include("transforms/compose.jl")

src/transforms/kmedoids.jl

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
5+
"""
6+
KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng())
7+
8+
Assign labels to rows of table using the `k`-medoids algorithm.
9+
10+
The iterative algorithm is interrupted if the relative change of
11+
the average dissimilarity between successive iterations is smaller
12+
than a tolerance `tol` or if the number of iterations exceeds
13+
the maximum number of iterations `maxiter`.
14+
15+
Optionally, specify a dictionary of `weights` for each column to
16+
affect the underlying table distance from TableDistances.jl, and
17+
a random number generator `rng` to obtain reproducible results.
18+
19+
## Examples
20+
21+
```julia
22+
KMedoids(3)
23+
KMedoids(4, maxiter=20)
24+
KMedoids(5, weights=Dict(:col1 => 1.0, :col2 => 2.0))
25+
```
26+
27+
## References
28+
29+
* Kaufman, L. & Rousseeuw, P. J. 1990. [Partitioning Around Medoids (Program PAM)]
30+
(https://onlinelibrary.wiley.com/doi/10.1002/9780470316801.ch2)
31+
32+
* Kaufman, L. & Rousseeuw, P. J. 1991. [Finding Groups in Data: An Introduction to Cluster Analysis]
33+
(https://www.jstor.org/stable/2532178)
34+
"""
35+
struct KMedoids{W,RNG} <: StatelessFeatureTransform
36+
k::Int
37+
tol::Float64
38+
maxiter::Int
39+
weights::W
40+
rng::RNG
41+
end
42+
43+
function KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng())
44+
# sanity checks
45+
@assert k > 0 "number of clusters must be positive"
46+
@assert tol > 0 "tolerance on relative change must be positive"
47+
@assert maxiter > 0 "maximum number of iterations must be positive"
48+
KMedoids(k, tol, maxiter, weights, rng)
49+
end
50+
51+
parameters(transform::KMedoids) = (; k=transform.k)
52+
53+
function applyfeat(transform::KMedoids, feat, prep)
54+
# retrieve parameters
55+
k = transform.k
56+
tol = transform.tol
57+
maxiter = transform.maxiter
58+
weights = transform.weights
59+
rng = transform.rng
60+
61+
# number of observations
62+
nobs = _nrow(feat)
63+
64+
# sanity checks
65+
k > nobs && throw(ArgumentError("requested number of clusters > number of observations"))
66+
67+
# normalize variables
68+
stdfeat = feat |> StdFeats()
69+
70+
# define table distance
71+
td = TableDistance(normalize=false, weights=weights)
72+
73+
# initialize medoids
74+
medoids = sample(rng, 1:nobs, k, replace=false)
75+
76+
# pre-allocate memory for labels and distances
77+
labels = fill(0, nobs)
78+
dists = fill(Inf, nobs)
79+
80+
# main loop
81+
iter = 0
82+
δcur = mean(dists)
83+
while iter < maxiter
84+
# update labels and medoids
85+
_updatelabels!(td, stdfeat, medoids, labels, dists)
86+
_updatemedoids!(td, stdfeat, medoids, labels)
87+
88+
# average dissimilarity
89+
δnew = mean(dists)
90+
91+
# break upon convergence
92+
abs(δnew - δcur) / δcur < tol && break
93+
94+
# update and continue
95+
δcur = δnew
96+
iter += 1
97+
end
98+
99+
newfeat = (; cluster=labels) |> Tables.materializer(feat)
100+
101+
newfeat, nothing
102+
end
103+
104+
function _updatelabels!(td, table, medoids, labels, dists)
105+
for (k, mₖ) in enumerate(medoids)
106+
inds = 1:_nrow(table)
107+
108+
X = Tables.subset(table, inds)
109+
μ = Tables.subset(table, [mₖ])
110+
111+
δ = pairwise(td, X, μ)
112+
113+
@inbounds for i in inds
114+
if δ[i] < dists[i]
115+
dists[i] = δ[i]
116+
labels[i] = k
117+
end
118+
end
119+
end
120+
end
121+
122+
function _updatemedoids!(td, table, medoids, labels)
123+
for k in eachindex(medoids)
124+
inds = findall(isequal(k), labels)
125+
126+
X = Tables.subset(table, inds)
127+
128+
j = _medoid(td, X)
129+
130+
@inbounds medoids[k] = inds[j]
131+
end
132+
end
133+
134+
function _nrow(table)
135+
cols = Tables.columns(table)
136+
vars = Tables.columnnames(cols)
137+
vals = Tables.getcolumn(cols, first(vars))
138+
length(vals)
139+
end
140+
141+
function _medoid(td, table)
142+
Δ = pairwise(td, table)
143+
_, j = findmin(sum, eachcol(Δ))
144+
j
145+
end

test/transforms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ transformfiles = [
3131
"functional.jl",
3232
"eigenanalysis.jl",
3333
"projectionpursuit.jl",
34+
"kmedoids.jl",
3435
"closure.jl",
3536
"remainder.jl",
3637
"compose.jl",

test/transforms/kmedoids.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@testset "KMedoids" begin
2+
@test !isrevertible(KMedoids(3))
3+
@test TT.parameters(KMedoids(3)) == (k=3,)
4+
end

0 commit comments

Comments
 (0)