Skip to content

Commit ebb7b35

Browse files
committed
Improved tests & fixes
1 parent f2b8276 commit ebb7b35

File tree

4 files changed

+43
-12
lines changed

4 files changed

+43
-12
lines changed

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://PyDataBlog.github.io/ParallelKMeans.jl/stable)
44
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://PyDataBlog.github.io/ParallelKMeans.jl/dev)
5-
[![Build Status](https://www.travis-ci.org/PyDataBlog/ParallelKMeans.jl.svg?branch=master)](https://www.travis-ci.org/PyDataBlog/ParallelKMeans.jl)
6-
[![Coverage Status](https://coveralls.io/repos/github/PyDataBlog/ParallelKMeans.jl/badge.svg?branch=master)](https://coveralls.io/github/PyDataBlog/ParallelKMeans.jl?branch=master)
5+
[![Build Status](https://github.com/PyDataBlog/ParallelKMeans.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/PyDataBlog/ParallelKMeans.jl/actions/workflows/CI.yml/badge.svg)
6+
[![codecov](https://codecov.io/gh/PyDataBlog/ParallelKMeans.jl/branch/master/graph/badge.svg?token=799USS6BPH)](https://codecov.io/gh/PyDataBlog/ParallelKMeans.jl)
77
[![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2FPyDataBlog%2FParallelKMeans.jl.svg?type=shield)](https://app.fossa.com/projects/git%2Bgithub.com%2FPyDataBlog%2FParallelKMeans.jl?ref=badge_shield)
88
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/PyDataBlog/ParallelKMeans.jl/master)
99
_________________________________________________________________________________________________________
@@ -22,10 +22,13 @@ ________________________________________________________________________________
2222

2323
## Table Of Content
2424

25-
1. [Documentation](#Documentation)
26-
2. [Installation](#Installation)
27-
3. [Features](#Features)
28-
4. [License](#License)
25+
- [ParallelKMeans](#parallelkmeans)
26+
- [Table Of Content](#table-of-content)
27+
- [Documentation](#documentation)
28+
- [Installation](#installation)
29+
- [Features](#features)
30+
- [Benchmarks](#benchmarks)
31+
- [License](#license)
2932

3033
_________________________________________________________________________________________________________
3134

src/mini_batch.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function kmeans!(alg::MiniBatch, containers, X, k,
3434

3535
# Initialize nearest centers for both batch and whole dataset labels
3636
converged = false
37-
niters = 0
37+
niters = 1
3838
counter = 0
3939
J_previous = zero(T)
4040
J = zero(T)
@@ -99,19 +99,25 @@ function kmeans!(alg::MiniBatch, containers, X, k,
9999
# Compute totalcost for the complete dataset
100100
@parallelize 1 ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
101101
totalcost = sum(containers.sum_of_squares)
102+
103+
# Print convergence message to user
104+
if verbose
105+
println("Successfully terminated with convergence.")
106+
end
107+
102108
break
103109
end
104110
else
105111
counter = 0
106112
end
107113

108114
# Warn users if model doesn't converge at max iterations
109-
if (niters > max_iters) & (!converged)
115+
if (niters >= max_iters) & (!converged)
110116

111117
if verbose
112118
println("Clustering model failed to converge. Labelling data with latest centroids.")
113119
end
114-
containers.labels = reassign_labels(X, metric, containers.labels, centroids)
120+
containers.labels .= reassign_labels(X, metric, containers.labels, centroids)
115121

116122
# Compute totalcost for unconverged model
117123
@parallelize 1 ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)

test/test70_verbose.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,16 @@ end
4848

4949
# Capture output and compare
5050
r = @capture_out kmeans(Coreset(), X, 3; n_threads=1, max_iters=1, verbose=true, rng = rng)
51-
# This test is broken on 1.5 dev, see https://github.com/rfourquet/StableRNGs.jl/issues/3
52-
# @test startswith(r, "Iteration 1: Jclust = 32.8028409136")
51+
@test startswith(r, "Iteration 1: Jclust = 32.8028409136")
52+
end
53+
54+
@testset "MiniBatch: Testing verbosity of implementation" begin
55+
rng = StableRNG(2020)
56+
X = rand(rng, 3, 100)
57+
58+
# Capture output and compare
59+
r = @capture_out kmeans(MiniBatch(10), X, 2; n_threads=1, max_iters=1, verbose=true, rng = rng)
60+
@test startswith(r, "Iteration 1: Jclust = 18.298067523612104")
5361
end
5462

5563
end # module

test/test90_minibatch.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Test
55
using StableRNGs
66
using StatsBase
77
using Distances
8+
using Suppressor
89

910

1011
@testset "MiniBatch default batch size" begin
@@ -15,14 +16,27 @@ end
1516
@testset "MiniBatch convergence" begin
1617
rng = StableRNG(2020)
1718
X = rand(rng, 3, 100)
19+
rng_orig = deepcopy(rng)
1820

21+
# use lloyd as baseline
1922
baseline = [kmeans(Lloyd(), X, 2; max_iters=100_000).totalcost for i in 1:200] |> mean |> round
20-
23+
# mini batch results
2124
res = [kmeans(MiniBatch(10), X, 2; max_iters=100_000).totalcost for i in 1:200] |> mean |> round
25+
# test for verbosity with convergence
26+
r = @capture_out kmeans(MiniBatch(50), X, 2; max_iters=2_000, verbose=true, rng=rng_orig)
2227

2328
@test baseline == res
29+
@test endswith(r, "Successfully terminated with convergence.\n")
2430
end
2531

32+
@testset "MiniBatch non-convergence warning" begin
33+
rng = StableRNG(2020)
34+
X = rand(rng, 3, 100)
35+
36+
# Capture output and compare
37+
r = @capture_out kmeans(MiniBatch(10), X, 2; n_threads=1, max_iters=10, verbose=true, rng = rng)
38+
@test endswith(r, "Clustering model failed to converge. Labelling data with latest centroids.\n")
39+
end
2640

2741
@testset "MiniBatch metric support" begin
2842
rng = StableRNG(2020)

0 commit comments

Comments
 (0)