Skip to content

Commit c153ca8

Browse files
authored
Merge pull request #154 from kylecorry31/matrix-rework
Matrix rework
2 parents 2a8b2f8 + 70b901f commit c153ca8

File tree

18 files changed

+366
-248
lines changed

18 files changed

+366
-248
lines changed

src/main/kotlin/com/kylecorry/sol/math/algebra/LinearAlgebra.kt

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package com.kylecorry.sol.math.algebra
22

33
import com.kylecorry.sol.math.SolMath
4-
import com.kylecorry.sol.shared.ArrayUtils.swap
54
import kotlin.math.absoluteValue
65
import kotlin.math.min
76
import kotlin.math.sqrt
@@ -13,7 +12,7 @@ object LinearAlgebra {
1312
throw Exception("Matrix 1 columns must be the same size as matrix 2 rows")
1413
}
1514

16-
val product = createMatrix(mat1.rows(), mat2.columns()) { _, _ -> 0f }
15+
val product = Matrix.zeros(mat1.rows(), mat2.columns())
1716
for (r in 0 until mat1.rows()) {
1817
for (otherC in 0 until mat2.columns()) {
1918
var sum = 0.0f
@@ -36,7 +35,7 @@ object LinearAlgebra {
3635
throw Exception("Matrix 1 rows must be the same size as matrix 2 rows")
3736
}
3837

39-
return createMatrix(mat1.rows(), mat1.columns()) { row, col ->
38+
return Matrix.create(mat1.rows(), mat1.columns()) { row, col ->
4039
mat1[row, col] - mat2[min(row, mat2.rows() - 1), min(col, mat2.columns() - 1)]
4140
}
4241
}
@@ -54,13 +53,13 @@ object LinearAlgebra {
5453
throw Exception("Matrix 1 rows must be the same size as matrix 2 rows")
5554
}
5655

57-
return createMatrix(mat1.rows(), mat1.columns()) { row, col ->
56+
return Matrix.create(mat1.rows(), mat1.columns()) { row, col ->
5857
mat1[row, col] + mat2[min(row, mat2.rows() - 1), min(col, mat2.columns() - 1)]
5958
}
6059
}
6160

6261
fun add(mat1: Matrix, value: Float): Matrix {
63-
return createMatrix(mat1.rows(), mat1.columns()) { row, col ->
62+
return Matrix.create(mat1.rows(), mat1.columns()) { row, col ->
6463
mat1[row, col] + value
6564
}
6665
}
@@ -74,13 +73,13 @@ object LinearAlgebra {
7473
throw Exception("Matrix 1 rows must be the same size as matrix 2 rows")
7574
}
7675

77-
return createMatrix(mat1.rows(), mat1.columns()) { row, col ->
76+
return Matrix.create(mat1.rows(), mat1.columns()) { row, col ->
7877
mat1[row, col] * mat2[min(row, mat2.rows() - 1), min(col, mat2.columns() - 1)]
7978
}
8079
}
8180

8281
fun multiply(mat1: Matrix, scale: Float): Matrix {
83-
return createMatrix(mat1.rows(), mat1.columns()) { row, col ->
82+
return Matrix.create(mat1.rows(), mat1.columns()) { row, col ->
8483
mat1[row, col] * scale
8584
}
8685
}
@@ -94,7 +93,7 @@ object LinearAlgebra {
9493
throw Exception("Matrix 1 rows must be the same size as matrix 2 rows")
9594
}
9695

97-
return createMatrix(mat1.rows(), mat1.columns()) { row, col ->
96+
return Matrix.create(mat1.rows(), mat1.columns()) { row, col ->
9897
mat1[row, col] / mat2[min(row, mat2.rows() - 1), min(col, mat2.columns() - 1)]
9998
}
10099
}
@@ -104,50 +103,60 @@ object LinearAlgebra {
104103
}
105104

106105
fun transpose(mat: Matrix): Matrix {
107-
return createMatrix(mat.columns(), mat.rows()) { row, col ->
106+
return Matrix.create(mat.columns(), mat.rows()) { row, col ->
108107
mat[col, row]
109108
}
110109
}
111110

112111
fun map(mat: Matrix, fn: (value: Float) -> Float): Matrix {
113-
return createMatrix(mat.rows(), mat.columns()) { row, col ->
112+
return Matrix.create(mat.rows(), mat.columns()) { row, col ->
114113
fn(mat[row, col])
115114
}
116115
}
117116

118117
fun mapRows(mat: Matrix, fn: (row: FloatArray) -> FloatArray): Matrix {
119-
return mat.map { fn(it.toFloatArray()).toTypedArray() }.toTypedArray()
118+
val copy = mat.clone()
119+
val temp = FloatArray(mat.columns())
120+
for (row in 0 until mat.rows()) {
121+
mat.getRow(row, temp)
122+
copy.setRow(row, fn(temp))
123+
}
124+
return copy
120125
}
121126

122127
fun mapColumns(mat: Matrix, fn: (row: FloatArray) -> FloatArray): Matrix {
123128
return mapRows(mat.transpose(), fn).transpose()
124129
}
125130

126131
fun sum(mat: Matrix): Float {
127-
return mat.sumOf { it.sum().toDouble() }.toFloat()
132+
return mat.sum()
128133
}
129134

130135
fun sumColumns(mat: Matrix): Matrix {
131136
return sumRows(mat.transpose()).transpose()
132137
}
133138

134139
fun sumRows(mat: Matrix): Matrix {
135-
return createMatrix(mat.rows(), 1) { row, _ ->
136-
mat[row].sum()
140+
val temp = FloatArray(mat.columns())
141+
return Matrix.create(mat.rows(), 1) { row, _ ->
142+
mat.getRow(row, temp)
143+
temp.sum()
137144
}
138145
}
139146

140147
fun max(mat: Matrix): Float {
141-
return mat.maxOf { it.max() }
148+
return mat.max()
142149
}
143150

144151
fun maxColumns(mat: Matrix): Matrix {
145152
return maxRows(mat.transpose()).transpose()
146153
}
147154

148155
fun maxRows(mat: Matrix): Matrix {
149-
return createMatrix(mat.rows(), 1) { row, _ ->
150-
mat[row].max()
156+
val temp = FloatArray(mat.columns())
157+
return Matrix.create(mat.rows(), 1) { row, _ ->
158+
mat.getRow(row, temp)
159+
temp.max()
151160
}
152161
}
153162

@@ -159,7 +168,7 @@ object LinearAlgebra {
159168
val det = determinant(m)
160169
if (SolMath.isZero(det)) {
161170
// No inverse exists
162-
return createMatrix(m.rows(), m.columns(), 0f)
171+
return Matrix.zeros(m.rows(), m.columns())
163172
}
164173
return adjugate(m).transpose().divide(determinant(m))
165174
}
@@ -171,7 +180,7 @@ object LinearAlgebra {
171180

172181
var colMultiplier: Int
173182
var rowMultiplier: Int
174-
return createMatrix(m.rows(), m.columns()) { r, c ->
183+
return Matrix.create(m.rows(), m.columns()) { r, c ->
175184
rowMultiplier = if (r % 2 == 0) {
176185
1
177186
} else {
@@ -208,7 +217,7 @@ object LinearAlgebra {
208217
}
209218

210219
fun cofactor(m: Matrix, r: Int, c: Int): Matrix {
211-
return createMatrix(m.rows() - 1, m.columns() - 1) { r1, c1 ->
220+
return Matrix.create(m.rows() - 1, m.columns() - 1) { r1, c1 ->
212221
val sr = if (r1 < r) {
213222
r1
214223
} else {
@@ -227,20 +236,20 @@ object LinearAlgebra {
227236
val rows = m.rows()
228237
val cols = m.columns()
229238

230-
val q = createMatrix(rows, cols) { _, _ -> 0f }
231-
val r = createMatrix(cols, cols) { _, _ -> 0f }
239+
val q = Matrix.zeros(rows, cols)
240+
val r = Matrix.zeros(cols, cols)
232241

233242
for (j in 0 until cols) {
234-
var v = m.column(j)
243+
var v = Matrix.row(values = m.getColumn(j))
235244

236245
for (i in 0 until j) {
237-
val qi = q.column(i)
246+
val qi = Matrix.row(values = q.getColumn(i))
238247
r[i, j] = qi.dot(v.transpose())[0, 0]
239248
v = v.subtract(qi.multiply(r[i, j]))
240249
}
241250

242251
r[j, j] = norm(v)
243-
q.setColumn(j, v.divide(r[j, j])[0])
252+
q.setColumn(j, v.divide(r[j, j]).getRow(0))
244253
}
245254

246255
return q to r
@@ -250,12 +259,12 @@ object LinearAlgebra {
250259
* Returns a column matrix of the diagonal of the matrix
251260
*/
252261
fun diagonal(m: Matrix): Matrix {
253-
return createMatrix(1, min(m.rows(), m.columns())) { i, _ ->
262+
return Matrix.create(1, min(m.rows(), m.columns())) { i, _ ->
254263
m[i, i]
255264
}
256265
}
257266

258-
fun eigenvalues(m: Matrix, tolerance: Float = 1e-12f, maxIterations: Int = 1000): Array<Float> {
267+
fun eigenvalues(m: Matrix, tolerance: Float = 1e-12f, maxIterations: Int = 1000): FloatArray {
259268
var old = m
260269
var new = m
261270

@@ -269,25 +278,19 @@ object LinearAlgebra {
269278
iterations++
270279
}
271280

272-
return diagonal(new)[0]
281+
return diagonal(new).getRow(0)
273282
}
274283

275284
fun norm(m: Matrix): Float {
276-
return sqrt(
277-
m.sumOf { values ->
278-
values.sumOf { value ->
279-
value * value.toDouble()
280-
}
281-
}
282-
).toFloat()
285+
return m.norm()
283286
}
284287

285-
fun solveLinear(a: Matrix, b: Array<Float>): Array<Float> {
288+
fun solveLinear(a: Matrix, b: FloatArray): FloatArray {
286289
require(a.rows() == a.columns()) { "Matrix must be square" }
287290
require(a.rows() == b.size) { "Matrix rows must be the same size as the vector" }
288291

289292
val n = a.columns()
290-
val ab = a.appendColumn(b.toFloatArray())
293+
val ab = a.appendColumn(b)
291294

292295
// Convert to row echelon form
293296
for (i in 0 until n) {
@@ -298,7 +301,7 @@ object LinearAlgebra {
298301
}
299302
}
300303

301-
ab.swap(i, maxRow)
304+
ab.swapRows(i, maxRow)
302305

303306
for (j in i + 1 until n) {
304307
val factor = ab[j, i] / ab[i, i]
@@ -317,13 +320,13 @@ object LinearAlgebra {
317320
}
318321
}
319322

320-
return x.toTypedArray()
323+
return x
321324
}
322325

323-
fun leastNorm(a: Matrix, b: Array<Float>): Array<Float> {
326+
fun leastNorm(a: Matrix, b: Array<Float>): FloatArray {
324327
val (q, r) = qr(a.transpose())
325-
val y = q.dot(r.inverse().transpose()).dot(createMatrix(b.size, 1) { i, _ -> b[i] })
326-
return y.transpose()[0]
328+
val y = q.dot(r.inverse().transpose()).dot(Matrix.create(b.size, 1) { i, _ -> b[i] })
329+
return y.getColumn(0)
327330
}
328331

329332
/**
@@ -336,17 +339,17 @@ object LinearAlgebra {
336339
fun leastSquares(a: Matrix, b: Vector): Vector {
337340
val isUnderdetermined = a.rows() < a.columns()
338341
if (isUnderdetermined) {
339-
return leastNorm(a, b)
342+
return leastNorm(a, b).toTypedArray()
340343
}
341344

342345
val jt = a.transpose()
343346
val jtj = jt.dot(a)
344347
val jtr = jt.dot(b.toColumnMatrix())
345-
return solveLinear(jtj, jtr.transpose()[0])
348+
return solveLinear(jtj, jtr.getColumn(0)).toTypedArray()
346349
}
347350

348351
fun appendColumn(m: Matrix, col: FloatArray): Matrix {
349-
return createMatrix(m.rows(), m.columns() + 1) { r, c ->
352+
return Matrix.create(m.rows(), m.columns() + 1) { r, c ->
350353
if (c < m.columns()) {
351354
m[r, c]
352355
} else {
@@ -356,7 +359,7 @@ object LinearAlgebra {
356359
}
357360

358361
fun appendColumn(m: Matrix, value: Float): Matrix {
359-
return createMatrix(m.rows(), m.columns() + 1) { r, c ->
362+
return Matrix.create(m.rows(), m.columns() + 1) { r, c ->
360363
if (c < m.columns()) {
361364
m[r, c]
362365
} else {
@@ -366,7 +369,7 @@ object LinearAlgebra {
366369
}
367370

368371
fun appendRow(m: Matrix, row: FloatArray): Matrix {
369-
return createMatrix(m.rows() + 1, m.columns()) { r, c ->
372+
return Matrix.create(m.rows() + 1, m.columns()) { r, c ->
370373
if (r < m.rows()) {
371374
m[r, c]
372375
} else {
@@ -376,7 +379,7 @@ object LinearAlgebra {
376379
}
377380

378381
fun appendRow(m: Matrix, value: Float): Matrix {
379-
return createMatrix(m.rows() + 1, m.columns()) { r, c ->
382+
return Matrix.create(m.rows() + 1, m.columns()) { r, c ->
380383
if (r < m.rows()) {
381384
m[r, c]
382385
} else {

0 commit comments

Comments
 (0)