Skip to content

Commit 2da8b03

Browse files
committed
shallow-water: remove fromfunction initialization
1 parent 936d519 commit 2da8b03

File tree

1 file changed

+22
-26
lines changed

1 file changed

+22
-26
lines changed

examples/shallow_water.py

+22-26
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,15 @@ def run(n, backend, datatype, benchmark_mode):
5454
if backend == "sharpy":
5555
import sharpy as np
5656
from sharpy import fini, init, sync
57-
from sharpy.numpy import fromfunction as _fromfunction
5857

5958
device = os.getenv("SHARPY_DEVICE", "")
6059
create_full = partial(np.full, device=device)
61-
fromfunction = partial(_fromfunction, device=device)
6260

6361
all_axes = [0, 1]
6462
init(False)
6563

6664
elif backend == "numpy":
6765
import numpy as np
68-
from numpy import fromfunction
6966

7067
if comm is not None:
7168
assert (
@@ -110,34 +107,33 @@ def run(n, backend, datatype, benchmark_mode):
110107
t_export = 0.02
111108
t_end = 1.0
112109

113-
# coordinate arrays
114-
x_t_2d = fromfunction(
115-
lambda i, j: xmin + i * dx + dx / 2,
116-
(nx, ny),
117-
dtype=dtype,
118-
)
119-
y_t_2d = fromfunction(
120-
lambda i, j: ymin + j * dy + dy / 2,
121-
(nx, ny),
122-
dtype=dtype,
123-
)
124-
x_u_2d = fromfunction(lambda i, j: xmin + i * dx, (nx + 1, ny), dtype=dtype)
125-
y_u_2d = fromfunction(
126-
lambda i, j: ymin + j * dy + dy / 2,
127-
(nx + 1, ny),
128-
dtype=dtype,
129-
)
130-
x_v_2d = fromfunction(
131-
lambda i, j: xmin + i * dx + dx / 2,
132-
(nx, ny + 1),
133-
dtype=dtype,
134-
)
135-
y_v_2d = fromfunction(lambda i, j: ymin + j * dy, (nx, ny + 1), dtype=dtype)
110+
def ind_arr(shape, columns=False):
111+
"""Construct an (nx, ny) array where each row/col is an arange"""
112+
nx, ny = shape
113+
if columns:
114+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % nx
115+
ind = np.reshape(ind, (ny, nx))
116+
ind = np.permute_dims(ind, [1, 0])
117+
else:
118+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % ny
119+
ind = np.reshape(ind, (nx, ny))
120+
return ind.astype(dtype)
136121

122+
# coordinate arrays
137123
T_shape = (nx, ny)
138124
U_shape = (nx + 1, ny)
139125
V_shape = (nx, ny + 1)
140126
F_shape = (nx + 1, ny + 1)
127+
sync()
128+
x_t_2d = xmin + ind_arr(T_shape, True) * dx + dx / 2
129+
y_t_2d = ymin + ind_arr(T_shape) * dy + dy / 2
130+
131+
x_u_2d = xmin + ind_arr(U_shape, True) * dx
132+
y_u_2d = ymin + ind_arr(U_shape) * dy + dy / 2
133+
134+
x_v_2d = xmin + ind_arr(V_shape, True) * dx + dx / 2
135+
y_v_2d = ymin + ind_arr(V_shape) * dy
136+
sync()
141137

142138
dofs_T = int(numpy.prod(numpy.asarray(T_shape)))
143139
dofs_U = int(numpy.prod(numpy.asarray(U_shape)))

0 commit comments

Comments
 (0)