Skip to content

Commit 7fb1742

Browse files
committed
shallow-water: remove fromfunction initialization
1 parent 0d34e27 commit 7fb1742

File tree

1 file changed

+25
-26
lines changed

1 file changed

+25
-26
lines changed

Diff for: examples/shallow_water.py

+25-26
Original file line numberDiff line numberDiff line change
@@ -54,25 +54,26 @@ def run(n, backend, datatype, benchmark_mode):
5454
if backend == "sharpy":
5555
import sharpy as np
5656
from sharpy import fini, init, sync
57-
from sharpy.numpy import fromfunction as _fromfunction
5857

5958
device = os.getenv("SHARPY_DEVICE", "")
6059
create_full = partial(np.full, device=device)
61-
fromfunction = partial(_fromfunction, device=device)
60+
61+
def transpose(a):
62+
return np.permute_dims(a, [1, 0])
6263

6364
all_axes = [0, 1]
6465
init(False)
6566

6667
elif backend == "numpy":
6768
import numpy as np
68-
from numpy import fromfunction
6969

7070
if comm is not None:
7171
assert (
7272
comm.Get_size() == 1
7373
), "Numpy backend only supports serial execution."
7474

7575
create_full = np.full
76+
transpose = np.transpose
7677

7778
fini = sync = lambda x=None: None
7879
all_axes = None
@@ -110,34 +111,32 @@ def run(n, backend, datatype, benchmark_mode):
110111
t_export = 0.02
111112
t_end = 1.0
112113

113-
# coordinate arrays
114-
x_t_2d = fromfunction(
115-
lambda i, j: xmin + i * dx + dx / 2,
116-
(nx, ny),
117-
dtype=dtype,
118-
)
119-
y_t_2d = fromfunction(
120-
lambda i, j: ymin + j * dy + dy / 2,
121-
(nx, ny),
122-
dtype=dtype,
123-
)
124-
x_u_2d = fromfunction(lambda i, j: xmin + i * dx, (nx + 1, ny), dtype=dtype)
125-
y_u_2d = fromfunction(
126-
lambda i, j: ymin + j * dy + dy / 2,
127-
(nx + 1, ny),
128-
dtype=dtype,
129-
)
130-
x_v_2d = fromfunction(
131-
lambda i, j: xmin + i * dx + dx / 2,
132-
(nx, ny + 1),
133-
dtype=dtype,
134-
)
135-
y_v_2d = fromfunction(lambda i, j: ymin + j * dy, (nx, ny + 1), dtype=dtype)
114+
def ind_arr(shape, columns=False):
115+
"""Construct an (nx, ny) array where each row/col is an arange"""
116+
nx, ny = shape
117+
if columns:
118+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % nx
119+
ind = transpose(np.reshape(ind, (ny, nx)))
120+
else:
121+
ind = np.arange(0, nx * ny, 1, dtype=np.int32) % ny
122+
ind = np.reshape(ind, (nx, ny))
123+
return ind.astype(dtype)
136124

125+
# coordinate arrays
137126
T_shape = (nx, ny)
138127
U_shape = (nx + 1, ny)
139128
V_shape = (nx, ny + 1)
140129
F_shape = (nx + 1, ny + 1)
130+
sync()
131+
x_t_2d = xmin + ind_arr(T_shape, True) * dx + dx / 2
132+
y_t_2d = ymin + ind_arr(T_shape) * dy + dy / 2
133+
134+
x_u_2d = xmin + ind_arr(U_shape, True) * dx
135+
y_u_2d = ymin + ind_arr(U_shape) * dy + dy / 2
136+
137+
x_v_2d = xmin + ind_arr(V_shape, True) * dx + dx / 2
138+
y_v_2d = ymin + ind_arr(V_shape) * dy
139+
sync()
141140

142141
dofs_T = int(numpy.prod(numpy.asarray(T_shape)))
143142
dofs_U = int(numpy.prod(numpy.asarray(U_shape)))

0 commit comments

Comments
 (0)