Skip to content

Commit c75de41

Browse files
committed
Fix derivative of complete polynomial and add additional test
1 parent 736fef7 commit c75de41

File tree

2 files changed

+68
-21
lines changed

2 files changed

+68
-21
lines changed

interpolation/complete_poly.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,23 @@ def complete_polynomial(z, d):
9999
"""
100100
# check inputs
101101
assert d >= 0, "d must be non-negative"
102-
z = np.asarray(z)
103-
104-
# compute inds allocate space for output
105-
nvar, nobs = z.shape
106-
out = np.zeros((n_complete(nvar, d), nobs))
107-
108102
if d > 5:
109103
raise ValueError("Complete polynomial only implemeted up to degree 5")
110104

111-
# populate out with jitted function
112-
_complete_poly_impl(z, d, out)
105+
# Assure z is array
106+
z = np.asarray(z)
107+
108+
# compute inds allocate space for output
109+
if np.ndim(z) == 1:
110+
nvar = z.size
111+
out = np.zeros(n_complete(nvar, d))
112+
# populate out with jitted function
113+
_complete_poly_impl_vec(z, d, out)
114+
else:
115+
nvar, nobs = z.shape
116+
out = np.zeros((n_complete(nvar, d), nobs))
117+
# populate out with jitted function
118+
_complete_poly_impl(z, d, out)
113119

114120
return out
115121

@@ -313,18 +319,25 @@ def complete_polynomial_der(z, d, der):
313319
# check inputs
314320
assert d >= 0, "d must be non-negative"
315321
assert der >= 0, "derivative must be non-negative"
316-
z = np.asarray(z)
317-
318-
# compute inds allocate space for output
319-
nvar, nobs = z.shape
320-
assert der < nvar, "derivative integer must be smaller than nobs in z"
321-
out = np.zeros((n_complete(nvar, d), nobs))
322-
323322
if d > 5:
324323
raise ValueError("Complete polynomial only implemeted up to degree 5")
325324

326-
# populate out with jitted function
327-
_complete_poly_der_impl(z, d, der, out)
325+
# Ensure z is a numpy array
326+
z = np.asarray(z)
327+
328+
# compute inds allocate space for output
329+
if np.ndim(z) == 1:
330+
nvar = z.size
331+
assert der < nvar, "derivative integer must be smaller than nobs in z"
332+
out = np.zeros(n_complete(nvar, d))
333+
# populate with jitted function
334+
_complete_poly_der_impl_vec(z, d, der, out)
335+
else:
336+
nvar, nobs = z.shape
337+
assert der < nvar, "derivative integer must be smaller than nobs in z"
338+
out = np.zeros((n_complete(nvar, d), nobs))
339+
# populate out with jitted function
340+
_complete_poly_der_impl(z, d, der, out)
328341

329342
return out
330343

@@ -474,6 +487,8 @@ def _complete_poly_der_impl(z, d, der, out):
474487
for i3 in range(i2, nvar):
475488
ix += 1
476489
for k in range(nobs):
490+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
491+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
477492
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
478493
out[ix, k] = c3*t1*t2*t3*z[der, k]**(c3-1) if c3>0 else 0.0
479494

@@ -491,12 +506,17 @@ def _complete_poly_der_impl(z, d, der, out):
491506
for i3 in range(i2, nvar):
492507
ix += 1
493508
for k in range(nobs):
509+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
510+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
494511
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
495512
out[ix, k] = c3*t1*t2*t3*z[der, k]**(c3-1) if c3>0 else 0.0
496513

497514
for i4 in range(i3, nvar):
498515
ix += 1
499516
for k in range(nobs):
517+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
518+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
519+
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
500520
c4, t4 = (c3+1, 1.0) if i4==der else (c3, z[i4, k])
501521
out[ix, k] = c4*t1*t2*t3*t4*z[der, k]**(c4-1) if c4>0 else 0.0
502522

@@ -514,18 +534,27 @@ def _complete_poly_der_impl(z, d, der, out):
514534
for i3 in range(i2, nvar):
515535
ix += 1
516536
for k in range(nobs):
537+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
538+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
517539
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
518540
out[ix, k] = c3*t1*t2*t3*z[der, k]**(c3-1) if c3>0 else 0.0
519541

520542
for i4 in range(i3, nvar):
521543
ix += 1
522544
for k in range(nobs):
545+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
546+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
547+
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
523548
c4, t4 = (c3+1, 1.0) if i4==der else (c3, z[i4, k])
524549
out[ix, k] = c4*t1*t2*t3*t4*z[der, k]**(c4-1) if c4>0 else 0.0
525550

526551
for i5 in range(i4, nvar):
527552
ix += 1
528553
for k in range(nobs):
554+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
555+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
556+
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
557+
c4, t4 = (c3+1, 1.0) if i4==der else (c3, z[i4, k])
529558
c5, t5 = (c4+1, 1.0) if i5==der else (c4, z[i5, k])
530559
out[ix, k] = c5*t1*t2*t3*t4*t5*z[der, k]**(c5-1) if c5>0 else 0.0
531560

interpolation/tests/test_complete.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ def f2(x, y): return x**3 - y
5050

5151
def test_complete_derivative():
5252

53-
# TODO: Currently if z has a 0 value then it breaks because occasionally
54-
# tries to raise 0 to a negative power -- This can be fixed by
55-
# checking whether coefficient is 0 before trying to do anything...
56-
5753
# Test derivative vector
5854
z = np.array([1, 2, 3])
5955
sol_vec = np.array([0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 3.0, 0.0, 0.0, 0.0])
@@ -68,9 +64,31 @@ def test_complete_derivative():
6864
assert(abs(out_mat[2, :] - np.ones(2)).max() < 1e-10)
6965
assert(abs(out_mat[-2, :] - np.array([5.0, 6.0])).max() < 1e-10)
7066

67+
def test_complete_vec_vs_mat():
68+
# Matrix for allocation
69+
temp = np.ones(n_complete(2, 3))*5.0
70+
temp_mat = np.ones((n_complete(2, 3), 3))
71+
72+
# Point at which to evaluate
73+
z = np.array([0.9, 1.05])
74+
z_mat = np.array([[0.9, 0.95, 1.0], [1.05, 1.0, 0.95]])
75+
76+
foo = complete_polynomial(z, 2)
77+
bar = complete_polynomial(z_mat, 2)[:, 0]
78+
assert np.allclose(foo, bar)
79+
80+
foo = complete_polynomial_der(z, 2, 0)
81+
bar = complete_polynomial_der(z_mat, 2, 0)[:, 0]
82+
assert np.allclose(foo, bar)
83+
84+
foo = complete_polynomial_der(z, 4, 0)
85+
bar = complete_polynomial_der(z_mat, 4, 0)[:, 0]
86+
assert np.allclose(foo, bar)
87+
7188

7289
if __name__ == '__main__':
7390
test_complete_scalar()
7491
test_complete_vector()
7592
test_complete_derivative()
93+
test_complete_vec_vs_mat()
7694

0 commit comments

Comments
 (0)