Skip to content

Commit 74cae69

Browse files
committed
shallow water: add partial gpu support, warm up jit cache
1 parent 0c674f4 commit 74cae69

File tree

1 file changed

+98
-72
lines changed

1 file changed

+98
-72
lines changed

examples/shallow_water.py

+98-72
Original file line numberDiff line numberDiff line change
@@ -111,28 +111,32 @@ def run(n, backend, datatype, benchmark_mode):
111111
t_end = 1.0
112112

113113
# coordinate arrays
114+
sync()
114115
x_t_2d = fromfunction(
115-
lambda i, j: xmin + i * dx + dx / 2,
116-
(nx, ny),
117-
dtype=dtype,
116+
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype, device=""
118117
)
119118
y_t_2d = fromfunction(
120-
lambda i, j: ymin + j * dy + dy / 2,
121-
(nx, ny),
122-
dtype=dtype,
119+
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype, device=""
120+
)
121+
x_u_2d = fromfunction(
122+
lambda i, j: xmin + i * dx, (nx + 1, ny), dtype=dtype, device=""
123123
)
124-
x_u_2d = fromfunction(lambda i, j: xmin + i * dx, (nx + 1, ny), dtype=dtype)
125124
y_u_2d = fromfunction(
126125
lambda i, j: ymin + j * dy + dy / 2,
127126
(nx + 1, ny),
128127
dtype=dtype,
128+
device="",
129129
)
130130
x_v_2d = fromfunction(
131131
lambda i, j: xmin + i * dx + dx / 2,
132132
(nx, ny + 1),
133133
dtype=dtype,
134+
device="",
134135
)
135-
y_v_2d = fromfunction(lambda i, j: ymin + j * dy, (nx, ny + 1), dtype=dtype)
136+
y_v_2d = fromfunction(
137+
lambda i, j: ymin + j * dy, (nx, ny + 1), dtype=dtype, device=""
138+
)
139+
sync()
136140

137141
T_shape = (nx, ny)
138142
U_shape = (nx + 1, ny)
@@ -157,7 +161,7 @@ def run(n, backend, datatype, benchmark_mode):
157161
q = create_full(F_shape, 0.0, dtype)
158162

159163
# bathymetry
160-
h = create_full(T_shape, 0.0, dtype)
164+
h = create_full(T_shape, 1.0, dtype) # HACK init with 1
161165

162166
hu = create_full(U_shape, 0.0, dtype)
163167
hv = create_full(V_shape, 0.0, dtype)
@@ -205,22 +209,16 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
205209
bath = 1.0
206210
return bath * create_full(T_shape, 1.0, dtype)
207211

208-
# inital elevation
209-
u0, v0, e0 = exact_solution(
210-
0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
211-
)
212-
e[:, :] = e0
213-
u[:, :] = u0
214-
v[:, :] = v0
215-
216212
# set bathymetry
217-
h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly)
213+
# h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device)
218214
# steady state potential energy
219-
pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
215+
# pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
216+
pe_offset = 0.5 * g * float(1.0) / nx / ny
220217

221218
# compute time step
222219
alpha = 0.5
223-
h_max = float(np.max(h, all_axes))
220+
# h_max = float(np.max(h, all_axes))
221+
h_max = float(1.0)
224222
c = (g * h_max) ** 0.5
225223
dt = alpha * dx / c
226224
dt = t_export / int(math.ceil(t_export / dt))
@@ -329,6 +327,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
329327
v[:, 1:-1] = v[:, 1:-1] / 3.0 + 2.0 / 3.0 * (v2[:, 1:-1] + dt * dvdt)
330328
e[:, :] = e[:, :] / 3.0 + 2.0 / 3.0 * (e2[:, :] + dt * dedt)
331329

330+
# warm jit cache
331+
step(u, v, e, u1, v1, e1, u2, v2, e2)
332+
sync()
333+
334+
# initial solution
335+
u0, v0, e0 = exact_solution(
336+
0, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
337+
)
338+
e[:, :] = e0.to_device(device)
339+
u[:, :] = u0.to_device(device)
340+
v[:, :] = v0.to_device(device)
341+
332342
t = 0
333343
i_export = 0
334344
next_t_export = 0
@@ -341,30 +351,41 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
341351
t = i * dt
342352

343353
if t >= next_t_export - 1e-8:
344-
_elev_max = np.max(e, all_axes)
345-
_u_max = np.max(u, all_axes)
346-
_q_max = np.max(q, all_axes)
347-
_total_v = np.sum(e + h, all_axes)
348-
349-
# potential energy
350-
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
351-
_total_pe = np.sum(_pe, all_axes)
352-
353-
# kinetic energy
354-
u2 = u * u
355-
v2 = v * v
356-
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
357-
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
358-
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
359-
_total_ke = np.sum(_ke, all_axes)
360-
361-
total_pe = float(_total_pe) * dx * dy
362-
total_ke = float(_total_ke) * dx * dy
363-
total_e = total_ke + total_pe
364-
elev_max = float(_elev_max)
365-
u_max = float(_u_max)
366-
q_max = float(_q_max)
367-
total_v = float(_total_v) * dx * dy
354+
if device:
355+
# FIXME gpu.memcpy to host requires identity layout
356+
# FIXME reduction on gpu
357+
elev_max = 0
358+
u_max = 0
359+
q_max = 0
360+
diff_e = 0
361+
diff_v = 0
362+
total_pe = 0
363+
total_ke = 0
364+
else:
365+
_elev_max = np.max(e, all_axes)
366+
_u_max = np.max(u, all_axes)
367+
_q_max = np.max(q, all_axes)
368+
_total_v = np.sum(e + h, all_axes)
369+
370+
# potential energy
371+
_pe = 0.5 * g * (e + h) * (e - h) + pe_offset
372+
_total_pe = np.sum(_pe, all_axes)
373+
374+
# kinetic energy
375+
u2 = u * u
376+
v2 = v * v
377+
u2_at_t = 0.5 * (u2[1:, :] + u2[:-1, :])
378+
v2_at_t = 0.5 * (v2[:, 1:] + v2[:, :-1])
379+
_ke = 0.5 * (u2_at_t + v2_at_t) * (e + h)
380+
_total_ke = np.sum(_ke, all_axes)
381+
382+
total_pe = float(_total_pe) * dx * dy
383+
total_ke = float(_total_ke) * dx * dy
384+
total_e = total_ke + total_pe
385+
elev_max = float(_elev_max)
386+
u_max = float(_u_max)
387+
q_max = float(_q_max)
388+
total_v = float(_total_v) * dx * dy
368389

369390
if i_export == 0:
370391
initial_v = total_v
@@ -399,35 +420,40 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
399420
duration = time_mod.perf_counter() - tic
400421
info(f"Duration: {duration:.2f} s")
401422

402-
e_exact = exact_solution(t, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d)[
403-
2
404-
]
405-
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
406-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
407-
info(f"L2 error: {err_L2:7.15e}")
408-
409-
if nx < 128 or ny < 128:
410-
info("Skipping correctness test due to small problem size.")
411-
elif not benchmark_mode:
412-
tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
413-
assert (
414-
diff_e < tolerance_ene
415-
), f"Energy error exceeds tolerance: {diff_e} > {tolerance_ene}"
416-
if nx == 128 and ny == 128:
417-
if datatype == "f32":
418-
assert numpy.allclose(
419-
err_L2, 4.3127859e-05, rtol=1e-5
420-
), "L2 error does not match"
421-
else:
422-
assert numpy.allclose(
423-
err_L2, 4.315799035627906e-05
424-
), "L2 error does not match"
425-
else:
426-
tolerance_l2 = 1e-4
423+
if device:
424+
# FIXME gpu.memcpy to host requires identity layout
425+
# FIXME reduction on gpu
426+
pass
427+
else:
428+
e_exact = exact_solution(
429+
t, x_t_2d, y_t_2d, x_u_2d, y_u_2d, x_v_2d, y_v_2d
430+
)[2]
431+
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
432+
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
433+
info(f"L2 error: {err_L2:7.15e}")
434+
435+
if nx < 128 or ny < 128:
436+
info("Skipping correctness test due to small problem size.")
437+
elif not benchmark_mode:
438+
tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
427439
assert (
428-
err_L2 < tolerance_l2
429-
), f"L2 error exceeds tolerance: {err_L2} > {tolerance_l2}"
430-
info("SUCCESS")
440+
diff_e < tolerance_ene
441+
), f"Energy error exceeds tolerance: {diff_e} > {tolerance_ene}"
442+
if nx == 128 and ny == 128:
443+
if datatype == "f32":
444+
assert numpy.allclose(
445+
err_L2, 4.3127859e-05, rtol=1e-5
446+
), "L2 error does not match"
447+
else:
448+
assert numpy.allclose(
449+
err_L2, 4.315799035627906e-05
450+
), "L2 error does not match"
451+
else:
452+
tolerance_l2 = 1e-4
453+
assert (
454+
err_L2 < tolerance_l2
455+
), f"L2 error exceeds tolerance: {err_L2} > {tolerance_l2}"
456+
info("SUCCESS")
431457

432458
fini()
433459

0 commit comments

Comments
 (0)