Skip to content

Commit ded6727

Browse files
committed
wave equation: warm up jit cache
1 parent 634856e commit ded6727

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

Diff for: examples/wave_equation.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,20 @@ def run(n, backend, datatype, benchmark_mode):
146146
u2 = create_full(U_shape, 0.0, dtype)
147147
v2 = create_full(V_shape, 0.0, dtype)
148148

149+
# compute time step
150+
alpha = 0.5
151+
c = (g * h) ** 0.5
152+
dt = alpha * dx / c
153+
dt = t_export / int(math.ceil(t_export / dt))
154+
nt = int(math.ceil(t_end / dt))
155+
if benchmark_mode:
156+
dt = 1e-5
157+
nt = 100
158+
t_export = dt * 25
159+
160+
info(f"Time step: {dt} s")
161+
info(f"Total run time: {t_end} s, {nt} time steps")
162+
149163
sync()
150164

151165
def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
@@ -166,26 +180,6 @@ def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
166180
sol_t = numpy.cos(2 * omega * t)
167181
return amp * sol_x * sol_y * sol_t
168182

169-
# initial elevation
170-
e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
171-
sync()
172-
173-
# compute time step
174-
alpha = 0.5
175-
c = (g * h) ** 0.5
176-
dt = alpha * dx / c
177-
dt = t_export / int(math.ceil(t_export / dt))
178-
nt = int(math.ceil(t_end / dt))
179-
if benchmark_mode:
180-
dt = 1e-5
181-
nt = 100
182-
t_export = dt * 25
183-
184-
info(f"Time step: {dt} s")
185-
info(f"Total run time: {t_end} s, {nt} time steps")
186-
187-
sync()
188-
189183
def rhs(u, v, e):
190184
"""
191185
Evaluate right hand side of the equations
@@ -220,6 +214,14 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
220214
v[:, 1:-1] = v[:, 1:-1] / 3.0 + 2.0 / 3.0 * (v2[:, 1:-1] + dt * dvdt)
221215
e[:, :] = e[:, :] / 3.0 + 2.0 / 3.0 * (e2[:, :] + dt * dedt)
222216

217+
# warm git cache
218+
step(u, v, e, u1, v1, e1, u2, v2, e2)
219+
sync()
220+
221+
# initial solution
222+
e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly).to_device(device)
223+
u[:, :] = create_full(U_shape, 0.0, dtype)
224+
v[:, :] = create_full(V_shape, 0.0, dtype)
223225
sync()
224226

225227
t = 0

0 commit comments

Comments
 (0)