Skip to content

Commit d344bec

Browse files
Matteo PiomboMatteo Piombo
Matteo Piombo
authored and
Matteo Piombo
committed
Introduce Associated Labels to centroids.
This could help showing how to use KMeans as a classifier. Some changes to KMeans properties and function. Updated Tests accordingly. Demo the usage of labels in one small test.
1 parent 619e604 commit d344bec

File tree

3 files changed

+67
-38
lines changed

3 files changed

+67
-38
lines changed

Diff for: K-Means/KMeans.swift

+38-20
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,24 @@
55

66
import Foundation
77

8-
class KMeans {
8+
class KMeans<Label: Hashable> {
99
let numCenters: Int
10-
let convergeDist: Double
10+
let labels: Array<Label>
11+
private(set) var centroids: Array<Vector>
1112

12-
init(numCenters: Int, convergeDist: Double) {
13-
self.numCenters = numCenters
14-
self.convergeDist = convergeDist
13+
init(labels: Array<Label>) {
14+
assert(labels.count > 1, "Exception: KMeans with less than 2 centers.")
15+
self.labels = labels
16+
self.numCenters = labels.count
17+
centroids = []
1518
}
1619

17-
private func nearestCenter(x: Vector, centers: [Vector]) -> Int {
20+
private func nearestCenterIndex(x: Vector, centers: [Vector]) -> Int {
1821
var nearestDist = DBL_MAX
1922
var minIndex = 0
20-
21-
for (idx, c) in centers.enumerate() {
22-
let dist = x.distTo(c)
23+
24+
for (idx, center) in centers.enumerate() {
25+
let dist = x.distTo(center)
2326
if dist < nearestDist {
2427
minIndex = idx
2528
nearestDist = dist
@@ -28,24 +31,26 @@ class KMeans {
2831
return minIndex
2932
}
3033

31-
func findCenters(points: [Vector]) -> [Vector] {
34+
35+
36+
func trainCenters(points: [Vector], convergeDist: Double) {
37+
3238
var centerMoveDist = 0.0
33-
let zeros = [Double](count: points[0].length, repeatedValue: 0.0)
39+
let zeroVector = Vector(d: [Double](count: points[0].length, repeatedValue: 0.0))
3440

3541
var kCenters = reservoirSample(points, k: numCenters)
3642

3743
repeat {
38-
var cnts = [Double](count: numCenters, repeatedValue: 0.0)
39-
var newCenters = [Vector](count:numCenters, repeatedValue: Vector(d:zeros))
40-
44+
45+
var classification: Array<[Vector]> = Array(count: numCenters, repeatedValue: [])
46+
4147
for p in points {
42-
let c = nearestCenter(p, centers: kCenters)
43-
cnts[c] += 1
44-
newCenters[c] += p
48+
let classIndex = nearestCenterIndex(p, centers: kCenters)
49+
classification[classIndex].append(p)
4550
}
4651

47-
for idx in 0..<numCenters {
48-
newCenters[idx] /= cnts[idx]
52+
let newCenters = classification.map { assignedPoints in
53+
assignedPoints.reduce(zeroVector, combine: +) / Double(assignedPoints.count)
4954
}
5055

5156
centerMoveDist = 0.0
@@ -56,7 +61,20 @@ class KMeans {
5661
kCenters = newCenters
5762
} while centerMoveDist > convergeDist
5863

59-
return kCenters
64+
centroids = kCenters
65+
}
66+
67+
func fit(point: Vector) -> Label {
68+
assert(!centroids.isEmpty, "Exception: KMeans tried to fit on a non trained model.")
69+
70+
let centroidIndex = nearestCenterIndex(point, centers: centroids)
71+
return labels[centroidIndex]
72+
}
73+
74+
func fit(points: [Vector]) -> [Label] {
75+
assert(!centroids.isEmpty, "Exception: KMeans tried to fit on a non trained model.")
76+
77+
return points.map(fit)
6078
}
6179
}
6280

Diff for: K-Means/Tests/KMeansTests.swift

+28-17
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,66 @@ import Foundation
1010
import XCTest
1111

1212
class KMeansTests: XCTestCase {
13-
var points = [Vector]()
1413

15-
func genPoints(numPoints:Int, numDimmensions:Int) {
14+
func genPoints(numPoints:Int, numDimmensions:Int) -> [Vector] {
15+
var points = [Vector]()
16+
1617
for _ in 0..<numPoints {
1718
var data = [Double]()
1819
for _ in 0..<numDimmensions {
1920
data.append(Double(arc4random_uniform(UInt32(numPoints*numDimmensions))))
2021
}
2122
points.append(Vector(d: data))
2223
}
23-
24+
return points
2425
}
2526

2627
func testSmall_2D() {
27-
genPoints(10, numDimmensions: 2)
28-
28+
let points = genPoints(10, numDimmensions: 2)
29+
2930
print("\nCenters")
30-
let kmm = KMeans(numCenters: 3, convergeDist: 0.01)
31-
for c in kmm.findCenters(points) {
32-
print(c)
31+
let kmm = KMeans<Character>(labels: ["A", "B", "C"])
32+
kmm.trainCenters(points, convergeDist: 0.01)
33+
34+
for (label, centroid) in zip(kmm.labels, kmm.centroids) {
35+
print("\(label): \(centroid)")
36+
}
37+
38+
print("\nClassifications")
39+
for (label, point) in zip(kmm.fit(points), points) {
40+
print("\(label): \(point)")
3341
}
3442
}
3543

3644
func testSmall_10D() {
37-
genPoints(10, numDimmensions: 10)
45+
let points = genPoints(10, numDimmensions: 10)
3846

3947
print("\nCenters")
40-
let kmm = KMeans(numCenters: 3, convergeDist: 0.01)
41-
for c in kmm.findCenters(points) {
48+
let kmm = KMeans<Int>(labels: [1, 2, 3])
49+
kmm.trainCenters(points, convergeDist: 0.01)
50+
for c in kmm.centroids {
4251
print(c)
4352
}
4453
}
4554

4655
func testLarge_2D() {
47-
genPoints(10000, numDimmensions: 2)
56+
let points = genPoints(10000, numDimmensions: 2)
4857

4958
print("\nCenters")
50-
let kmm = KMeans(numCenters: 5, convergeDist: 0.01)
51-
for c in kmm.findCenters(points) {
59+
let kmm = KMeans<Character>(labels: ["A","B","C","D","E"])
60+
kmm.trainCenters(points, convergeDist: 0.01)
61+
for c in kmm.centroids {
5262
print(c)
5363
}
5464
}
5565

5666
func testLarge_10D() {
57-
genPoints(10000, numDimmensions: 10)
67+
let points = genPoints(10000, numDimmensions: 10)
5868

5969
print("\nCenters")
60-
let kmm = KMeans(numCenters: 5, convergeDist: 0.01)
61-
for c in kmm.findCenters(points) {
70+
let kmm = KMeans<Int>(labels: [1,2,3,4,5])
71+
kmm.trainCenters(points, convergeDist: 0.01)
72+
for c in kmm.centroids {
6273
print(c)
6374
}
6475
}

Diff for: K-Means/Tests/Tests.xcodeproj/project.pbxproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
/* End PBXBuildFile section */
1414

1515
/* Begin PBXFileReference section */
16-
B80894DB1C852CFA0018730E /* KMeans.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = KMeans.swift; path = ../KMeans.swift; sourceTree = SOURCE_ROOT; };
16+
B80894DB1C852CFA0018730E /* KMeans.swift */ = {isa = PBXFileReference; indentWidth = 2; lastKnownFileType = sourcecode.swift; name = KMeans.swift; path = ../KMeans.swift; sourceTree = SOURCE_ROOT; tabWidth = 2; };
1717
B80894E01C852D100018730E /* Tests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = Tests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
1818
B80894E31C852D100018730E /* KMeansTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = KMeansTests.swift; sourceTree = SOURCE_ROOT; };
1919
B80894E51C852D100018730E /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist; path = Info.plist; sourceTree = "<group>"; };

0 commit comments

Comments
 (0)