Skip to content

Commit 0e275c4

Browse files
authored
Merge pull request #44 from PyDataBlog/mljinterface
Mljinterface draft
2 parents 2a564dc + 7cd46e6 commit 0e275c4

File tree

2 files changed

+76
-20
lines changed

2 files changed

+76
-20
lines changed

src/mlj_interface.jl

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,87 @@
11
# TODO 1: a using MLJModelInterface or import MLJModelInterface statement
2+
# Expose all instances of user specified structs and package artifcats.
3+
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
4+
5+
# availalbe variants for reference
6+
const MLJDICT = Dict(:Lloyd => Lloyd(),
7+
:Hamerly => Hamerly(),
8+
:LightElkan => LightElkan())
29

310
####
411
#### MODEL DEFINITION
512
####
613
# TODO 2: MLJ-compatible model types and constructors
7-
@mlj_model mutable struct KMeans <: MLJModelInterface.Unsupervised
8-
# Hyperparameters of the model
9-
algo::Symbol = :Lloyd::(_ in (:Lloyd, :Hamerly, :LightElkan))
10-
k_init::String = "k-means++"::(_ in ("k-means++", String)) # allow user seeding?
11-
k::Int = 3::(_ > 0)
12-
tol::Float64 = 1e-6::(_ < 1)
13-
max_iters::Int = 300::(_ > 0)
14-
copy::Bool = true
15-
threads::Int = Threads.nthreads()::(_ > 0)
16-
verbosity::Int = 0::(_ in (0, 1)) # Temp fix. Do we need to follow mlj verbosity style?
17-
init = nothing
14+
15+
mutable struct KMeans <: MLJModelInterface.Unsupervised
16+
algo::Symbol
17+
k_init::String
18+
k::Int
19+
tol::Float64
20+
max_iters::Int
21+
copy::Bool
22+
threads::Int
23+
verbosity::Int
24+
init
1825
end
1926

2027

21-
# Expose all instances of user specified structs and package artifcats.
22-
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
28+
function KMeans(; algo=:Lloyd, k_init="k-means++",
29+
k=3, tol=1e-6, max_iters=300, copy=true,
30+
threads=Threads.nthreads(), verbosity=0, init=nothing)
31+
32+
model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, verbosity, init)
33+
message = MLJModelInterface.clean!(model)
34+
isempty(message) || @warn message
35+
return model
36+
end
37+
38+
39+
function MLJModelInterface.clean!(m::KMeans)
40+
warning = ""
41+
42+
if !(m.algo keys(MLJDICT))
43+
warning *= "Unsuppored algorithm supplied. Defauting to KMeans++ seeding algorithm."
44+
m.algo = :Lloyd
45+
46+
elseif m.k_init != "k-means++"
47+
warning *= "Only `k-means++` or random seeding algorithms are supported. Defaulting to random seeding."
48+
m.k_init = "random"
49+
50+
elseif m.k < 1
51+
warning *= "Number of clusters must be greater than 0. Defaulting to 3 clusters."
52+
m.k = 3
53+
54+
elseif !(m.tol < 1.0)
55+
warning *= "Tolerance level must be less than 1. Defaulting to tol of 1e-6."
56+
m.tol = 1e-6
57+
58+
elseif !(m.max_iters > 0)
59+
warning *= "Number of permitted iterations must be greater than 0. Defaulting to 300 iterations."
60+
m.max_iters = 300
61+
62+
elseif !(m.threads > 0)
63+
warning *= "Number of threads must be at least 1. Defaulting to all threads available."
64+
m.threads = Threads.nthreads()
65+
66+
elseif !(m.verbosity (0, 1))
67+
warning *= "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 0."
68+
m.verbosity = 0
69+
end
70+
return warning
71+
end
2372

24-
# availalbe variants for reference
25-
const MLJDICT = Dict(:Lloyd => Lloyd(),
26-
:Hamerly => Hamerly(),
27-
:LightElkan => LightElkan())
2873

2974
# TODO 3: implementation of fit, predict, and fitted_params of the model
3075
####
3176
#### FIT FUNCTION
3277
####
3378
"""
3479
TODO 3.1: Docs
80+
# fit the specified struct as a ParaKMeans model
3581
3682
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
3783
"""
3884
function MLJModelInterface.fit(m::KMeans, X)
39-
# fit the specified struct as a ParaKMeans model
40-
4185
# convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
4286
if !m.copy
4387
# transpose input table without copying and pass to model
@@ -123,4 +167,4 @@ metadata_model(KMeans,
123167
output = MLJModelInterface.Table(MLJModelInterface.Count),
124168
weights = false,
125169
descr = ParallelKMeans_Desc,
126-
path = "ParallelKMeans.src.mlj_interface.KMeans")
170+
path = "ParallelKMeans.KMeans")

test/test07_mlj_interface.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TestMLJInterface
22

33
using ParallelKMeans
4+
using ParallelKMeans: KMeans
45
using Random
56
using Test
67
using Suppressor
@@ -22,6 +23,17 @@ using MLJBase
2223
end
2324

2425

26+
@testset "Test bad struct warings" begin
27+
@test_logs (:warn, "Unsuppored algorithm supplied. Defauting to KMeans++ seeding algorithm.") ParallelKMeans.KMeans(algo=:Fake)
28+
@test_logs (:warn, "Only `k-means++` or random seeding algorithms are supported. Defaulting to random seeding.") ParallelKMeans.KMeans(k_init="abc")
29+
@test_logs (:warn, "Number of clusters must be greater than 0. Defaulting to 3 clusters.") ParallelKMeans.KMeans(k=0)
30+
@test_logs (:warn, "Tolerance level must be less than 1. Defaulting to tol of 1e-6.") ParallelKMeans.KMeans(tol=2)
31+
@test_logs (:warn, "Number of permitted iterations must be greater than 0. Defaulting to 300 iterations.") ParallelKMeans.KMeans(max_iters=0)
32+
@test_logs (:warn, "Number of threads must be at least 1. Defaulting to all threads available.") ParallelKMeans.KMeans(threads=0)
33+
@test_logs (:warn, "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 0.") ParallelKMeans.KMeans(verbosity=100)
34+
end
35+
36+
2537
@testset "Test model fitting verbosity" begin
2638
Random.seed!(2020)
2739
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])

0 commit comments

Comments
 (0)