@@ -111,12 +111,14 @@ def run(n, backend, datatype, benchmark_mode):
111111 t_end = 1.0
112112
113113 # coordinate arrays
114+ sync ()
114115 x_t_2d = fromfunction (
115- lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype
116+ lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype , device = ""
116117 )
117118 y_t_2d = fromfunction (
118- lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype
119+ lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype , device = ""
119120 )
121+ sync ()
120122
121123 T_shape = (nx , ny )
122124 U_shape = (nx + 1 , ny )
@@ -144,6 +146,22 @@ def run(n, backend, datatype, benchmark_mode):
144146 u2 = create_full (U_shape , 0.0 , dtype )
145147 v2 = create_full (V_shape , 0.0 , dtype )
146148
149+ # compute time step
150+ alpha = 0.5
151+ c = (g * h ) ** 0.5
152+ dt = alpha * dx / c
153+ dt = t_export / int (math .ceil (t_export / dt ))
154+ nt = int (math .ceil (t_end / dt ))
155+ if benchmark_mode :
156+ dt = 1e-5
157+ nt = 100
158+ t_export = dt * 25
159+
160+ info (f"Time step: { dt } s" )
161+ info (f"Total run time: { t_end } s, { nt } time steps" )
162+
163+ sync ()
164+
147165 def exact_elev (t , x_t_2d , y_t_2d , lx , ly ):
148166 """
149167 Exact solution for elevation field.
@@ -162,25 +180,6 @@ def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
162180 sol_t = numpy .cos (2 * omega * t )
163181 return amp * sol_x * sol_y * sol_t
164182
165- # inital elevation
166- e [:, :] = exact_elev (0.0 , x_t_2d , y_t_2d , lx , ly )
167-
168- # compute time step
169- alpha = 0.5
170- c = (g * h ) ** 0.5
171- dt = alpha * dx / c
172- dt = t_export / int (math .ceil (t_export / dt ))
173- nt = int (math .ceil (t_end / dt ))
174- if benchmark_mode :
175- dt = 1e-5
176- nt = 100
177- t_export = dt * 25
178-
179- info (f"Time step: { dt } s" )
180- info (f"Total run time: { t_end } s, { nt } time steps" )
181-
182- sync ()
183-
184183 def rhs (u , v , e ):
185184 """
186185 Evaluate right hand side of the equations
@@ -215,6 +214,16 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
215214 v [:, 1 :- 1 ] = v [:, 1 :- 1 ] / 3.0 + 2.0 / 3.0 * (v2 [:, 1 :- 1 ] + dt * dvdt )
216215 e [:, :] = e [:, :] / 3.0 + 2.0 / 3.0 * (e2 [:, :] + dt * dedt )
217216
217+ # warm git cache
218+ step (u , v , e , u1 , v1 , e1 , u2 , v2 , e2 )
219+ sync ()
220+
221+ # initial solution
222+ e [:, :] = exact_elev (0.0 , x_t_2d , y_t_2d , lx , ly ).to_device (device )
223+ u [:, :] = create_full (U_shape , 0.0 , dtype )
224+ v [:, :] = create_full (V_shape , 0.0 , dtype )
225+ sync ()
226+
218227 t = 0
219228 i_export = 0
220229 next_t_export = 0
@@ -226,9 +235,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
226235 t = i * dt
227236
228237 if t >= next_t_export - 1e-8 :
229- _elev_max = np .max (e , all_axes )
230- _u_max = np .max (u , all_axes )
231- _total_v = np .sum (e + h , all_axes )
238+ if device :
239+ # FIXME gpu.memcpy to host requires identity layout
240+ # FIXME reduction on gpu
241+ # e_host = e.to_device()
242+ # u_host = u.to_device()
243+ # h_host = h.to_device()
244+ # _elev_max = np.max(e_host, all_axes)
245+ # _u_max = np.max(u_host, all_axes)
246+ # _total_v = np.sum(e_host + h, all_axes)
247+ _elev_max = 0
248+ _u_max = 0
249+ _total_v = 0
250+ else :
251+ _elev_max = np .max (e , all_axes )
252+ _u_max = np .max (u , all_axes )
253+ _total_v = np .sum (e + h , all_axes )
232254
233255 elev_max = float (_elev_max )
234256 u_max = float (_u_max )
@@ -265,10 +287,17 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
265287
266288 e_exact = exact_elev (t , x_t_2d , y_t_2d , lx , ly )
267289 err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
268- err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
290+ if device :
291+ # FIXME gpu.memcpy to host requires identity layout
292+ # FIXME reduction on gpu
293+ # err2_host = err2.to_device()
294+ # err_L2 = math.sqrt(float(np.sum(err2_host, all_axes)))
295+ err_L2 = 0
296+ else :
297+ err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
269298 info (f"L2 error: { err_L2 :7.5e} " )
270299
271- if nx == 128 and ny == 128 and not benchmark_mode :
300+ if nx == 128 and ny == 128 and not benchmark_mode and not device :
272301 if datatype == "f32" :
273302 assert numpy .allclose (err_L2 , 7.2235471e-03 , rtol = 1e-4 )
274303 else :
0 commit comments