Skip to content

Commit 936d519

Browse files
committed
wave-equation: remove fromfunction initialization
1 parent 4520d97 commit 936d519

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

examples/wave_equation.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,15 @@ def run(n, backend, datatype, benchmark_mode):
5454
if backend == "sharpy":
5555
import sharpy as np
5656
from sharpy import fini, init, sync
57-
from sharpy.numpy import fromfunction as _fromfunction
5857

5958
device = os.getenv("SHARPY_DEVICE", "")
6059
create_full = partial(np.full, device=device)
61-
fromfunction = partial(_fromfunction, device=device)
6260

6361
all_axes = [0, 1]
6462
init(False)
6563

6664
elif backend == "numpy":
6765
import numpy as np
68-
from numpy import fromfunction
6966

7067
if comm is not None:
7168
assert (
@@ -110,17 +107,24 @@ def run(n, backend, datatype, benchmark_mode):
110107
t_export = 0.02
111108
t_end = 1.0
112109

113-
# coordinate arrays
114-
x_t_2d = fromfunction(
115-
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype
116-
)
117-
y_t_2d = fromfunction(
118-
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype
119-
)
110+
def ind_arr(shape, columns=False):
111+
"""Construct an (nx, ny) array where each row/col is an arange"""
112+
nx, ny = shape
113+
if columns:
114+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % nx
115+
ind = np.reshape(ind, (ny, nx))
116+
ind = np.permute_dims(ind, [1, 0])
117+
else:
118+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % ny
119+
ind = np.reshape(ind, (nx, ny))
120+
return ind.astype(dtype)
120121

122+
# coordinate arrays
121123
T_shape = (nx, ny)
122124
U_shape = (nx + 1, ny)
123125
V_shape = (nx, ny + 1)
126+
x_t_2d = xmin + ind_arr(T_shape, True) * dx + dx / 2
127+
y_t_2d = ymin + ind_arr(T_shape) * dy + dy / 2
124128

125129
dofs_T = int(numpy.prod(numpy.asarray(T_shape)))
126130
dofs_U = int(numpy.prod(numpy.asarray(U_shape)))

0 commit comments

Comments
 (0)