Skip to content

Commit f5d5ced

Browse files
committed
refactored MLJ interface based on feedback
1 parent 0e275c4 commit f5d5ced

File tree

2 files changed

+42
-26
lines changed

2 files changed

+42
-26
lines changed

src/mlj_interface.jl

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# TODO 1: a using MLJModelInterface or import MLJModelInterface statement
21
# Expose all instances of user specified structs and package artifcats.
3-
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
2+
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all available variants of the KMeans clustering algorithm
3+
in native Julia. Compatible with Julia 1.3+"
44

55
# availalbe variants for reference
66
const MLJDICT = Dict(:Lloyd => Lloyd(),
@@ -10,7 +10,6 @@ const MLJDICT = Dict(:Lloyd => Lloyd(),
1010
####
1111
#### MODEL DEFINITION
1212
####
13-
# TODO 2: MLJ-compatible model types and constructors
1413

1514
mutable struct KMeans <: MLJModelInterface.Unsupervised
1615
algo::Symbol
@@ -40,7 +39,7 @@ function MLJModelInterface.clean!(m::KMeans)
4039
warning = ""
4140

4241
if !(m.algo keys(MLJDICT))
43-
warning *= "Unsuppored algorithm supplied. Defauting to KMeans++ seeding algorithm."
42+
warning *= "Unsupported KMeans variant, Defauting to KMeans++ seeding algorithm."
4443
m.algo = :Lloyd
4544

4645
elseif m.k_init != "k-means++"
@@ -71,24 +70,22 @@ function MLJModelInterface.clean!(m::KMeans)
7170
end
7271

7372

74-
# TODO 3: implementation of fit, predict, and fitted_params of the model
7573
####
7674
#### FIT FUNCTION
7775
####
7876
"""
79-
TODO 3.1: Docs
80-
# fit the specified struct as a ParaKMeans model
77+
Fit the specified ParaKMeans model constructed by the user.
8178
8279
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
8380
"""
8481
function MLJModelInterface.fit(m::KMeans, X)
8582
# convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
8683
if !m.copy
87-
# transpose input table without copying and pass to model
88-
DMatrix = convert(Array{Float64, 2}, X)'
84+
# permutes dimensions of input table without copying and pass to model
85+
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(X)')
8986
else
90-
# tranposes input table as a column major matrix after making a copy of the data
91-
DMatrix = MLJModelInterface.matrix(X; transpose=true)
87+
# permutes dimensions of input table as a column major matrix from a copy of the data
88+
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(X, transpose=true))
9289
end
9390

9491
# lookup available algorithms
@@ -109,9 +106,6 @@ function MLJModelInterface.fit(m::KMeans, X)
109106
end
110107

111108

112-
"""
113-
TODO 3.2: Docs
114-
"""
115109
function MLJModelInterface.fitted_params(model::KMeans, fitresult)
116110
# extract what's relevant from `fitresult`
117111
results, _, _ = fitresult # unpack fitresult
@@ -129,15 +123,26 @@ end
129123
####
130124
#### PREDICT FUNCTION
131125
####
132-
"""
133-
TODO 3.3: Docs
134-
"""
126+
135127
function MLJModelInterface.transform(m::KMeans, fitresult, Xnew)
136128
# make predictions/assignments using the learned centroids
129+
130+
if !m.copy
131+
# permutes dimensions of input table without copying and pass to model
132+
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(Xnew)')
133+
else
134+
# permutes dimensions of input table as a column major matrix from a copy of the data
135+
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(Xnew, transpose=true))
136+
end
137+
138+
# TODO: Warn users if fitresult is from a `non-converged` fit?
139+
if !fitresult[end].converged
140+
@warn "Failed to converged. Using last assignments to make transformations."
141+
end
142+
143+
# results from fitted model
137144
results = fitresult[1]
138-
DMatrix = MLJModelInterface.matrix(Xnew, transpose=true)
139145

140-
# TODO 3.3.1: Warn users if fitresult is from a `non-converged` fit.
141146
# use centroid matrix to assign clusters for new data
142147
centroids = results.centers
143148
distances = Distances.pairwise(Distances.SqEuclidean(), DMatrix, centroids; dims=2)
@@ -153,12 +158,11 @@ end
153158
# TODO 4: metadata for the package and for each of the model interfaces
154159
metadata_pkg.(KMeans,
155160
name = "ParallelKMeans",
156-
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af", # see your Project.toml
157-
url = "https://github.com/PyDataBlog/ParallelKMeans.jl", # URL to your package repo
158-
julia = true, # is it written entirely in Julia?
159-
license = "MIT", # your package license
160-
is_wrapper = false, # does it wrap around some other package?
161-
)
161+
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af",
162+
url = "https://github.com/PyDataBlog/ParallelKMeans.jl",
163+
julia = true,
164+
license = "MIT",
165+
is_wrapper = false)
162166

163167

164168
# Metadata for ParaKMeans model interface

test/test07_mlj_interface.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ end
2424

2525

2626
@testset "Test bad struct warings" begin
27-
@test_logs (:warn, "Unsuppored algorithm supplied. Defauting to KMeans++ seeding algorithm.") ParallelKMeans.KMeans(algo=:Fake)
27+
@test_logs (:warn, "Unsupported KMeans variant, Defauting to KMeans++ seeding algorithm.") ParallelKMeans.KMeans(algo=:Fake)
2828
@test_logs (:warn, "Only `k-means++` or random seeding algorithms are supported. Defaulting to random seeding.") ParallelKMeans.KMeans(k_init="abc")
2929
@test_logs (:warn, "Number of clusters must be greater than 0. Defaulting to 3 clusters.") ParallelKMeans.KMeans(k=0)
3030
@test_logs (:warn, "Tolerance level must be less than 1. Defaulting to tol of 1e-6.") ParallelKMeans.KMeans(tol=2)
@@ -121,4 +121,16 @@ end
121121
@test preds[:x1][1] == 2
122122
end
123123

124+
@testset "Testing " begin
125+
Random.seed!(2020)
126+
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
127+
X_test = table([10 1])
128+
129+
model = KMeans(k=2, max_iters=1)
130+
results = fit(model, X)
131+
132+
@test_logs (:warn, "Failed to converged. Using last assignments to make transformations.") transform(model, results, X_test)
133+
end
134+
124135
end # module
136+

0 commit comments

Comments
 (0)