11package com.kylecorry.sol.math.algebra
22
33import com.kylecorry.sol.math.SolMath
4- import com.kylecorry.sol.shared.ArrayUtils.swap
54import kotlin.math.absoluteValue
65import kotlin.math.min
76import kotlin.math.sqrt
@@ -13,7 +12,7 @@ object LinearAlgebra {
1312 throw Exception (" Matrix 1 columns must be the same size as matrix 2 rows" )
1413 }
1514
16- val product = createMatrix (mat1.rows(), mat2.columns()) { _, _ -> 0f }
15+ val product = Matrix .zeros (mat1.rows(), mat2.columns())
1716 for (r in 0 until mat1.rows()) {
1817 for (otherC in 0 until mat2.columns()) {
1918 var sum = 0.0f
@@ -36,7 +35,7 @@ object LinearAlgebra {
3635 throw Exception (" Matrix 1 rows must be the same size as matrix 2 rows" )
3736 }
3837
39- return createMatrix (mat1.rows(), mat1.columns()) { row, col ->
38+ return Matrix .create (mat1.rows(), mat1.columns()) { row, col ->
4039 mat1[row, col] - mat2[min(row, mat2.rows() - 1 ), min(col, mat2.columns() - 1 )]
4140 }
4241 }
@@ -54,13 +53,13 @@ object LinearAlgebra {
5453 throw Exception (" Matrix 1 rows must be the same size as matrix 2 rows" )
5554 }
5655
57- return createMatrix (mat1.rows(), mat1.columns()) { row, col ->
56+ return Matrix .create (mat1.rows(), mat1.columns()) { row, col ->
5857 mat1[row, col] + mat2[min(row, mat2.rows() - 1 ), min(col, mat2.columns() - 1 )]
5958 }
6059 }
6160
6261 fun add (mat1 : Matrix , value : Float ): Matrix {
63- return createMatrix (mat1.rows(), mat1.columns()) { row, col ->
62+ return Matrix .create (mat1.rows(), mat1.columns()) { row, col ->
6463 mat1[row, col] + value
6564 }
6665 }
@@ -74,13 +73,13 @@ object LinearAlgebra {
7473 throw Exception (" Matrix 1 rows must be the same size as matrix 2 rows" )
7574 }
7675
77- return createMatrix (mat1.rows(), mat1.columns()) { row, col ->
76+ return Matrix .create (mat1.rows(), mat1.columns()) { row, col ->
7877 mat1[row, col] * mat2[min(row, mat2.rows() - 1 ), min(col, mat2.columns() - 1 )]
7978 }
8079 }
8180
8281 fun multiply (mat1 : Matrix , scale : Float ): Matrix {
83- return createMatrix (mat1.rows(), mat1.columns()) { row, col ->
82+ return Matrix .create (mat1.rows(), mat1.columns()) { row, col ->
8483 mat1[row, col] * scale
8584 }
8685 }
@@ -94,7 +93,7 @@ object LinearAlgebra {
9493 throw Exception (" Matrix 1 rows must be the same size as matrix 2 rows" )
9594 }
9695
97- return createMatrix (mat1.rows(), mat1.columns()) { row, col ->
96+ return Matrix .create (mat1.rows(), mat1.columns()) { row, col ->
9897 mat1[row, col] / mat2[min(row, mat2.rows() - 1 ), min(col, mat2.columns() - 1 )]
9998 }
10099 }
@@ -104,50 +103,60 @@ object LinearAlgebra {
104103 }
105104
106105 fun transpose (mat : Matrix ): Matrix {
107- return createMatrix (mat.columns(), mat.rows()) { row, col ->
106+ return Matrix .create (mat.columns(), mat.rows()) { row, col ->
108107 mat[col, row]
109108 }
110109 }
111110
112111 fun map (mat : Matrix , fn : (value: Float ) -> Float ): Matrix {
113- return createMatrix (mat.rows(), mat.columns()) { row, col ->
112+ return Matrix .create (mat.rows(), mat.columns()) { row, col ->
114113 fn(mat[row, col])
115114 }
116115 }
117116
118117 fun mapRows (mat : Matrix , fn : (row: FloatArray ) -> FloatArray ): Matrix {
119- return mat.map { fn(it.toFloatArray()).toTypedArray() }.toTypedArray()
118+ val copy = mat.clone()
119+ val temp = FloatArray (mat.columns())
120+ for (row in 0 until mat.rows()) {
121+ mat.getRow(row, temp)
122+ copy.setRow(row, fn(temp))
123+ }
124+ return copy
120125 }
121126
122127 fun mapColumns (mat : Matrix , fn : (row: FloatArray ) -> FloatArray ): Matrix {
123128 return mapRows(mat.transpose(), fn).transpose()
124129 }
125130
126131 fun sum (mat : Matrix ): Float {
127- return mat.sumOf { it. sum().toDouble() }.toFloat ()
132+ return mat.sum()
128133 }
129134
130135 fun sumColumns (mat : Matrix ): Matrix {
131136 return sumRows(mat.transpose()).transpose()
132137 }
133138
134139 fun sumRows (mat : Matrix ): Matrix {
135- return createMatrix(mat.rows(), 1 ) { row, _ ->
136- mat[row].sum()
140+ val temp = FloatArray (mat.columns())
141+ return Matrix .create(mat.rows(), 1 ) { row, _ ->
142+ mat.getRow(row, temp)
143+ temp.sum()
137144 }
138145 }
139146
140147 fun max (mat : Matrix ): Float {
141- return mat.maxOf { it. max() }
148+ return mat.max()
142149 }
143150
144151 fun maxColumns (mat : Matrix ): Matrix {
145152 return maxRows(mat.transpose()).transpose()
146153 }
147154
148155 fun maxRows (mat : Matrix ): Matrix {
149- return createMatrix(mat.rows(), 1 ) { row, _ ->
150- mat[row].max()
156+ val temp = FloatArray (mat.columns())
157+ return Matrix .create(mat.rows(), 1 ) { row, _ ->
158+ mat.getRow(row, temp)
159+ temp.max()
151160 }
152161 }
153162
@@ -159,7 +168,7 @@ object LinearAlgebra {
159168 val det = determinant(m)
160169 if (SolMath .isZero(det)) {
161170 // No inverse exists
162- return createMatrix (m.rows(), m.columns(), 0f )
171+ return Matrix .zeros (m.rows(), m.columns())
163172 }
164173 return adjugate(m).transpose().divide(determinant(m))
165174 }
@@ -171,7 +180,7 @@ object LinearAlgebra {
171180
172181 var colMultiplier: Int
173182 var rowMultiplier: Int
174- return createMatrix (m.rows(), m.columns()) { r, c ->
183+ return Matrix .create (m.rows(), m.columns()) { r, c ->
175184 rowMultiplier = if (r % 2 == 0 ) {
176185 1
177186 } else {
@@ -208,7 +217,7 @@ object LinearAlgebra {
208217 }
209218
210219 fun cofactor (m : Matrix , r : Int , c : Int ): Matrix {
211- return createMatrix (m.rows() - 1 , m.columns() - 1 ) { r1, c1 ->
220+ return Matrix .create (m.rows() - 1 , m.columns() - 1 ) { r1, c1 ->
212221 val sr = if (r1 < r) {
213222 r1
214223 } else {
@@ -227,20 +236,20 @@ object LinearAlgebra {
227236 val rows = m.rows()
228237 val cols = m.columns()
229238
230- val q = createMatrix (rows, cols) { _, _ -> 0f }
231- val r = createMatrix (cols, cols) { _, _ -> 0f }
239+ val q = Matrix .zeros (rows, cols)
240+ val r = Matrix .zeros (cols, cols)
232241
233242 for (j in 0 until cols) {
234- var v = m.column(j )
243+ var v = Matrix .row(values = m.getColumn(j) )
235244
236245 for (i in 0 until j) {
237- val qi = q.column(i )
246+ val qi = Matrix .row(values = q.getColumn(i) )
238247 r[i, j] = qi.dot(v.transpose())[0 , 0 ]
239248 v = v.subtract(qi.multiply(r[i, j]))
240249 }
241250
242251 r[j, j] = norm(v)
243- q.setColumn(j, v.divide(r[j, j])[ 0 ] )
252+ q.setColumn(j, v.divide(r[j, j]).getRow( 0 ) )
244253 }
245254
246255 return q to r
@@ -250,12 +259,12 @@ object LinearAlgebra {
250259 * Returns a column matrix of the diagonal of the matrix
251260 */
252261 fun diagonal (m : Matrix ): Matrix {
253- return createMatrix (1 , min(m.rows(), m.columns())) { i, _ ->
262+ return Matrix .create (1 , min(m.rows(), m.columns())) { i, _ ->
254263 m[i, i]
255264 }
256265 }
257266
258- fun eigenvalues (m : Matrix , tolerance : Float = 1e-12f, maxIterations : Int = 1000): Array < Float > {
267+ fun eigenvalues (m : Matrix , tolerance : Float = 1e-12f, maxIterations : Int = 1000): FloatArray {
259268 var old = m
260269 var new = m
261270
@@ -269,25 +278,19 @@ object LinearAlgebra {
269278 iterations++
270279 }
271280
272- return diagonal(new)[ 0 ]
281+ return diagonal(new).getRow( 0 )
273282 }
274283
275284 fun norm (m : Matrix ): Float {
276- return sqrt(
277- m.sumOf { values ->
278- values.sumOf { value ->
279- value * value.toDouble()
280- }
281- }
282- ).toFloat()
285+ return m.norm()
283286 }
284287
285- fun solveLinear (a : Matrix , b : Array < Float > ): Array < Float > {
288+ fun solveLinear (a : Matrix , b : FloatArray ): FloatArray {
286289 require(a.rows() == a.columns()) { " Matrix must be square" }
287290 require(a.rows() == b.size) { " Matrix rows must be the same size as the vector" }
288291
289292 val n = a.columns()
290- val ab = a.appendColumn(b.toFloatArray() )
293+ val ab = a.appendColumn(b)
291294
292295 // Convert to row echelon form
293296 for (i in 0 until n) {
@@ -298,7 +301,7 @@ object LinearAlgebra {
298301 }
299302 }
300303
301- ab.swap (i, maxRow)
304+ ab.swapRows (i, maxRow)
302305
303306 for (j in i + 1 until n) {
304307 val factor = ab[j, i] / ab[i, i]
@@ -317,13 +320,13 @@ object LinearAlgebra {
317320 }
318321 }
319322
320- return x.toTypedArray()
323+ return x
321324 }
322325
323- fun leastNorm (a : Matrix , b : Array <Float >): Array < Float > {
326+ fun leastNorm (a : Matrix , b : Array <Float >): FloatArray {
324327 val (q, r) = qr(a.transpose())
325- val y = q.dot(r.inverse().transpose()).dot(createMatrix (b.size, 1 ) { i, _ -> b[i] })
326- return y.transpose()[ 0 ]
328+ val y = q.dot(r.inverse().transpose()).dot(Matrix .create (b.size, 1 ) { i, _ -> b[i] })
329+ return y.getColumn( 0 )
327330 }
328331
329332 /* *
@@ -336,17 +339,17 @@ object LinearAlgebra {
336339 fun leastSquares (a : Matrix , b : Vector ): Vector {
337340 val isUnderdetermined = a.rows() < a.columns()
338341 if (isUnderdetermined) {
339- return leastNorm(a, b)
342+ return leastNorm(a, b).toTypedArray()
340343 }
341344
342345 val jt = a.transpose()
343346 val jtj = jt.dot(a)
344347 val jtr = jt.dot(b.toColumnMatrix())
345- return solveLinear(jtj, jtr.transpose()[ 0 ] )
348+ return solveLinear(jtj, jtr.getColumn( 0 )).toTypedArray( )
346349 }
347350
348351 fun appendColumn (m : Matrix , col : FloatArray ): Matrix {
349- return createMatrix (m.rows(), m.columns() + 1 ) { r, c ->
352+ return Matrix .create (m.rows(), m.columns() + 1 ) { r, c ->
350353 if (c < m.columns()) {
351354 m[r, c]
352355 } else {
@@ -356,7 +359,7 @@ object LinearAlgebra {
356359 }
357360
358361 fun appendColumn (m : Matrix , value : Float ): Matrix {
359- return createMatrix (m.rows(), m.columns() + 1 ) { r, c ->
362+ return Matrix .create (m.rows(), m.columns() + 1 ) { r, c ->
360363 if (c < m.columns()) {
361364 m[r, c]
362365 } else {
@@ -366,7 +369,7 @@ object LinearAlgebra {
366369 }
367370
368371 fun appendRow (m : Matrix , row : FloatArray ): Matrix {
369- return createMatrix (m.rows() + 1 , m.columns()) { r, c ->
372+ return Matrix .create (m.rows() + 1 , m.columns()) { r, c ->
370373 if (r < m.rows()) {
371374 m[r, c]
372375 } else {
@@ -376,7 +379,7 @@ object LinearAlgebra {
376379 }
377380
378381 fun appendRow (m : Matrix , value : Float ): Matrix {
379- return createMatrix (m.rows() + 1 , m.columns()) { r, c ->
382+ return Matrix .create (m.rows() + 1 , m.columns()) { r, c ->
380383 if (r < m.rows()) {
381384 m[r, c]
382385 } else {
0 commit comments