1
- # TODO 1: a using MLJModelInterface or import MLJModelInterface statement
2
1
# Expose all instances of user specified structs and package artifcats.
3
- const ParallelKMeans_Desc = " Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
2
+ const ParallelKMeans_Desc = " Parallel & lightning fast implementation of all available variants of the KMeans clustering algorithm
3
+ in native Julia. Compatible with Julia 1.3+"
4
4
5
5
# availalbe variants for reference
6
6
const MLJDICT = Dict (:Lloyd => Lloyd (),
@@ -10,7 +10,6 @@ const MLJDICT = Dict(:Lloyd => Lloyd(),
10
10
# ###
11
11
# ### MODEL DEFINITION
12
12
# ###
13
- # TODO 2: MLJ-compatible model types and constructors
14
13
15
14
mutable struct KMeans <: MLJModelInterface.Unsupervised
16
15
algo:: Symbol
@@ -40,7 +39,7 @@ function MLJModelInterface.clean!(m::KMeans)
40
39
warning = " "
41
40
42
41
if ! (m. algo ∈ keys (MLJDICT))
43
- warning *= " Unsuppored algorithm supplied. Defauting to KMeans++ seeding algorithm."
42
+ warning *= " Unsupported KMeans variant, Defauting to KMeans++ seeding algorithm."
44
43
m. algo = :Lloyd
45
44
46
45
elseif m. k_init != " k-means++"
@@ -71,24 +70,22 @@ function MLJModelInterface.clean!(m::KMeans)
71
70
end
72
71
73
72
74
- # TODO 3: implementation of fit, predict, and fitted_params of the model
75
73
# ###
76
74
# ### FIT FUNCTION
77
75
# ###
78
76
"""
79
- TODO 3.1: Docs
80
- # fit the specified struct as a ParaKMeans model
77
+ Fit the specified ParaKMeans model constructed by the user.
81
78
82
79
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
83
80
"""
84
81
function MLJModelInterface. fit (m:: KMeans , X)
85
82
# convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
86
83
if ! m. copy
87
- # transpose input table without copying and pass to model
88
- DMatrix = convert (Array{Float64, 2 }, X)'
84
+ # permutes dimensions of input table without copying and pass to model
85
+ DMatrix = convert (Array{Float64, 2 }, MLJModelInterface . matrix ( X)' )
89
86
else
90
- # tranposes input table as a column major matrix after making a copy of the data
91
- DMatrix = MLJModelInterface. matrix (X; transpose= true )
87
+ # permutes dimensions of input table as a column major matrix from a copy of the data
88
+ DMatrix = convert (Array{Float64, 2 }, MLJModelInterface. matrix (X, transpose= true ) )
92
89
end
93
90
94
91
# lookup available algorithms
@@ -109,9 +106,6 @@ function MLJModelInterface.fit(m::KMeans, X)
109
106
end
110
107
111
108
112
- """
113
- TODO 3.2: Docs
114
- """
115
109
function MLJModelInterface. fitted_params (model:: KMeans , fitresult)
116
110
# extract what's relevant from `fitresult`
117
111
results, _, _ = fitresult # unpack fitresult
@@ -129,15 +123,26 @@ end
129
123
# ###
130
124
# ### PREDICT FUNCTION
131
125
# ###
132
- """
133
- TODO 3.3: Docs
134
- """
126
+
135
127
function MLJModelInterface. transform (m:: KMeans , fitresult, Xnew)
136
128
# make predictions/assignments using the learned centroids
129
+
130
+ if ! m. copy
131
+ # permutes dimensions of input table without copying and pass to model
132
+ DMatrix = convert (Array{Float64, 2 }, MLJModelInterface. matrix (Xnew)' )
133
+ else
134
+ # permutes dimensions of input table as a column major matrix from a copy of the data
135
+ DMatrix = convert (Array{Float64, 2 }, MLJModelInterface. matrix (Xnew, transpose= true ))
136
+ end
137
+
138
+ # TODO : Warn users if fitresult is from a `non-converged` fit?
139
+ if ! fitresult[end ]. converged
140
+ @warn " Failed to converged. Using last assignments to make transformations."
141
+ end
142
+
143
+ # results from fitted model
137
144
results = fitresult[1 ]
138
- DMatrix = MLJModelInterface. matrix (Xnew, transpose= true )
139
145
140
- # TODO 3.3.1: Warn users if fitresult is from a `non-converged` fit.
141
146
# use centroid matrix to assign clusters for new data
142
147
centroids = results. centers
143
148
distances = Distances. pairwise (Distances. SqEuclidean (), DMatrix, centroids; dims= 2 )
@@ -153,12 +158,11 @@ end
153
158
# TODO 4: metadata for the package and for each of the model interfaces
154
159
metadata_pkg .(KMeans,
155
160
name = " ParallelKMeans" ,
156
- uuid = " 42b8e9d4-006b-409a-8472-7f34b3fb58af" , # see your Project.toml
157
- url = " https://github.com/PyDataBlog/ParallelKMeans.jl" , # URL to your package repo
158
- julia = true , # is it written entirely in Julia?
159
- license = " MIT" , # your package license
160
- is_wrapper = false , # does it wrap around some other package?
161
- )
161
+ uuid = " 42b8e9d4-006b-409a-8472-7f34b3fb58af" ,
162
+ url = " https://github.com/PyDataBlog/ParallelKMeans.jl" ,
163
+ julia = true ,
164
+ license = " MIT" ,
165
+ is_wrapper = false )
162
166
163
167
164
168
# Metadata for ParaKMeans model interface
0 commit comments