@@ -99,17 +99,23 @@ def complete_polynomial(z, d):
99
99
"""
100
100
# check inputs
101
101
assert d >= 0 , "d must be non-negative"
102
- z = np .asarray (z )
103
-
104
- # compute inds allocate space for output
105
- nvar , nobs = z .shape
106
- out = np .zeros ((n_complete (nvar , d ), nobs ))
107
-
108
102
if d > 5 :
109
103
raise ValueError ("Complete polynomial only implemeted up to degree 5" )
110
104
111
- # populate out with jitted function
112
- _complete_poly_impl (z , d , out )
105
+ # Assure z is array
106
+ z = np .asarray (z )
107
+
108
+ # compute inds allocate space for output
109
+ if np .ndim (z ) == 1 :
110
+ nvar = z .size
111
+ out = np .zeros (n_complete (nvar , d ))
112
+ # populate out with jitted function
113
+ _complete_poly_impl_vec (z , d , out )
114
+ else :
115
+ nvar , nobs = z .shape
116
+ out = np .zeros ((n_complete (nvar , d ), nobs ))
117
+ # populate out with jitted function
118
+ _complete_poly_impl (z , d , out )
113
119
114
120
return out
115
121
@@ -313,18 +319,25 @@ def complete_polynomial_der(z, d, der):
313
319
# check inputs
314
320
assert d >= 0 , "d must be non-negative"
315
321
assert der >= 0 , "derivative must be non-negative"
316
- z = np .asarray (z )
317
-
318
- # compute inds allocate space for output
319
- nvar , nobs = z .shape
320
- assert der < nvar , "derivative integer must be smaller than nobs in z"
321
- out = np .zeros ((n_complete (nvar , d ), nobs ))
322
-
323
322
if d > 5 :
324
323
raise ValueError ("Complete polynomial only implemeted up to degree 5" )
325
324
326
- # populate out with jitted function
327
- _complete_poly_der_impl (z , d , der , out )
325
+ # Ensure z is a numpy array
326
+ z = np .asarray (z )
327
+
328
+ # compute inds allocate space for output
329
+ if np .ndim (z ) == 1 :
330
+ nvar = z .size
331
+ assert der < nvar , "derivative integer must be smaller than nobs in z"
332
+ out = np .zeros (n_complete (nvar , d ))
333
+ # populate with jitted function
334
+ _complete_poly_der_impl_vec (z , d , der , out )
335
+ else :
336
+ nvar , nobs = z .shape
337
+ assert der < nvar , "derivative integer must be smaller than nobs in z"
338
+ out = np .zeros ((n_complete (nvar , d ), nobs ))
339
+ # populate out with jitted function
340
+ _complete_poly_der_impl (z , d , der , out )
328
341
329
342
return out
330
343
@@ -474,6 +487,8 @@ def _complete_poly_der_impl(z, d, der, out):
474
487
for i3 in range (i2 , nvar ):
475
488
ix += 1
476
489
for k in range (nobs ):
490
+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
491
+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
477
492
c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
478
493
out [ix , k ] = c3 * t1 * t2 * t3 * z [der , k ]** (c3 - 1 ) if c3 > 0 else 0.0
479
494
@@ -491,12 +506,17 @@ def _complete_poly_der_impl(z, d, der, out):
491
506
for i3 in range (i2 , nvar ):
492
507
ix += 1
493
508
for k in range (nobs ):
509
+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
510
+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
494
511
c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
495
512
out [ix , k ] = c3 * t1 * t2 * t3 * z [der , k ]** (c3 - 1 ) if c3 > 0 else 0.0
496
513
497
514
for i4 in range (i3 , nvar ):
498
515
ix += 1
499
516
for k in range (nobs ):
517
+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
518
+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
519
+ c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
500
520
c4 , t4 = (c3 + 1 , 1.0 ) if i4 == der else (c3 , z [i4 , k ])
501
521
out [ix , k ] = c4 * t1 * t2 * t3 * t4 * z [der , k ]** (c4 - 1 ) if c4 > 0 else 0.0
502
522
@@ -514,18 +534,27 @@ def _complete_poly_der_impl(z, d, der, out):
514
534
for i3 in range (i2 , nvar ):
515
535
ix += 1
516
536
for k in range (nobs ):
537
+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
538
+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
517
539
c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
518
540
out [ix , k ] = c3 * t1 * t2 * t3 * z [der , k ]** (c3 - 1 ) if c3 > 0 else 0.0
519
541
520
542
for i4 in range (i3 , nvar ):
521
543
ix += 1
522
544
for k in range (nobs ):
545
+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
546
+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
547
+ c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
523
548
c4 , t4 = (c3 + 1 , 1.0 ) if i4 == der else (c3 , z [i4 , k ])
524
549
out [ix , k ] = c4 * t1 * t2 * t3 * t4 * z [der , k ]** (c4 - 1 ) if c4 > 0 else 0.0
525
550
526
551
for i5 in range (i4 , nvar ):
527
552
ix += 1
528
553
for k in range (nobs ):
554
+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
555
+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
556
+ c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
557
+ c4 , t4 = (c3 + 1 , 1.0 ) if i4 == der else (c3 , z [i4 , k ])
529
558
c5 , t5 = (c4 + 1 , 1.0 ) if i5 == der else (c4 , z [i5 , k ])
530
559
out [ix , k ] = c5 * t1 * t2 * t3 * t4 * t5 * z [der , k ]** (c5 - 1 ) if c5 > 0 else 0.0
531
560
0 commit comments