|
187 | 187 | @test yhat[1] == 2
|
188 | 188 | end
|
189 | 189 |
|
| 190 | +@testset "Test MiniBatch model fitting" begin |
| 191 | + rng = StableRNG(2020) |
| 192 | + X = table(rand(rng, 3, 100)') |
| 193 | + X_test = table([0.25 0.17 0.29; 0.52 0.71 0.75]) # similar to first 2 examples |
| 194 | + |
| 195 | + model = KMeans(algo = MiniBatch(50), k=2, rng=rng, max_iters=2_000) |
| 196 | + results, cache, report = fit(model, 0, X) |
| 197 | + |
| 198 | + @test cache == nothing |
| 199 | + @test results.converged |
| 200 | + @test report.totalcost ≈ 18.03007733451847 |
| 201 | + |
| 202 | + params = fitted_params(model, results) |
| 203 | + @test all(params.cluster_centers .≈ [0.39739206832613827 0.4818900563319951; |
| 204 | + 0.7695625526281311 0.30986081763964723; |
| 205 | + 0.6175496080776439 0.3911138270823586]) |
| 206 | + |
| 207 | + # Use trained model to cluster new data X_test |
| 208 | + preds = transform(model, results, X_test) |
| 209 | + @test preds[:x1][1] ≈ 0.48848842207123555 |
| 210 | + @test preds[:x2][1] ≈ 0.08355805256372761 |
| 211 | + |
| 212 | + # Make predictions on new data X_test with fitted params |
| 213 | + yhat = predict(model, results, X_test) |
| 214 | + @test yhat == report.assignments[1:2] |
| 215 | +end |
| 216 | + |
190 | 217 | @testset "Testing weights support" begin
|
191 | 218 | rng = StableRNG(2020)
|
192 | 219 | X = table(rand(rng, 3, 100)')
|
|
210 | 237 | @test_logs (:warn, "Failed to converge. Using last assignments to make transformations.") transform(model, results, X_test)
|
211 | 238 | end
|
212 | 239 |
|
213 |
| -""" |
| 240 | + |
214 | 241 | @testset "Testing non convergence warning during model fitting" begin
|
215 | 242 | Random.seed!(2020)
|
216 | 243 | X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
|
|
219 | 246 | model = KMeans(k=2, max_iters=1)
|
220 | 247 | @test_logs (:warn, "Specified model failed to converge.") fit(model, 1, X);
|
221 | 248 | end
|
222 |
| -""" |
| 249 | + |
223 | 250 | end # module
|
0 commit comments