Skip to content

Commit baa3150

Browse files
committed
Updated MLJ interface with mini batch algorithm
1 parent ebb7b35 commit baa3150

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

src/mlj_interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ const MLJDICT = Dict(:Lloyd => Lloyd(),
1010
:Elkan => Elkan(),
1111
:Yinyang => Yinyang(),
1212
:Coreset => Coreset(),
13-
:阴阳 => Coreset())
13+
:阴阳 => Coreset(),
14+
:MiniBatch => MiniBatch())
1415

1516
####
1617
#### MODEL DEFINITION
@@ -123,12 +124,11 @@ function MMI.fit(m::KMeans, verbosity::Int, X)
123124
totalcost=result.totalcost, assignments=result.assignments, labels=cluster_labels)
124125

125126

126-
"""
127-
# TODO: warn users about non convergence
127+
# Warn users about non convergence
128128
if verbose & (!fitresult.converged)
129129
@warn "Specified model failed to converge."
130130
end
131-
"""
131+
132132
return (fitresult, cache, report)
133133
end
134134

test/test80_mlj_interface.jl

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,33 @@ end
187187
@test yhat[1] == 2
188188
end
189189

190+
@testset "Test MiniBatch model fitting" begin
191+
rng = StableRNG(2020)
192+
X = table(rand(rng, 3, 100)')
193+
X_test = table([0.25 0.17 0.29; 0.52 0.71 0.75]) # similar to first 2 examples
194+
195+
model = KMeans(algo = MiniBatch(50), k=2, rng=rng, max_iters=2_000)
196+
results, cache, report = fit(model, 0, X)
197+
198+
@test cache == nothing
199+
@test results.converged
200+
@test report.totalcost 18.03007733451847
201+
202+
params = fitted_params(model, results)
203+
@test all(params.cluster_centers .≈ [0.39739206832613827 0.4818900563319951;
204+
0.7695625526281311 0.30986081763964723;
205+
0.6175496080776439 0.3911138270823586])
206+
207+
# Use trained model to cluster new data X_test
208+
preds = transform(model, results, X_test)
209+
@test preds[:x1][1] 0.48848842207123555
210+
@test preds[:x2][1] 0.08355805256372761
211+
212+
# Make predictions on new data X_test with fitted params
213+
yhat = predict(model, results, X_test)
214+
@test yhat == report.assignments[1:2]
215+
end
216+
190217
@testset "Testing weights support" begin
191218
rng = StableRNG(2020)
192219
X = table(rand(rng, 3, 100)')
@@ -210,7 +237,7 @@ end
210237
@test_logs (:warn, "Failed to converge. Using last assignments to make transformations.") transform(model, results, X_test)
211238
end
212239

213-
"""
240+
214241
@testset "Testing non convergence warning during model fitting" begin
215242
Random.seed!(2020)
216243
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
@@ -219,5 +246,5 @@ end
219246
model = KMeans(k=2, max_iters=1)
220247
@test_logs (:warn, "Specified model failed to converge.") fit(model, 1, X);
221248
end
222-
"""
249+
223250
end # module

0 commit comments

Comments
 (0)