Skip to content

Commit c195a74

Browse files
committed
wave equation: add gpu reductions
1 parent ad5c58d commit c195a74

File tree

1 file changed

+14
-26
lines changed

1 file changed

+14
-26
lines changed

examples/wave_equation.py

+14-26
Original file line numberDiff line numberDiff line change
@@ -244,22 +244,15 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
244244
t = i * dt
245245

246246
if t >= next_t_export - 1e-8:
247-
if device:
248-
# FIXME gpu.memcpy to host requires identity layout
249-
# FIXME reduction on gpu
250-
# e_host = e.to_device()
251-
# u_host = u.to_device()
252-
# h_host = h.to_device()
253-
# _elev_max = np.max(e_host, all_axes)
254-
# _u_max = np.max(u_host, all_axes)
255-
# _total_v = np.sum(e_host + h, all_axes)
256-
_elev_max = 0
257-
_u_max = 0
258-
_total_v = 0
259-
else:
260-
_elev_max = np.max(e, all_axes)
261-
_u_max = np.max(u, all_axes)
262-
_total_v = np.sum(e + h, all_axes)
247+
sync()
248+
H_tmp = e + h
249+
sync()
250+
_elev_max = np.max(e, all_axes).to_device()
251+
# NOTE max(u) segfaults, shape (n+1, n) too large for tiling
252+
_u_max = np.max(u[1:, :], all_axes).to_device()
253+
_total_v = np.sum(H_tmp, all_axes).to_device()
254+
# NOTE this segfaults
255+
# _total_v = np.sum(e + h, all_axes).to_device() # segfaults
263256

264257
elev_max = float(_elev_max)
265258
u_max = float(_u_max)
@@ -294,16 +287,11 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
294287
duration = time_mod.perf_counter() - tic
295288
info(f"Duration: {duration:.2f} s")
296289

297-
if device:
298-
# FIXME gpu.memcpy to host requires identity layout
299-
# FIXME reduction on gpu
300-
# err2_host = err2.to_device()
301-
# err_L2 = math.sqrt(float(np.sum(err2_host, all_axes)))
302-
err_L2 = 0
303-
else:
304-
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
305-
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
306-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
290+
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly).to_device(device)
291+
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
292+
err2sum = np.sum(err2, all_axes).to_device()
293+
sync()
294+
err_L2 = math.sqrt(float(err2sum))
307295
info(f"L2 error: {err_L2:7.5e}")
308296

309297
if nx == 128 and ny == 128 and not benchmark_mode and not device:

0 commit comments

Comments
 (0)