Skip to content

Commit 7dfaaa3

Browse files
committed
shallow water: add partial gpu support, warm up jit cache
1 parent a935132 commit 7dfaaa3

File tree

1 file changed

+77
-59
lines changed

1 file changed

+77
-59
lines changed

Diff for: examples/shallow_water.py

+77-59
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def ind_arr(shape, columns=False):
156156
q = create_full(F_shape, 0.0, dtype)
157157

158158
# bathymetry
159-
h = create_full(T_shape, 0.0, dtype)
159+
h = create_full(T_shape, 1.0, dtype) # HACK init with 1
160160

161161
hu = create_full(U_shape, 0.0, dtype)
162162
hv = create_full(V_shape, 0.0, dtype)
@@ -205,13 +205,15 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
205205
return bath * create_full(T_shape, 1.0, dtype)
206206

207207
# set bathymetry
208-
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly)
208+
# h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device)
209209
# steady state potential energy
210-
pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
210+
# pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
211+
pe_offset = 0.5 * g * float(1.0) / nx / ny
211212

212213
# compute time step
213214
alpha = 0.5
214-
h_max = float(np.max(h, all_axes))
215+
# h_max = float(np.max(h, all_axes))
216+
h_max = float(1.0)
215217
c = (g * h_max) ** 0.5
216218
dt = alpha * dx / c
217219
dt = t_export / int(math.ceil(t_export / dt))
@@ -328,9 +330,9 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
328330
u0, v0, e0 = exact_solution(
329331
0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
330332
)
331-
e[:, :] = e0
332-
u[:, :] = u0
333-
v[:, :] = v0
333+
e[:, :] = e0.to_device(device)
334+
u[:, :] = u0.to_device(device)
335+
v[:, :] = v0.to_device(device)
334336

335337
t = 0
336338
i_export = 0
@@ -344,30 +346,41 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
344346
t = i * dt
345347

346348
if t >= next_t_export - 1e-8:
347-
_elev_max = np.max(e, all_axes)
348-
_u_max = np.max(u, all_axes)
349-
_q_max = np.max(q, all_axes)
350-
_total_v = np.sum(e + h, all_axes)
351-
352-
# potential energy
353-
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
354-
_total_pe = np.sum(_pe, all_axes)
355-
356-
# kinetic energy
357-
u2 = u * u
358-
v2 = v * v
359-
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
360-
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
361-
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
362-
_total_ke = np.sum(_ke, all_axes)
363-
364-
total_pe = float(_total_pe) * dx * dy
365-
total_ke = float(_total_ke) * dx * dy
366-
total_e = total_ke + total_pe
367-
elev_max = float(_elev_max)
368-
u_max = float(_u_max)
369-
q_max = float(_q_max)
370-
total_v = float(_total_v) * dx * dy
349+
if device:
350+
# FIXME gpu.memcpy to host requires identity layout
351+
# FIXME reduction on gpu
352+
elev_max = 0
353+
u_max = 0
354+
q_max = 0
355+
diff_e = 0
356+
diff_v = 0
357+
total_pe = 0
358+
total_ke = 0
359+
else:
360+
_elev_max = np.max(e, all_axes)
361+
_u_max = np.max(u, all_axes)
362+
_q_max = np.max(q, all_axes)
363+
_total_v = np.sum(e + h, all_axes)
364+
365+
# potential energy
366+
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
367+
_total_pe = np.sum(_pe, all_axes)
368+
369+
# kinetic energy
370+
u2 = u * u
371+
v2 = v * v
372+
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
373+
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
374+
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
375+
_total_ke = np.sum(_ke, all_axes)
376+
377+
total_pe = float(_total_pe) * dx * dy
378+
total_ke = float(_total_ke) * dx * dy
379+
total_e = total_ke + total_pe
380+
elev_max = float(_elev_max)
381+
u_max = float(_u_max)
382+
q_max = float(_q_max)
383+
total_v = float(_total_v) * dx * dy
371384

372385
if i_export == 0:
373386
initial_v = total_v
@@ -402,35 +415,40 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
402415
duration = time_mod.perf_counter() - tic
403416
info(f"Duration: {duration:.2f} s")
404417

405-
e_exact = exact_solution(t, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d)[
406-
2
407-
]
408-
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
409-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
410-
info(f"L2 error: {err_L2:7.15e}")
411-
412-
if nx < 128 or ny < 128:
413-
info("Skipping correctness test due to small problem size.")
414-
elif not benchmark_mode:
415-
tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
416-
assert (
417-
diff_e < tolerance_ene
418-
), f"Energy error exceeds tolerance: {diff_e} > {tolerance_ene}"
419-
if nx == 128 and ny == 128:
420-
if datatype == "f32":
421-
assert numpy.allclose(
422-
err_L2, 4.3127859e-05, rtol=1e-5
423-
), "L2 error does not match"
424-
else:
425-
assert numpy.allclose(
426-
err_L2, 4.315799035627906e-05
427-
), "L2 error does not match"
428-
else:
429-
tolerance_l2 = 1e-4
418+
if device:
419+
# FIXME gpu.memcpy to host requires identity layout
420+
# FIXME reduction on gpu
421+
pass
422+
else:
423+
e_exact = exact_solution(
424+
t, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
425+
)[2]
426+
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
427+
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
428+
info(f"L2 error: {err_L2:7.15e}")
429+
430+
if nx < 128 or ny < 128:
431+
info("Skipping correctness test due to small problem size.")
432+
elif not benchmark_mode:
433+
tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
430434
assert (
431-
err_L2 < tolerance_l2
432-
), f"L2 error exceeds tolerance: {err_L2} > {tolerance_l2}"
433-
info("SUCCESS")
435+
diff_e < tolerance_ene
436+
), f"Energy error exceeds tolerance: {diff_e} > {tolerance_ene}"
437+
if nx == 128 and ny == 128:
438+
if datatype == "f32":
439+
assert numpy.allclose(
440+
err_L2, 4.3127859e-05, rtol=1e-5
441+
), "L2 error does not match"
442+
else:
443+
assert numpy.allclose(
444+
err_L2, 4.315799035627906e-05
445+
), "L2 error does not match"
446+
else:
447+
tolerance_l2 = 1e-4
448+
assert (
449+
err_L2 < tolerance_l2
450+
), f"L2 error exceeds tolerance: {err_L2} > {tolerance_l2}"
451+
info("SUCCESS")
434452

435453
fini()
436454

0 commit comments

Comments
 (0)