Skip to content

Commit 1f70bb6

Browse files
committed
Use new vector class
Signed-off-by: Kyle Corry <[email protected]>
1 parent c153ca8 commit 1f70bb6

File tree

6 files changed

+71
-81
lines changed

6 files changed

+71
-81
lines changed

src/main/kotlin/com/kylecorry/sol/math/Vector.kt

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
package com.kylecorry.sol.math
22

3+
import com.kylecorry.sol.math.algebra.Matrix
4+
import kotlin.collections.sumOf
5+
import kotlin.math.sqrt
6+
37
@JvmInline
48
value class Vector(val data: FloatArray) {
59

610
val n: Int
711
get() = data.size
812

13+
val size: Int
14+
get() = data.size
15+
916
operator fun get(index: Int): Float {
1017
return data[index]
1118
}
1219

20+
operator fun set(index: Int, value: Float) {
21+
data[index] = value
22+
}
23+
1324
operator fun times(scalar: Float): Vector {
1425
return Vector(FloatArray(n) { i -> data[i] * scalar })
1526
}
@@ -18,9 +29,33 @@ value class Vector(val data: FloatArray) {
1829
return Vector(FloatArray(n) { i -> data[i] + other.data[i] })
1930
}
2031

32+
operator fun minus(other: Vector): Vector {
33+
return Vector(FloatArray(n) { i -> data[i] - other.data[i] })
34+
}
35+
36+
fun magnitude(): Float {
37+
return norm()
38+
}
39+
40+
fun norm(): Float {
41+
return sqrt(data.sumOf { it.toDouble() * it }).toFloat()
42+
}
43+
44+
fun toColumnMatrix(): Matrix {
45+
return Matrix.column(values = data)
46+
}
47+
48+
fun toRowMatrix(): Matrix {
49+
return Matrix.row(values = data)
50+
}
51+
2152
companion object {
2253
fun from(vararg values: Float): Vector {
2354
return Vector(values)
2455
}
56+
57+
fun create(size: Int): Vector {
58+
return Vector(FloatArray(size))
59+
}
2560
}
2661
}

src/main/kotlin/com/kylecorry/sol/math/algebra/LinearAlgebra.kt

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.kylecorry.sol.math.algebra
22

33
import com.kylecorry.sol.math.SolMath
4+
import com.kylecorry.sol.math.Vector
45
import kotlin.math.absoluteValue
56
import kotlin.math.min
67
import kotlin.math.sqrt
@@ -285,12 +286,12 @@ object LinearAlgebra {
285286
return m.norm()
286287
}
287288

288-
fun solveLinear(a: Matrix, b: FloatArray): FloatArray {
289+
fun solveLinear(a: Matrix, b: Vector): Vector {
289290
require(a.rows() == a.columns()) { "Matrix must be square" }
290-
require(a.rows() == b.size) { "Matrix rows must be the same size as the vector" }
291+
require(a.rows() == b.n) { "Matrix rows must be the same size as the vector" }
291292

292293
val n = a.columns()
293-
val ab = a.appendColumn(b)
294+
val ab = a.appendColumn(b.data)
294295

295296
// Convert to row echelon form
296297
for (i in 0 until n) {
@@ -312,7 +313,7 @@ object LinearAlgebra {
312313
}
313314

314315
// Back substitution
315-
val x = FloatArray(n)
316+
val x = Vector.create(n)
316317
for (i in n - 1 downTo 0) {
317318
x[i] = ab[i, n] / ab[i, i]
318319
for (j in i - 1 downTo 0) {
@@ -323,10 +324,10 @@ object LinearAlgebra {
323324
return x
324325
}
325326

326-
fun leastNorm(a: Matrix, b: Array<Float>): FloatArray {
327+
fun leastNorm(a: Matrix, b: Vector): Vector {
327328
val (q, r) = qr(a.transpose())
328-
val y = q.dot(r.inverse().transpose()).dot(Matrix.create(b.size, 1) { i, _ -> b[i] })
329-
return y.getColumn(0)
329+
val y = q.dot(r.inverse().transpose()).dot(b.toColumnMatrix())
330+
return Vector(y.getColumn(0))
330331
}
331332

332333
/**
@@ -339,13 +340,13 @@ object LinearAlgebra {
339340
fun leastSquares(a: Matrix, b: Vector): Vector {
340341
val isUnderdetermined = a.rows() < a.columns()
341342
if (isUnderdetermined) {
342-
return leastNorm(a, b).toTypedArray()
343+
return leastNorm(a, b)
343344
}
344345

345346
val jt = a.transpose()
346347
val jtj = jt.dot(a)
347348
val jtr = jt.dot(b.toColumnMatrix())
348-
return solveLinear(jtj, jtr.getColumn(0)).toTypedArray()
349+
return solveLinear(jtj, Vector(jtr.getColumn(0)))
349350
}
350351

351352
fun appendColumn(m: Matrix, col: FloatArray): Matrix {

src/main/kotlin/com/kylecorry/sol/math/algebra/Vector.kt

Lines changed: 0 additions & 38 deletions
This file was deleted.

src/main/kotlin/com/kylecorry/sol/math/optimization/LeastSquaresOptimizer.kt

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
package com.kylecorry.sol.math.optimization
22

3+
import com.kylecorry.sol.math.Vector
34
import com.kylecorry.sol.math.algebra.LinearAlgebra
45
import com.kylecorry.sol.math.algebra.Matrix
5-
import com.kylecorry.sol.math.algebra.norm
6-
import com.kylecorry.sol.math.algebra.toColumnMatrix
76
import com.kylecorry.sol.math.geometry.Geometry
87
import kotlin.math.abs
98

@@ -38,19 +37,19 @@ class LeastSquaresOptimizer {
3837

3938
for (i in 0 until maxIterations) {
4039

41-
val f = points.mapIndexed { i, point ->
40+
val f = Vector(points.mapIndexed { i, point ->
4241
(errors[i] - distanceFn(point, guess)) * weightingFn(i, point, errors[i])
43-
}.toTypedArray()
42+
}.toFloatArray())
4443

4544
val jacobian = Matrix.create(points.mapIndexed { i, point ->
4645
jacobianFn(i, point, guess).toTypedArray()
4746
}.toTypedArray())
4847

4948
val step = LinearAlgebra.leastSquares(jacobian, f)
5049

51-
if ((step.maxOfOrNull { abs(it) } ?: 0f) > maxAllowedStep) {
52-
val maxStep = step.maxOfOrNull { abs(it) } ?: 0f
53-
step.forEachIndexed { index, it ->
50+
if ((step.data.maxOfOrNull { abs(it) } ?: 0f) > maxAllowedStep) {
51+
val maxStep = step.data.maxOfOrNull { abs(it) } ?: 0f
52+
step.data.forEachIndexed { index, it ->
5453
step[index] = it / maxStep * maxAllowedStep
5554
}
5655
}

src/main/kotlin/com/kylecorry/sol/science/astronomy/stars/StarLocationCalculator.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import com.kylecorry.sol.math.Range
44
import com.kylecorry.sol.math.SolMath.deltaAngle
55
import com.kylecorry.sol.math.SolMath.square
66
import com.kylecorry.sol.math.SolMath.toRadians
7+
import com.kylecorry.sol.math.Vector
78
import com.kylecorry.sol.math.algebra.LinearAlgebra
89
import com.kylecorry.sol.math.algebra.Matrix
9-
import com.kylecorry.sol.math.algebra.norm
1010
import com.kylecorry.sol.math.optimization.ConvergenceOptimizer
1111
import com.kylecorry.sol.math.optimization.SimulatedAnnealingOptimizer
1212
import com.kylecorry.sol.science.astronomy.Astronomy
@@ -116,7 +116,7 @@ internal class StarLocationCalculator {
116116
// Solve using least squares
117117
val ls = LinearAlgebra.leastSquares(
118118
Matrix.create(linesOfPosition.map { it.first }.toTypedArray()),
119-
linesOfPosition.map { it.second }.toTypedArray()
119+
Vector(linesOfPosition.map { it.second }.toFloatArray())
120120
)
121121

122122
if (ls.norm() < 0.000001) {

src/test/kotlin/com/kylecorry/sol/math/algebra/LinearAlgebraTest.kt

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.kylecorry.sol.math.algebra
22

3+
import com.kylecorry.sol.math.Vector
34
import org.junit.jupiter.api.Assertions.assertEquals
45
import org.junit.jupiter.api.Test
56

@@ -304,7 +305,7 @@ class LinearAlgebraTest {
304305
expected[2, 0] = 10f
305306
expected[2, 1] = 12f
306307

307-
val result = LinearAlgebra.map(m1){ it * 2 }
308+
val result = LinearAlgebra.map(m1) { it * 2 }
308309

309310
assertEquals(expected, result)
310311
}
@@ -430,8 +431,8 @@ class LinearAlgebraTest {
430431
val expected = Matrix.create(2, 2, 0f)
431432
expected[0, 0] = -2f
432433
expected[0, 1] = 1f
433-
expected[1, 0] = 3/2f
434-
expected[1, 1] = -1/2f
434+
expected[1, 0] = 3 / 2f
435+
expected[1, 1] = -1 / 2f
435436

436437
val actual = LinearAlgebra.inverse(m1)
437438

@@ -503,14 +504,14 @@ class LinearAlgebraTest {
503504
}
504505

505506
@Test
506-
fun solveLinear(){
507+
fun solveLinear() {
507508
val a1 = Matrix.create(2, 2, 0f)
508509
a1[0, 0] = 2f
509510
a1[0, 1] = 1f
510511
a1[1, 0] = 1f
511512
a1[1, 1] = -1f
512-
val b1 = floatArrayOf(-4f, -2f)
513-
val expected1 = floatArrayOf(-2f, 0f)
513+
val b1 = Vector.from(-4f, -2f)
514+
val expected1 = Vector.from(-2f, 0f)
514515
val actual1 = LinearAlgebra.solveLinear(a1, b1)
515516
assertEquals(expected1, actual1, 0.00001f)
516517

@@ -524,22 +525,22 @@ class LinearAlgebraTest {
524525
a2[2, 0] = 1f
525526
a2[2, 1] = 3f
526527
a2[2, 2] = 2f
527-
val b2 = floatArrayOf(8f, 7f, -3f)
528-
val expected2 = floatArrayOf(6f, -1f, -3f)
528+
val b2 = Vector.from(8f, 7f, -3f)
529+
val expected2 = Vector.from(6f, -1f, -3f)
529530
val actual2 = LinearAlgebra.solveLinear(a2, b2)
530531
assertEquals(expected2, actual2, 0.00001f)
531532
}
532533

533534
@Test
534-
fun leastSquares(){
535+
fun leastSquares() {
535536
// Well conditioned
536537
val a1 = Matrix.create(2, 2, 0f)
537538
a1[0, 0] = 2f
538539
a1[0, 1] = 1f
539540
a1[1, 0] = 1f
540541
a1[1, 1] = -1f
541-
val b1 = arrayOf(7f, -1f)
542-
val expected1 = arrayOf(2f, 3f)
542+
val b1 = Vector.from(7f, -1f)
543+
val expected1 = Vector.from(2f, 3f)
543544
val actual1 = LinearAlgebra.leastSquares(a1, b1)
544545
assertEquals(expected1, actual1, 0.00001f)
545546

@@ -551,8 +552,8 @@ class LinearAlgebraTest {
551552
a2[1, 1] = -1f
552553
a2[2, 0] = 1f
553554
a2[2, 1] = 0f
554-
val b2 = arrayOf(3f, 1f, 2f)
555-
val expected2 = arrayOf(2f, 1f)
555+
val b2 = Vector.from(3f, 1f, 2f)
556+
val expected2 = Vector.from(2f, 1f)
556557
val actual2 = LinearAlgebra.leastSquares(a2, b2)
557558
assertEquals(expected2, actual2, 0.00001f)
558559

@@ -564,8 +565,8 @@ class LinearAlgebraTest {
564565
a3[1, 0] = 0f
565566
a3[1, 1] = 1f
566567
a3[1, 2] = 2f
567-
val b3 = arrayOf(6f, 5f)
568-
val expected3 = arrayOf(2.5f, 2f, 1.5f)
568+
val b3 = Vector.from(6f, 5f)
569+
val expected3 = Vector.from(2.5f, 2f, 1.5f)
569570
val actual3 = LinearAlgebra.leastSquares(a3, b3)
570571
assertEquals(expected3, actual3, 0.00001f)
571572
}
@@ -581,18 +582,10 @@ class LinearAlgebraTest {
581582
}
582583
}
583584

584-
private fun assertEquals(m1: Array<Float>, m2: Array<Float>, tolerance: Float = 0f) {
585+
private fun assertEquals(m1: Vector, m2: Vector, tolerance: Float = 0f) {
585586
assertEquals(m1.size, m2.size)
586587

587-
for (i in m1.indices) {
588-
assertEquals(m1[i], m2[i], tolerance)
589-
}
590-
}
591-
592-
private fun assertEquals(m1: FloatArray, m2: FloatArray, tolerance: Float = 0f) {
593-
assertEquals(m1.size, m2.size)
594-
595-
for (i in m1.indices) {
588+
for (i in m1.data.indices) {
596589
assertEquals(m1[i], m2[i], tolerance)
597590
}
598591
}

0 commit comments

Comments
 (0)