Skip to content

Commit 0c674f4

Browse files
committed
wave equation: add partial gpu support, warm up jit cache
1 parent 5304a84 commit 0c674f4

File tree

1 file changed

+55
-26
lines changed

1 file changed

+55
-26
lines changed

Diff for: examples/wave_equation.py

+55-26
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,14 @@ def run(n, backend, datatype, benchmark_mode):
111111
t_end = 1.0
112112

113113
# coordinate arrays
114+
sync()
114115
x_t_2d = fromfunction(
115-
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype
116+
lambda i, j: xmin + i * dx + dx / 2, (nx, ny), dtype=dtype, device=""
116117
)
117118
y_t_2d = fromfunction(
118-
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype
119+
lambda i, j: ymin + j * dy + dy / 2, (nx, ny), dtype=dtype, device=""
119120
)
121+
sync()
120122

121123
T_shape = (nx, ny)
122124
U_shape = (nx + 1, ny)
@@ -144,6 +146,22 @@ def run(n, backend, datatype, benchmark_mode):
144146
u2 = create_full(U_shape, 0.0, dtype)
145147
v2 = create_full(V_shape, 0.0, dtype)
146148

149+
# compute time step
150+
alpha = 0.5
151+
c = (g * h) ** 0.5
152+
dt = alpha * dx / c
153+
dt = t_export / int(math.ceil(t_export / dt))
154+
nt = int(math.ceil(t_end / dt))
155+
if benchmark_mode:
156+
dt = 1e-5
157+
nt = 100
158+
t_export = dt * 25
159+
160+
info(f"Time step: {dt} s")
161+
info(f"Total run time: {t_end} s, {nt} time steps")
162+
163+
sync()
164+
147165
def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
148166
"""
149167
Exact solution for elevation field.
@@ -162,25 +180,6 @@ def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
162180
sol_t = numpy.cos(2 * omega * t)
163181
return amp * sol_x * sol_y * sol_t
164182

165-
# inital elevation
166-
e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
167-
168-
# compute time step
169-
alpha = 0.5
170-
c = (g * h) ** 0.5
171-
dt = alpha * dx / c
172-
dt = t_export / int(math.ceil(t_export / dt))
173-
nt = int(math.ceil(t_end / dt))
174-
if benchmark_mode:
175-
dt = 1e-5
176-
nt = 100
177-
t_export = dt * 25
178-
179-
info(f"Time step: {dt} s")
180-
info(f"Total run time: {t_end} s, {nt} time steps")
181-
182-
sync()
183-
184183
def rhs(u, v, e):
185184
"""
186185
Evaluate right hand side of the equations
@@ -215,6 +214,16 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
215214
v[:, 1:-1] = v[:, 1:-1] / 3.0 + 2.0 / 3.0 * (v2[:, 1:-1] + dt * dvdt)
216215
e[:, :] = e[:, :] / 3.0 + 2.0 / 3.0 * (e2[:, :] + dt * dedt)
217216

217+
# warm git cache
218+
step(u, v, e, u1, v1, e1, u2, v2, e2)
219+
sync()
220+
221+
# initial solution
222+
e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly).to_device(device)
223+
u[:, :] = create_full(U_shape, 0.0, dtype)
224+
v[:, :] = create_full(V_shape, 0.0, dtype)
225+
sync()
226+
218227
t = 0
219228
i_export = 0
220229
next_t_export = 0
@@ -226,9 +235,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
226235
t = i * dt
227236

228237
if t >= next_t_export - 1e-8:
229-
_elev_max = np.max(e, all_axes)
230-
_u_max = np.max(u, all_axes)
231-
_total_v = np.sum(e + h, all_axes)
238+
if device:
239+
# FIXME gpu.memcpy to host requires identity layout
240+
# FIXME reduction on gpu
241+
# e_host = e.to_device()
242+
# u_host = u.to_device()
243+
# h_host = h.to_device()
244+
# _elev_max = np.max(e_host, all_axes)
245+
# _u_max = np.max(u_host, all_axes)
246+
# _total_v = np.sum(e_host + h, all_axes)
247+
_elev_max = 0
248+
_u_max = 0
249+
_total_v = 0
250+
else:
251+
_elev_max = np.max(e, all_axes)
252+
_u_max = np.max(u, all_axes)
253+
_total_v = np.sum(e + h, all_axes)
232254

233255
elev_max = float(_elev_max)
234256
u_max = float(_u_max)
@@ -265,10 +287,17 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
265287

266288
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
267289
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
268-
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
290+
if device:
291+
# FIXME gpu.memcpy to host requires identity layout
292+
# FIXME reduction on gpu
293+
# err2_host = err2.to_device()
294+
# err_L2 = math.sqrt(float(np.sum(err2_host, all_axes)))
295+
err_L2 = 0
296+
else:
297+
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
269298
info(f"L2 error: {err_L2:7.5e}")
270299

271-
if nx == 128 and ny == 128 and not benchmark_mode:
300+
if nx == 128 and ny == 128 and not benchmark_mode and not device:
272301
if datatype == "f32":
273302
assert numpy.allclose(err_L2, 7.2235471e-03, rtol=1e-4)
274303
else:

0 commit comments

Comments
 (0)