@@ -111,12 +111,14 @@ def run(n, backend, datatype, benchmark_mode):
111
111
t_end = 1.0
112
112
113
113
# coordinate arrays
114
+ sync ()
114
115
x_t_2d = fromfunction (
115
- lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype
116
+ lambda i , j : xmin + i * dx + dx / 2 , (nx , ny ), dtype = dtype , device = ""
116
117
)
117
118
y_t_2d = fromfunction (
118
- lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype
119
+ lambda i , j : ymin + j * dy + dy / 2 , (nx , ny ), dtype = dtype , device = ""
119
120
)
121
+ sync ()
120
122
121
123
T_shape = (nx , ny )
122
124
U_shape = (nx + 1 , ny )
@@ -144,6 +146,22 @@ def run(n, backend, datatype, benchmark_mode):
144
146
u2 = create_full (U_shape , 0.0 , dtype )
145
147
v2 = create_full (V_shape , 0.0 , dtype )
146
148
149
+ # compute time step
150
+ alpha = 0.5
151
+ c = (g * h ) ** 0.5
152
+ dt = alpha * dx / c
153
+ dt = t_export / int (math .ceil (t_export / dt ))
154
+ nt = int (math .ceil (t_end / dt ))
155
+ if benchmark_mode :
156
+ dt = 1e-5
157
+ nt = 100
158
+ t_export = dt * 25
159
+
160
+ info (f"Time step: { dt } s" )
161
+ info (f"Total run time: { t_end } s, { nt } time steps" )
162
+
163
+ sync ()
164
+
147
165
def exact_elev (t , x_t_2d , y_t_2d , lx , ly ):
148
166
"""
149
167
Exact solution for elevation field.
@@ -162,25 +180,6 @@ def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
162
180
sol_t = numpy .cos (2 * omega * t )
163
181
return amp * sol_x * sol_y * sol_t
164
182
165
- # inital elevation
166
- e [:, :] = exact_elev (0.0 , x_t_2d , y_t_2d , lx , ly )
167
-
168
- # compute time step
169
- alpha = 0.5
170
- c = (g * h ) ** 0.5
171
- dt = alpha * dx / c
172
- dt = t_export / int (math .ceil (t_export / dt ))
173
- nt = int (math .ceil (t_end / dt ))
174
- if benchmark_mode :
175
- dt = 1e-5
176
- nt = 100
177
- t_export = dt * 25
178
-
179
- info (f"Time step: { dt } s" )
180
- info (f"Total run time: { t_end } s, { nt } time steps" )
181
-
182
- sync ()
183
-
184
183
def rhs (u , v , e ):
185
184
"""
186
185
Evaluate right hand side of the equations
@@ -215,6 +214,16 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
215
214
v [:, 1 :- 1 ] = v [:, 1 :- 1 ] / 3.0 + 2.0 / 3.0 * (v2 [:, 1 :- 1 ] + dt * dvdt )
216
215
e [:, :] = e [:, :] / 3.0 + 2.0 / 3.0 * (e2 [:, :] + dt * dedt )
217
216
217
+ # warm git cache
218
+ step (u , v , e , u1 , v1 , e1 , u2 , v2 , e2 )
219
+ sync ()
220
+
221
+ # initial solution
222
+ e [:, :] = exact_elev (0.0 , x_t_2d , y_t_2d , lx , ly ).to_device (device )
223
+ u [:, :] = create_full (U_shape , 0.0 , dtype )
224
+ v [:, :] = create_full (V_shape , 0.0 , dtype )
225
+ sync ()
226
+
218
227
t = 0
219
228
i_export = 0
220
229
next_t_export = 0
@@ -226,9 +235,22 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
226
235
t = i * dt
227
236
228
237
if t >= next_t_export - 1e-8 :
229
- _elev_max = np .max (e , all_axes )
230
- _u_max = np .max (u , all_axes )
231
- _total_v = np .sum (e + h , all_axes )
238
+ if device :
239
+ # FIXME gpu.memcpy to host requires identity layout
240
+ # FIXME reduction on gpu
241
+ # e_host = e.to_device()
242
+ # u_host = u.to_device()
243
+ # h_host = h.to_device()
244
+ # _elev_max = np.max(e_host, all_axes)
245
+ # _u_max = np.max(u_host, all_axes)
246
+ # _total_v = np.sum(e_host + h, all_axes)
247
+ _elev_max = 0
248
+ _u_max = 0
249
+ _total_v = 0
250
+ else :
251
+ _elev_max = np .max (e , all_axes )
252
+ _u_max = np .max (u , all_axes )
253
+ _total_v = np .sum (e + h , all_axes )
232
254
233
255
elev_max = float (_elev_max )
234
256
u_max = float (_u_max )
@@ -265,10 +287,17 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
265
287
266
288
e_exact = exact_elev (t , x_t_2d , y_t_2d , lx , ly )
267
289
err2 = (e_exact - e ) * (e_exact - e ) * dx * dy / lx / ly
268
- err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
290
+ if device :
291
+ # FIXME gpu.memcpy to host requires identity layout
292
+ # FIXME reduction on gpu
293
+ # err2_host = err2.to_device()
294
+ # err_L2 = math.sqrt(float(np.sum(err2_host, all_axes)))
295
+ err_L2 = 0
296
+ else :
297
+ err_L2 = math .sqrt (float (np .sum (err2 , all_axes )))
269
298
info (f"L2 error: { err_L2 :7.5e} " )
270
299
271
- if nx == 128 and ny == 128 and not benchmark_mode :
300
+ if nx == 128 and ny == 128 and not benchmark_mode and not device :
272
301
if datatype == "f32" :
273
302
assert numpy .allclose (err_L2 , 7.2235471e-03 , rtol = 1e-4 )
274
303
else :
0 commit comments