@@ -45,9 +45,7 @@ def get_local_mesh(FFT, L):
45
45
"""Returns local mesh."""
46
46
X = np .ogrid [FFT .local_slice (False )]
47
47
N = FFT .global_shape ()
48
- for i in range (len (N )):
49
- X [i ] = (X [i ]* L [i ]/ N [i ])
50
- X = [np .broadcast_to (x , FFT .shape (False )) for x in X ]
48
+ X = [np .broadcast_to (x * L [i ]/ N [i ], FFT .shape (False )) for i , x in enumerate (X )]
51
49
return X
52
50
53
51
def get_local_wavenumbermesh (FFT , L ):
@@ -60,9 +58,7 @@ def get_local_wavenumbermesh(FFT, L):
60
58
K = [ki [si ] for ki , si in zip (k , s )]
61
59
Ks = np .meshgrid (* K , indexing = 'ij' , sparse = True )
62
60
Lp = 2 * np .pi / L
63
- for i in range (3 ):
64
- Ks [i ] = (Ks [i ]* Lp [i ]).astype (float )
65
- return [np .broadcast_to (k , FFT .shape (True )) for k in Ks ]
61
+ return [np .broadcast_to (k * Lp [i ], FFT .shape (True )) for i , k in enumerate (Ks )]
66
62
67
63
X = get_local_mesh (FFT , L )
68
64
K = get_local_wavenumbermesh (FFT , L )
@@ -131,3 +127,5 @@ def compute_rhs(rhs):
131
127
if MPI .COMM_WORLD .Get_rank () == 0 :
132
128
print ('Time = {}' .format (time ()- t0 ))
133
129
assert round (float (k ) - 0.124953117517 , 7 ) == 0
130
+
131
+ FFT .destroy ()
0 commit comments