Skip to content

Commit 591fa45

Browse files
committed
Merge pull request #76 from perlfly/K-Means
Introduce associated labels to centroids
2 parents 13e6956 + d344bec commit 591fa45

File tree

3 files changed

+67
-38
lines changed

3 files changed

+67
-38
lines changed

Diff for: K-Means/KMeans.swift

+38-20
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,24 @@
55

66
import Foundation
77

8-
class KMeans {
8+
class KMeans<Label: Hashable> {
99
let numCenters: Int
10-
let convergeDist: Double
10+
let labels: Array<Label>
11+
private(set) var centroids: Array<Vector>
1112

12-
init(numCenters: Int, convergeDist: Double) {
13-
self.numCenters = numCenters
14-
self.convergeDist = convergeDist
13+
init(labels: Array<Label>) {
14+
assert(labels.count > 1, "Exception: KMeans with less than 2 centers.")
15+
self.labels = labels
16+
self.numCenters = labels.count
17+
centroids = []
1518
}
1619

17-
private func nearestCenter(x: Vector, centers: [Vector]) -> Int {
20+
private func nearestCenterIndex(x: Vector, centers: [Vector]) -> Int {
1821
var nearestDist = DBL_MAX
1922
var minIndex = 0
20-
21-
for (idx, c) in centers.enumerate() {
22-
let dist = x.distTo(c)
23+
24+
for (idx, center) in centers.enumerate() {
25+
let dist = x.distTo(center)
2326
if dist < nearestDist {
2427
minIndex = idx
2528
nearestDist = dist
@@ -28,24 +31,26 @@ class KMeans {
2831
return minIndex
2932
}
3033

31-
func findCenters(points: [Vector]) -> [Vector] {
34+
35+
36+
func trainCenters(points: [Vector], convergeDist: Double) {
37+
3238
var centerMoveDist = 0.0
33-
let zeros = [Double](count: points[0].length, repeatedValue: 0.0)
39+
let zeroVector = Vector(d: [Double](count: points[0].length, repeatedValue: 0.0))
3440

3541
var kCenters = reservoirSample(points, k: numCenters)
3642

3743
repeat {
38-
var cnts = [Double](count: numCenters, repeatedValue: 0.0)
39-
var newCenters = [Vector](count:numCenters, repeatedValue: Vector(d:zeros))
40-
44+
45+
var classification: Array<[Vector]> = Array(count: numCenters, repeatedValue: [])
46+
4147
for p in points {
42-
let c = nearestCenter(p, centers: kCenters)
43-
cnts[c] += 1
44-
newCenters[c] += p
48+
let classIndex = nearestCenterIndex(p, centers: kCenters)
49+
classification[classIndex].append(p)
4550
}
4651

47-
for idx in 0..<numCenters {
48-
newCenters[idx] /= cnts[idx]
52+
let newCenters = classification.map { assignedPoints in
53+
assignedPoints.reduce(zeroVector, combine: +) / Double(assignedPoints.count)
4954
}
5055

5156
centerMoveDist = 0.0
@@ -56,7 +61,20 @@ class KMeans {
5661
kCenters = newCenters
5762
} while centerMoveDist > convergeDist
5863

59-
return kCenters
64+
centroids = kCenters
65+
}
66+
67+
func fit(point: Vector) -> Label {
68+
assert(!centroids.isEmpty, "Exception: KMeans tried to fit on a non trained model.")
69+
70+
let centroidIndex = nearestCenterIndex(point, centers: centroids)
71+
return labels[centroidIndex]
72+
}
73+
74+
func fit(points: [Vector]) -> [Label] {
75+
assert(!centroids.isEmpty, "Exception: KMeans tried to fit on a non trained model.")
76+
77+
return points.map(fit)
6078
}
6179
}
6280

Diff for: K-Means/Tests/KMeansTests.swift

+28-17
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,66 @@ import Foundation
1010
import XCTest
1111

1212
class KMeansTests: XCTestCase {
13-
var points = [Vector]()
1413

15-
func genPoints(numPoints:Int, numDimmensions:Int) {
14+
func genPoints(numPoints:Int, numDimmensions:Int) -> [Vector] {
15+
var points = [Vector]()
16+
1617
for _ in 0..<numPoints {
1718
var data = [Double]()
1819
for _ in 0..<numDimmensions {
1920
data.append(Double(arc4random_uniform(UInt32(numPoints*numDimmensions))))
2021
}
2122
points.append(Vector(d: data))
2223
}
23-
24+
return points
2425
}
2526

2627
func testSmall_2D() {
27-
genPoints(10, numDimmensions: 2)
28-
28+
let points = genPoints(10, numDimmensions: 2)
29+
2930
print("\nCenters")
30-
let kmm = KMeans(numCenters: 3, convergeDist: 0.01)
31-
for c in kmm.findCenters(points) {
32-
print(c)
31+
let kmm = KMeans<Character>(labels: ["A", "B", "C"])
32+
kmm.trainCenters(points, convergeDist: 0.01)
33+
34+
for (label, centroid) in zip(kmm.labels, kmm.centroids) {
35+
print("\(label): \(centroid)")
36+
}
37+
38+
print("\nClassifications")
39+
for (label, point) in zip(kmm.fit(points), points) {
40+
print("\(label): \(point)")
3341
}
3442
}
3543

3644
func testSmall_10D() {
37-
genPoints(10, numDimmensions: 10)
45+
let points = genPoints(10, numDimmensions: 10)
3846

3947
print("\nCenters")
40-
let kmm = KMeans(numCenters: 3, convergeDist: 0.01)
41-
for c in kmm.findCenters(points) {
48+
let kmm = KMeans<Int>(labels: [1, 2, 3])
49+
kmm.trainCenters(points, convergeDist: 0.01)
50+
for c in kmm.centroids {
4251
print(c)
4352
}
4453
}
4554

4655
func testLarge_2D() {
47-
genPoints(10000, numDimmensions: 2)
56+
let points = genPoints(10000, numDimmensions: 2)
4857

4958
print("\nCenters")
50-
let kmm = KMeans(numCenters: 5, convergeDist: 0.01)
51-
for c in kmm.findCenters(points) {
59+
let kmm = KMeans<Character>(labels: ["A","B","C","D","E"])
60+
kmm.trainCenters(points, convergeDist: 0.01)
61+
for c in kmm.centroids {
5262
print(c)
5363
}
5464
}
5565

5666
func testLarge_10D() {
57-
genPoints(10000, numDimmensions: 10)
67+
let points = genPoints(10000, numDimmensions: 10)
5868

5969
print("\nCenters")
60-
let kmm = KMeans(numCenters: 5, convergeDist: 0.01)
61-
for c in kmm.findCenters(points) {
70+
let kmm = KMeans<Int>(labels: [1,2,3,4,5])
71+
kmm.trainCenters(points, convergeDist: 0.01)
72+
for c in kmm.centroids {
6273
print(c)
6374
}
6475
}

Diff for: K-Means/Tests/Tests.xcodeproj/project.pbxproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
/* End PBXBuildFile section */
1414

1515
/* Begin PBXFileReference section */
16-
B80894DB1C852CFA0018730E /* KMeans.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; name = KMeans.swift; path = ../KMeans.swift; sourceTree = SOURCE_ROOT; };
16+
B80894DB1C852CFA0018730E /* KMeans.swift */ = {isa = PBXFileReference; indentWidth = 2; lastKnownFileType = sourcecode.swift; name = KMeans.swift; path = ../KMeans.swift; sourceTree = SOURCE_ROOT; tabWidth = 2; };
1717
B80894E01C852D100018730E /* Tests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = Tests.xctest; sourceTree = BUILT_PRODUCTS_DIR; };
1818
B80894E31C852D100018730E /* KMeansTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = KMeansTests.swift; sourceTree = SOURCE_ROOT; };
1919
B80894E51C852D100018730E /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist; path = Info.plist; sourceTree = "<group>"; };

0 commit comments

Comments
 (0)