@@ -209,14 +209,6 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
209
209
bath = 1.0
210
210
return bath * create_full (T_shape , 1.0 , dtype )
211
211
212
- # inital elevation
213
- u0 , v0 , e0 = exact_solution (
214
- 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
215
- )
216
- e [:, :] = e0 .to_device (device )
217
- u [:, :] = u0 .to_device (device )
218
- v [:, :] = v0 .to_device (device )
219
-
220
212
# set bathymetry
221
213
# h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device)
222
214
# steady state potential energy
@@ -335,6 +327,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
335
327
v [:, 1 :- 1 ] = v [:, 1 :- 1 ] / 3.0 + 2.0 / 3.0 * (v2 [:, 1 :- 1 ] + dt * dvdt )
336
328
e [:, :] = e [:, :] / 3.0 + 2.0 / 3.0 * (e2 [:, :] + dt * dedt )
337
329
330
+ # warm jit cache
331
+ step (u , v , e , u1 , v1 , e1 , u2 , v2 , e2 )
332
+ sync ()
333
+
334
+ # initial solution
335
+ u0 , v0 , e0 = exact_solution (
336
+ 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
337
+ )
338
+ e [:, :] = e0 .to_device (device )
339
+ u [:, :] = u0 .to_device (device )
340
+ v [:, :] = v0 .to_device (device )
341
+
338
342
t = 0
339
343
i_export = 0
340
344
next_t_export = 0
0 commit comments