@@ -134,7 +134,7 @@ def run(n, backend, datatype, benchmark_mode):
134
134
info (f"Total DOFs: { dofs_T + dofs_U + dofs_V } " )
135
135
136
136
# prognostic variables: elevation, (u, v) velocity
137
- # e = create_full(T_shape, 0.0, dtype)
137
+ e = create_full (T_shape , 0.0 , dtype )
138
138
u = create_full (U_shape , 0.0 , dtype )
139
139
v = create_full (V_shape , 0.0 , dtype )
140
140
@@ -167,9 +167,7 @@ def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
167
167
return amp * sol_x * sol_y * sol_t
168
168
169
169
# initial elevation
170
- # e[:, :] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
171
- # NOTE assignment fails, do not pre-allocate e
172
- e = exact_elev (0.0 , x_t_2d , y_t_2d , lx , ly ).to_device (device )
170
+ e [:, :] = exact_elev (0.0 , x_t_2d , y_t_2d , lx , ly )
173
171
sync ()
174
172
175
173
# compute time step
@@ -235,8 +233,8 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
235
233
t = i * dt
236
234
237
235
if t >= next_t_export - 1e-8 :
238
- _elev_max = 0 # np.max(e, all_axes)
239
- _u_max = 0 # np.max(u, all_axes)
236
+ _elev_max = e [ 0 , 0 ]. to_device () # np.max(e, all_axes)
237
+ _u_max = u [ 0 , 0 ]. to_device () # np.max(u, all_axes)
240
238
_total_v = 0 # np.sum(e + h, all_axes)
241
239
242
240
elev_max = float (_elev_max )
0 commit comments