@@ -111,28 +111,32 @@ def run(n, backend, datatype, benchmark_mode):
111111 t_end = 1.0
112112
113113 # coordinate arrays
114+ sync ()
114115 x_t_2d = fromfunction (
115- lambda i , j : xmin + i * dx + dx / 2 ,
116- (nx , ny ),
117- dtype = dtype ,
116+ lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype , device = ""
118117 )
119118 y_t_2d = fromfunction (
120- lambda i , j : ymin + j * dy + dy / 2 ,
121- (nx , ny ),
122- dtype = dtype ,
119+ lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype , device = ""
120+ )
121+ x_u_2d = fromfunction (
122+ lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype , device = ""
123123 )
124- x_u_2d = fromfunction (lambda i , j : xmin + i * dx , (nx + 1 , ny ), dtype = dtype )
125124 y_u_2d = fromfunction (
126125 lambda i , j : ymin + j * dy + dy / 2 ,
127126 (nx + 1 , ny ),
128127 dtype = dtype ,
128+ device = "" ,
129129 )
130130 x_v_2d = fromfunction (
131131 lambda i , j : xmin + i * dx + dx / 2 ,
132132 (nx , ny + 1 ),
133133 dtype = dtype ,
134+ device = "" ,
134135 )
135- y_v_2d = fromfunction (lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype )
136+ y_v_2d = fromfunction (
137+ lambda i , j : ymin + j * dy , (nx , ny + 1 ), dtype = dtype , device = ""
138+ )
139+ sync ()
136140
137141 T_shape = (nx , ny )
138142 U_shape = (nx + 1 , ny )
@@ -157,7 +161,7 @@ def run(n, backend, datatype, benchmark_mode):
157161 q = create_full (F_shape , 0.0 , dtype )
158162
159163 # bathymetry
160- h = create_full (T_shape , 0 .0 , dtype )
164+ h = create_full (T_shape , 1 .0 , dtype ) # HACK init with 1
161165
162166 hu = create_full (U_shape , 0.0 , dtype )
163167 hv = create_full (V_shape , 0.0 , dtype )
@@ -205,22 +209,16 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
205209 bath = 1.0
206210 return bath * create_full (T_shape , 1.0 , dtype )
207211
208- # inital elevation
209- u0 , v0 , e0 = exact_solution (
210- 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
211- )
212- e [:, :] = e0
213- u [:, :] = u0
214- v [:, :] = v0
215-
216212 # set bathymetry
217- h [:, :] = bathymetry (x_t_2d , y_t_2d , lx , ly )
213+ # h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device )
218214 # steady state potential energy
219- pe_offset = 0.5 * g * float (np .sum (h ** 2.0 , all_axes )) / nx / ny
215+ # pe_offset = 0.5 * g * float(np.sum(h**2.0, all_axes)) / nx / ny
216+ pe_offset = 0.5 * g * float (1.0 ) / nx / ny
220217
221218 # compute time step
222219 alpha = 0.5
223- h_max = float (np .max (h , all_axes ))
220+ # h_max = float(np.max(h, all_axes))
221+ h_max = float (1.0 )
224222 c = (g * h_max ) ** 0.5
225223 dt = alpha * dx / c
226224 dt = t_export / int (math .ceil (t_export / dt ))
@@ -329,6 +327,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
329327 v [:, 1 :- 1 ] = v [:, 1 :- 1 ] / 3.0 + 2.0 / 3.0 * (v2 [:, 1 :- 1 ] + dt * dvdt )
330328 e [:, :] = e [:, :] / 3.0 + 2.0 / 3.0 * (e2 [:, :] + dt * dedt )
331329
330+ # warm jit cache
331+ step (u , v , e , u1 , v1 , e1 , u2 , v2 , e2 )
332+ sync ()
333+
334+ # initial solution
335+ u0 , v0 , e0 = exact_solution (
336+ 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
337+ )
338+ e [:, :] = e0 .to_device (device )
339+ u [:, :] = u0 .to_device (device )
340+ v [:, :] = v0 .to_device (device )
341+
332342 t = 0
333343 i_export = 0
334344 next_t_export = 0
@@ -341,30 +351,41 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
341351 t = i * dt
342352
343353 if t >= next_t_export - 1e-8 :
344- _elev_max = np .max (e , all_axes )
345- _u_max = np .max (u , all_axes )
346- _q_max = np .max (q , all_axes )
347- _total_v = np .sum (e + h , all_axes )
348-
349- # potential energy
350- _pe = 0.5 * g * (e + h ) * (e - h ) + pe_offset
351- _total_pe = np .sum (_pe , all_axes )
352-
353- # kinetic energy
354- u2 = u * u
355- v2 = v * v
356- u2_at_t = 0.5 * (u2 [1 :, :] + u2 [:- 1 , :])
357- v2_at_t = 0.5 * (v2 [:, 1 :] + v2 [:, :- 1 ])
358- _ke = 0.5 * (u2_at_t + v2_at_t ) * (e + h )
359- _total_ke = np .sum (_ke , all_axes )
360-
361- total_pe = float (_total_pe ) * dx * dy
362- total_ke = float (_total_ke ) * dx * dy
363- total_e = total_ke + total_pe
364- elev_max = float (_elev_max )
365- u_max = float (_u_max )
366- q_max = float (_q_max )
367- total_v = float (_total_v ) * dx * dy
354+ if device :
355+ # FIXME gpu.memcpy to host requires identity layout
356+ # FIXME reduction on gpu
357+ elev_max = 0
358+ u_max = 0
359+ q_max = 0
360+ diff_e = 0
361+ diff_v = 0
362+ total_pe = 0
363+ total_ke = 0
364+ else :
365+ _elev_max = np .max (e , all_axes )
366+ _u_max = np .max (u , all_axes )
367+ _q_max = np .max (q , all_axes )
368+ _total_v = np .sum (e + h , all_axes )
369+
370+ # potential energy
371+ _pe = 0.5 * g * (e + h ) * (e - h ) + pe_offset
372+ _total_pe = np .sum (_pe , all_axes )
373+
374+ # kinetic energy
375+ u2 = u * u
376+ v2 = v * v
377+ u2_at_t = 0.5 * (u2 [1 :, :] + u2 [:- 1 , :])
378+ v2_at_t = 0.5 * (v2 [:, 1 :] + v2 [:, :- 1 ])
379+ _ke = 0.5 * (u2_at_t + v2_at_t ) * (e + h )
380+ _total_ke = np .sum (_ke , all_axes )
381+
382+ total_pe = float (_total_pe ) * dx * dy
383+ total_ke = float (_total_ke ) * dx * dy
384+ total_e = total_ke + total_pe
385+ elev_max = float (_elev_max )
386+ u_max = float (_u_max )
387+ q_max = float (_q_max )
388+ total_v = float (_total_v ) * dx * dy
368389
369390 if i_export == 0 :
370391 initial_v = total_v
@@ -399,35 +420,40 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
399420 duration = time_mod .perf_counter () - tic
400421 info (f"Duration: { duration :.2f} s" )
401422
402- e_exact = exact_solution (t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d )[
403- 2
404- ]
405- err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
406- err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
407- info (f"L2 error: { err_L2 :7.15e} " )
408-
409- if nx < 128 or ny < 128 :
410- info ("Skipping correctness test due to small problem size." )
411- elif not benchmark_mode :
412- tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
413- assert (
414- diff_e < tolerance_ene
415- ), f"Energy error exceeds tolerance: { diff_e } > { tolerance_ene } "
416- if nx == 128 and ny == 128 :
417- if datatype == "f32" :
418- assert numpy .allclose (
419- err_L2 , 4.3127859e-05 , rtol = 1e-5
420- ), "L2 error does not match"
421- else :
422- assert numpy .allclose (
423- err_L2 , 4.315799035627906e-05
424- ), "L2 error does not match"
425- else :
426- tolerance_l2 = 1e-4
423+ if device :
424+ # FIXME gpu.memcpy to host requires identity layout
425+ # FIXME reduction on gpu
426+ pass
427+ else :
428+ e_exact = exact_solution (
429+ t , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
430+ )[2 ]
431+ err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
432+ err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
433+ info (f"L2 error: { err_L2 :7.15e} " )
434+
435+ if nx < 128 or ny < 128 :
436+ info ("Skipping correctness test due to small problem size." )
437+ elif not benchmark_mode :
438+ tolerance_ene = 1e-7 if datatype == "f32" else 1e-9
427439 assert (
428- err_L2 < tolerance_l2
429- ), f"L2 error exceeds tolerance: { err_L2 } > { tolerance_l2 } "
430- info ("SUCCESS" )
440+ diff_e < tolerance_ene
441+ ), f"Energy error exceeds tolerance: { diff_e } > { tolerance_ene } "
442+ if nx == 128 and ny == 128 :
443+ if datatype == "f32" :
444+ assert numpy .allclose (
445+ err_L2 , 4.3127859e-05 , rtol = 1e-5
446+ ), "L2 error does not match"
447+ else :
448+ assert numpy .allclose (
449+ err_L2 , 4.315799035627906e-05
450+ ), "L2 error does not match"
451+ else :
452+ tolerance_l2 = 1e-4
453+ assert (
454+ err_L2 < tolerance_l2
455+ ), f"L2 error exceeds tolerance: { err_L2 } > { tolerance_l2 } "
456+ info ("SUCCESS" )
431457
432458 fini ()
433459
0 commit comments