@@ -34,9 +34,7 @@ def spherical_pw(N, k, r, setup):
34
34
n = np .arange (N + 1 )
35
35
36
36
bn = weights (N , kr , setup )
37
- for i , x in enumerate (kr ):
38
- bn [i , :] = bn [i , :] * 4 * np .pi * (1j )** n
39
- return bn
37
+ return 4 * np .pi * (1j )** n * bn
40
38
41
39
42
40
def spherical_ps (N , k , r , rs , setup ):
@@ -72,11 +70,14 @@ def spherical_ps(N, k, r, rs, setup):
72
70
n = np .arange (N + 1 )
73
71
74
72
bn = weights (N , k * r , setup )
73
+ if len (k ) == 1 :
74
+ bn = bn [np .newaxis , :]
75
+
75
76
for i , x in enumerate (krs ):
76
77
hn = special .spherical_jn (n , x ) - 1j * special .spherical_yn (n , x )
77
78
bn [i , :] = bn [i , :] * 4 * np .pi * (- 1j ) * hn * k [i ]
78
79
79
- return bn
80
+ return np . squeeze ( bn )
80
81
81
82
82
83
def weights (N , kr , setup ):
@@ -106,6 +107,7 @@ def weights(N, kr, setup):
106
107
Radial weights for all orders up to N and the given wavenumbers.
107
108
108
109
"""
110
+ kr = util .asarray_1d (kr )
109
111
n = np .arange (N + 1 )
110
112
bns = np .zeros ((len (kr ), N + 1 ), dtype = complex )
111
113
for i , x in enumerate (kr ):
@@ -259,9 +261,7 @@ def circular_pw(N, k, r, setup):
259
261
n = np .roll (np .arange (- N , N + 1 ), - N )
260
262
261
263
bn = circ_radial_weights (N , kr , setup )
262
- for i , x in enumerate (kr ):
263
- bn [i , :] = bn [i , :] * (1j )** (n )
264
- return bn
264
+ return (1j )** (n ) * bn
265
265
266
266
267
267
def circular_ls (N , k , r , rs , setup ):
@@ -297,10 +297,12 @@ def circular_ls(N, k, r, rs, setup):
297
297
n = np .roll (np .arange (- N , N + 1 ), - N )
298
298
299
299
bn = circ_radial_weights (N , k * r , setup )
300
+ if len (k ) == 1 :
301
+ bn = bn [np .newaxis , :]
300
302
for i , x in enumerate (krs ):
301
303
Hn = special .hankel2 (n , x )
302
304
bn [i , :] = bn [i , :] * - 1j / 4 * Hn
303
- return bn
305
+ return np . squeeze ( bn )
304
306
305
307
306
308
def circ_radial_weights (N , kr , setup ):
@@ -329,6 +331,7 @@ def circ_radial_weights(N, kr, setup):
329
331
Radial weights for all orders up to N and the given wavenumbers.
330
332
331
333
"""
334
+ kr = util .asarray_1d (kr )
332
335
n = np .arange (N + 1 )
333
336
Bns = np .zeros ((len (kr ), N + 1 ), dtype = complex )
334
337
for i , x in enumerate (kr ):
0 commit comments