Skip to content

Commit 014b9fe

Browse files
committed
Add polynomial fitting and operations
1 parent f8bffee commit 014b9fe

File tree

6 files changed

+249
-41
lines changed

6 files changed

+249
-41
lines changed

src/main/kotlin/com/kylecorry/sol/math/algebra/Polynomial.kt

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ import com.kylecorry.sol.math.sumOfFloat
66
data class PolynomialTerm(
77
val coefficient: Float,
88
val exponent: Int
9-
)
9+
) {
10+
operator fun times(other: PolynomialTerm): PolynomialTerm {
11+
return PolynomialTerm(this.coefficient * other.coefficient, this.exponent + other.exponent)
12+
}
13+
}
1014

1115
class Polynomial(terms: List<PolynomialTerm>) {
1216

@@ -44,6 +48,27 @@ class Polynomial(terms: List<PolynomialTerm>) {
4448
return Polynomial(integratedTerms + listOf(PolynomialTerm(c, 0)))
4549
}
4650

51+
operator fun plus(other: Polynomial): Polynomial {
52+
return Polynomial(this.terms + other.terms)
53+
}
54+
55+
operator fun minus(other: Polynomial): Polynomial {
56+
val negatedTerms = other.terms.map {
57+
PolynomialTerm(-it.coefficient, it.exponent)
58+
}
59+
return Polynomial(this.terms + negatedTerms)
60+
}
61+
62+
operator fun times(other: Polynomial): Polynomial {
63+
val productTerms = mutableListOf<PolynomialTerm>()
64+
for (term1 in this.terms) {
65+
for (term2 in other.terms) {
66+
productTerms.add(term1 * term2)
67+
}
68+
}
69+
return Polynomial(productTerms)
70+
}
71+
4772
override fun equals(other: Any?): Boolean {
4873
if (this === other) return true
4974
if (other !is Polynomial) return false
Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,13 @@
11
package com.kylecorry.sol.math.interpolation
22

33
import com.kylecorry.sol.math.Vector2
4+
import com.kylecorry.sol.math.regression.NewtonPolynomialRegression
45

5-
class NewtonInterpolator(points: List<Vector2>, private val order: Int = points.size - 1) : Interpolator {
6-
7-
private var sortedPoints = points.sortedBy { it.x }
8-
private var cachedA = emptyArray<Float>()
9-
10-
private val lock = Any()
6+
class NewtonInterpolator(points: List<Vector2>, order: Int = points.size - 1) : Interpolator {
7+
private val regression = NewtonPolynomialRegression(order)
8+
private var polynomial = regression.fit(points)
119

1210
override fun interpolate(x: Float): Float {
13-
synchronized(lock) {
14-
if (cachedA.isEmpty()) {
15-
cachedA = getDividedDifferenceCoefficients(sortedPoints)
16-
}
17-
18-
return dividedDifferencePrecomputed(x, sortedPoints, cachedA)
19-
}
20-
}
21-
22-
private fun getDividedDifferenceCoefficients(
23-
points: List<Vector2>
24-
): Array<Float> {
25-
val n = order + 1
26-
val a = Array(n) { 0f }
27-
for (i in 0 until n) {
28-
a[i] = points[i].y
29-
}
30-
for (j in 1 until n) {
31-
for (i in n - 1 downTo j) {
32-
a[i] = (a[i] - a[i - 1]) / (points[i].x - points[i - j].x)
33-
}
34-
}
35-
return a
11+
return polynomial.evaluate(x)
3612
}
37-
38-
39-
private fun dividedDifferencePrecomputed(x: Float, points: List<Vector2>, a: Array<Float>): Float {
40-
val n = a.size
41-
var y = a[n - 1]
42-
for (i in n - 2 downTo 0) {
43-
y = a[i] + (x - points[i].x) * y
44-
}
45-
return y
46-
}
47-
4813
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package com.kylecorry.sol.math.regression
2+
3+
import com.kylecorry.sol.math.Vector2
4+
import com.kylecorry.sol.math.algebra.Polynomial
5+
import com.kylecorry.sol.math.algebra.PolynomialTerm
6+
7+
class NewtonPolynomialRegression(private val order: Int) : PolynomialRegression {
8+
override fun fit(points: List<Vector2>): Polynomial {
9+
val sortedPoints = points.sortedBy { it.x }
10+
val coefficients = getDividedDifferenceCoefficients(sortedPoints)
11+
val terms = getPolynomialTerms(coefficients, sortedPoints)
12+
return Polynomial(terms)
13+
}
14+
15+
private fun getPolynomialTerms(coefficients: Array<Float>, points: List<Vector2>): List<PolynomialTerm> {
16+
val n = coefficients.size
17+
var polynomial = Polynomial.fromCoefficients(coefficients[n - 1])
18+
for (i in n - 2 downTo 0) {
19+
polynomial = Polynomial.fromCoefficients(coefficients[i]) + polynomial * Polynomial(
20+
listOf(
21+
PolynomialTerm(1f, 1),
22+
PolynomialTerm(-points[i].x, 0)
23+
)
24+
)
25+
}
26+
return polynomial.terms
27+
}
28+
29+
private fun getDividedDifferenceCoefficients(
30+
points: List<Vector2>
31+
): Array<Float> {
32+
val n = order + 1
33+
val a = Array(n) { 0f }
34+
for (i in 0 until n) {
35+
a[i] = points[i].y
36+
}
37+
for (j in 1 until n) {
38+
for (i in n - 1 downTo j) {
39+
a[i] = (a[i] - a[i - 1]) / (points[i].x - points[i - j].x)
40+
}
41+
}
42+
return a
43+
}
44+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.kylecorry.sol.math.regression
2+
3+
import com.kylecorry.sol.math.Vector2
4+
import com.kylecorry.sol.math.algebra.Polynomial
5+
6+
interface PolynomialRegression {
7+
fun fit(points: List<Vector2>): Polynomial
8+
}

src/test/kotlin/com/kylecorry/sol/math/algebra/PolynomialTest.kt

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,131 @@ class PolynomialTest {
141141
val poly = Polynomial.of("2x^3 - 4x^2 + x - 5.1234")
142142
assertEquals("2x^3 - 4x^2 + x - 5.1234", poly.toString())
143143
}
144+
145+
@Test
146+
fun addition() {
147+
// Test: (2x + 3) + (x + 1) = 3x + 4
148+
val poly1 = Polynomial.of("2x + 3")
149+
val poly2 = Polynomial.of("x + 1")
150+
val result = poly1 + poly2
151+
assertEquals(Polynomial.of("3x + 4"), result)
152+
153+
// Test: (x^2 + 2x + 1) + (x^2 + x - 1) = 2x^2 + 3x
154+
val quadratic1 = Polynomial.fromCoefficients(1f, 2f, 1f)
155+
val quadratic2 = Polynomial.of("x^2 + x - 1")
156+
val result2 = quadratic1 + quadratic2
157+
assertEquals(Polynomial.of("2x^2 + 3x"), result2)
158+
159+
// Test: (5x^3 + 2x) + (-5x^3 + 3x) = 5x
160+
val cubic1 = Polynomial.of("5x^3 + 2x")
161+
val cubic2 = Polynomial.of("-5x^3 + 3x")
162+
val result3 = cubic1 + cubic2
163+
assertEquals(Polynomial.of("5x"), result3)
164+
165+
// Test: (3) + (7) = 10 (constant addition)
166+
val const1 = Polynomial.of("3")
167+
val const2 = Polynomial.of("7")
168+
val result4 = const1 + const2
169+
assertEquals(Polynomial.of("10"), result4)
170+
171+
// Test: (x^2) + (0) = x^2 (adding zero)
172+
val quadratic = Polynomial.of("x^2")
173+
val zero = Polynomial.of()
174+
val result5 = quadratic + zero
175+
assertEquals(quadratic, result5)
176+
177+
// Test by evaluation: (x^2 + x + 1) + (2x - 1) at x=2 should equal 10
178+
val p1 = Polynomial.of("x^2 + x + 1")
179+
val p2 = Polynomial.of("2x - 1")
180+
val sum = p1 + p2
181+
assertEquals(10f, sum.evaluate(2f), 0.001f)
182+
}
183+
184+
@Test
185+
fun subtraction() {
186+
// Test: (3x + 5) - (x + 2) = 2x + 3
187+
val poly1 = Polynomial.of("3x + 5")
188+
val poly2 = Polynomial.of("x + 2")
189+
val result = poly1 - poly2
190+
assertEquals(Polynomial.of("2x + 3"), result)
191+
192+
// Test: (x^2 + 3x + 2) - (x^2 + x - 1) = 2x + 3
193+
val quadratic1 = Polynomial.of("x^2 + 3x + 2")
194+
val quadratic2 = Polynomial.of("x^2 + x - 1")
195+
val result2 = quadratic1 - quadratic2
196+
assertEquals(Polynomial.of("2x + 3"), result2)
197+
198+
// Test: (5x^3 + 2x) - (2x^3 + 3x) = 3x^3 - x
199+
val cubic1 = Polynomial.of("5x^3 + 2x")
200+
val cubic2 = Polynomial.of("2x^3 + 3x")
201+
val result3 = cubic1 - cubic2
202+
assertEquals(Polynomial.of("3x^3 - x"), result3)
203+
204+
// Test: (10) - (3) = 7 (constant subtraction)
205+
val const1 = Polynomial.of("10")
206+
val const2 = Polynomial.of("3")
207+
val result4 = const1 - const2
208+
assertEquals(Polynomial.of("7"), result4)
209+
210+
// Test: (x^2) - (x^2) = 0
211+
val quadratic = Polynomial.of("x^2")
212+
val result5 = quadratic - quadratic
213+
assertEquals(Polynomial.of("0"), result5)
214+
215+
// Test by evaluation: (3x^2 + 2x + 5) - (x^2 + x + 1) at x=3 should equal 25
216+
val p1 = Polynomial.of("3x^2 + 2x + 5")
217+
val p2 = Polynomial.of("x^2 + x + 1")
218+
val diff = p1 - p2
219+
assertEquals(25f, diff.evaluate(3f), 0.001f)
220+
}
221+
222+
@Test
223+
fun multiplication() {
224+
// Test: (2x + 1) * (x + 3) = 2x^2 + 7x + 3
225+
val poly1 = Polynomial.of("2x + 1")
226+
val poly2 = Polynomial.of("x + 3")
227+
val result = poly1 * poly2
228+
assertEquals(Polynomial.of("2x^2 + 7x + 3"), result)
229+
230+
// Test: (x + 1) * (x - 1) = x^2 - 1
231+
val binomial1 = Polynomial.of("x + 1")
232+
val binomial2 = Polynomial.of("x - 1")
233+
val result2 = binomial1 * binomial2
234+
assertEquals(Polynomial.of("x^2 - 1"), result2)
235+
236+
// Test: (x + 2)^2 = x^2 + 4x + 4
237+
val binomial = Polynomial.of("x + 2")
238+
val result3 = binomial * binomial
239+
assertEquals(Polynomial.of("x^2 + 4x + 4"), result3)
240+
241+
// Test: (3x^2 + 2x + 1) * (2x + 1) = 6x^3 + 7x^2 + 4x + 1
242+
val quadratic = Polynomial.of("3x^2 + 2x + 1")
243+
val linear = Polynomial.of("2x + 1")
244+
val result4 = quadratic * linear
245+
assertEquals(Polynomial.of("6x^3 + 7x^2 + 4x + 1"), result4)
246+
247+
// Test: (5) * (3) = 15 (constant multiplication)
248+
val const1 = Polynomial.of("5")
249+
val const2 = Polynomial.of("3")
250+
val result5 = const1 * const2
251+
assertEquals(Polynomial.of("15"), result5)
252+
253+
// Test: (x) * (0) = 0
254+
val x = Polynomial.of("x")
255+
val zero = Polynomial.of()
256+
val result6 = x * zero
257+
assertEquals(Polynomial.of("0"), result6)
258+
259+
// Test by evaluation: (x + 2) * (x - 1) at x=4 should equal 18
260+
val p1 = Polynomial.of("x + 2")
261+
val p2 = Polynomial.of("x - 1")
262+
val product = p1 * p2
263+
assertEquals(18f, product.evaluate(4f), 0.001f)
264+
265+
// Test: (x^2 + 1) * (x^2 - 1) = x^4 - 1
266+
val sq1 = Polynomial.of("x^2 + 1")
267+
val sq2 = Polynomial.of("x^2 - 1")
268+
val result7 = sq1 * sq2
269+
assertEquals(Polynomial.of("x^4 - 1"), result7)
270+
}
144271
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.kylecorry.sol.math.regression
2+
3+
import com.kylecorry.sol.math.Vector2
4+
import com.kylecorry.sol.math.algebra.Polynomial
5+
import org.junit.jupiter.api.Assertions.assertEquals
6+
import org.junit.jupiter.params.ParameterizedTest
7+
import org.junit.jupiter.params.provider.CsvSource
8+
9+
class NewtonPolynomialRegressionTest {
10+
11+
@ParameterizedTest
12+
@CsvSource(
13+
"'2x + 1', '2x + 1', 1",
14+
"'x^2 + 2x + 1', 'x^2 + 2x + 1', 2",
15+
"'x^3 - 2x^2 + 3x - 4', 'x^3 - 2x^2 + 3x - 4', 3",
16+
"'5', '5', 1",
17+
"'x^2', 'x^2', 2",
18+
"'2x - 3', '2x - 3', 1",
19+
"'0.5x', '0.5x', 1",
20+
"'x^4', 'x^4', 4",
21+
"'x', 'x', 1",
22+
"'x^3 + 2x^2 - x + 3', 'x^3 + 2x^2 - x + 3', 3",
23+
"'x + 1', 'x + 1', 1",
24+
"'3x^2 - 2x + 7', '3x^2 - 2x + 7', 2",
25+
"'x^3', 'x^3', 3",
26+
"'4x - 5', '4x - 5', 1",
27+
// Lower order than required, it will approximate
28+
"'x^2 + 2x + 1', '3x + 1', 1"
29+
)
30+
fun testFit(sourcePolynomial: String, expectedPolynomial: String, order: Int) {
31+
val sourcePoly = Polynomial.of(sourcePolynomial)
32+
val points = (0..10).map { x -> Vector2(x.toFloat(), sourcePoly.evaluate(x.toFloat())) }
33+
34+
val regression = NewtonPolynomialRegression(order)
35+
val polynomial = regression.fit(points)
36+
37+
assertEquals(Polynomial.of(expectedPolynomial), polynomial)
38+
}
39+
}

0 commit comments

Comments
 (0)