Skip to content

Commit a935132

Browse files
committed
wave equation: add partial gpu support, warm up jit cache
1 parent 763ee57 commit a935132

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

Diff for: examples/wave_equation.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,10 @@ def ind_arr(shape, columns=False):
126126
T_shape = (nx, ny)
127127
U_shape = (nx + 1, ny)
128128
V_shape = (nx, ny + 1)
129+
sync()
129130
x_t_2d = xmin + ind_arr(T_shape, True) * dx + dx / 2
130131
y_t_2d = ymin + ind_arr(T_shape) * dy + dy / 2
132+
sync()
131133

132134
dofs_T = int(numpy.prod(numpy.asarray(T_shape)))
133135
dofs_U = int(numpy.prod(numpy.asarray(U_shape)))
@@ -151,6 +153,8 @@ def ind_arr(shape, columns=False):
151153
u2 = create_full(U_shape, 0.0, dtype)
152154
v2 = create_full(V_shape, 0.0, dtype)
153155

156+
sync()
157+
154158
def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
155159
"""
156160
Exact solution for elevation field.
@@ -224,7 +228,7 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
224228
sync()
225229

226230
# initial solution
227-
e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
231+
e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly).to_device(device)
228232
u[:, :] = create_full(U_shape, 0.0, dtype)
229233
v[:, :] = create_full(V_shape, 0.0, dtype)
230234
sync()
@@ -240,9 +244,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
240244
t = i * dt
241245

242246
if t >= next_t_export - 1e-8:
243-
_elev_max = np.max(e, all_axes)
244-
_u_max = np.max(u, all_axes)
245-
_total_v = np.sum(e + h, all_axes)
247+
if device:
248+
# FIXME gpu.memcpy to host requires identity layout
249+
# FIXME reduction on gpu
250+
# e_host = e.to_device()
251+
# u_host = u.to_device()
252+
# h_host = h.to_device()
253+
# _elev_max = np.max(e_host, all_axes)
254+
# _u_max = np.max(u_host, all_axes)
255+
# _total_v = np.sum(e_host + h, all_axes)
256+
_elev_max = 0
257+
_u_max = 0
258+
_total_v = 0
259+
else:
260+
_elev_max = np.max(e, all_axes)
261+
_u_max = np.max(u, all_axes)
262+
_total_v = np.sum(e + h, all_axes)
246263

247264
elev_max = float(_elev_max)
248265
u_max = float(_u_max)
@@ -277,12 +294,19 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
277294
duration = time_mod.perf_counter() - tic
278295
info(f"Duration: {duration:.2f} s")
279296

280-
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
281-
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
282-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
297+
if device:
298+
# FIXME gpu.memcpy to host requires identity layout
299+
# FIXME reduction on gpu
300+
# err2_host = err2.to_device()
301+
# err_L2 = math.sqrt(float(np.sum(err2_host, all_axes)))
302+
err_L2 = 0
303+
else:
304+
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
305+
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
306+
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
283307
info(f"L2 error: {err_L2:7.5e}")
284308

285-
if nx == 128 and ny == 128 and not benchmark_mode:
309+
if nx == 128 and ny == 128 and not benchmark_mode and not device:
286310
if datatype == "f32":
287311
assert numpy.allclose(err_L2, 7.2235471e-03, rtol=1e-4)
288312
else:

0 commit comments

Comments
 (0)