@@ -32,7 +32,8 @@ create_dataloader_post(trajectories; ntrajectory, nunroll, device = identity) =
3232 @assert nt ≥ nunroll " Trajectory too short for nunroll = $nunroll "
3333 istart = rand (rng, 1 : nt- nunroll)
3434 it = istart: istart+ nunroll
35- (; u = device .(u[it]), t = t[it])
35+ u = selectdim (u, ndims (u), it) |> Array |> device # convert view to array first
36+ (; u, t = t[it])
3637 end
3738 data, rng
3839 end
@@ -115,25 +116,25 @@ Create a-posteriori loss function.
115116function create_loss_post (; setup, method, psolver, closure, nsubstep = 1 )
116117 closure_model = wrappedclosure (closure, setup)
117118 setup = (; setup... , closure_model)
118- (; dimension, Iu, x) = setup. grid
119- D = dimension ()
119+ (; Iu) = setup. grid
120+ inside = Iu[1 ]
121+ @assert all (== (inside), Iu)
120122 loss_post (data, θ) =
121123 sum (data) do (; u, t)
122124 T = eltype (θ)
123- v = u[ 1 ]
125+ ules = selectdim (u, ndims (u), 1 ) |> collect
124126 stepper =
125- create_stepper (method; setup, psolver, u = v , temp = nothing , t = t[1 ])
127+ create_stepper (method; setup, psolver, u = ules , temp = nothing , t = t[1 ])
126128 loss = zero (T)
127129 for it = 2 : length (t)
128130 Δt = (t[it] - t[it- 1 ]) / nsubstep
129131 for isub = 1 : nsubstep
130132 stepper = timestep (method, stepper, Δt; θ)
131133 end
132- a, b = T (0 ), T (0 )
133- for α = 1 : length (u[1 ])
134- a += sum (abs2, (stepper. u[α]- u[it][α])[Iu[α]])
135- b += sum (abs2, u[it][α][Iu[α]])
136- end
134+ uref = view (u, inside, :, it)
135+ ules = view (stepper. u, inside, :)
136+ a = sum (abs2, ules - uref)
137+ b = sum (abs2, uref)
137138 loss += a / b
138139 end
139140 loss / (length (t) - 1 )
@@ -145,14 +146,15 @@ Create a-posteriori relative error.
145146"""
146147function create_relerr_post (; data, setup, method, psolver, closure_model, nsubstep = 1 )
147148 setup = (; setup... , closure_model)
148- (; dimension, Iu) = setup. grid
149- D = dimension ()
149+ (; Iu) = setup. grid
150+ inside = Iu[1 ]
151+ @assert all (== (inside), Iu)
150152 (; u, t) = data
151- v = copy .(u[ 1 ])
153+ v = selectdim (u, ndims (u), 1 ) |> collect
152154 cache = IncompressibleNavierStokes. ode_method_cache (method, setup)
153155 function relerr_post (θ)
154- T = eltype (u[ 1 ][ 1 ] )
155- copyto! . (v, u[ 1 ] )
156+ T = eltype (u)
157+ copyto! (v, selectdim (u, ndims (u), 1 ) )
156158 stepper = create_stepper (method; setup, psolver, u = v, temp = nothing , t = t[1 ])
157159 e = zero (T)
158160 for it = 2 : length (t)
@@ -161,13 +163,10 @@ function create_relerr_post(; data, setup, method, psolver, closure_model, nsubs
161163 stepper =
162164 IncompressibleNavierStokes. timestep! (method, stepper, Δt; θ, cache)
163165 end
164- a, b = T (0 ), T (0 )
165- for α = 1 : D
166- # a += sum(abs2, (stepper.u[α]-u[it][α])[Iu[α]])
167- # b += sum(abs2, u[it][α][Iu[α]])
168- a += sum (abs2, view (stepper. u[α] - u[it][α], Iu[α]))
169- b += sum (abs2, view (u[it][α], Iu[α]))
170- end
166+ uref = view (u, inside, :, it)
167+ ules = view (stepper. u, inside, :)
168+ a = sum (abs2, ules - uref)
169+ b = sum (abs2, uref)
171170 e += sqrt (a) / sqrt (b)
172171 end
173172 e / (length (t) - 1 )
@@ -189,15 +188,17 @@ function create_relerr_symmetry_post(;
189188 (; dimension, Iu) = setup. grid
190189 D = dimension ()
191190 T = eltype (u[1 ])
191+ inside = Iu[1 ]
192+ @assert all (== (inside), Iu)
192193 cache = IncompressibleNavierStokes. ode_method_cache (method, setup)
193194 function err (θ)
194195 stepper =
195- create_stepper (method; setup, psolver, u = copy . (u), temp = nothing , t = T (0 ))
196+ create_stepper (method; setup, psolver, u = copy (u), temp = nothing , t = T (0 ))
196197 stepper_rot = create_stepper (
197198 method;
198199 setup,
199200 psolver,
200- u = rot2stag (copy . (u), g),
201+ u = rot2stag (copy (u), g),
201202 temp = nothing ,
202203 t = T (0 ),
203204 )
@@ -207,11 +208,8 @@ function create_relerr_symmetry_post(;
207208 stepper_rot =
208209 IncompressibleNavierStokes. timestep! (method, stepper_rot, Δt; θ, cache)
209210 u_rot = rot2stag (stepper. u, g)
210- a, b = T (0 ), T (0 )
211- for α = 1 : D
212- a += sum (abs2, view (stepper_rot. u[α] - u_rot[α], Iu[α]))
213- b += sum (abs2, view (u_rot[α], Iu[α]))
214- end
211+ a = sum (abs2, view (stepper_rot. u - u_rot, inside, :))
212+ b = sum (abs2, view (u_rot, inside, :))
215213 e += sqrt (a) / sqrt (b)
216214 end
217215 e / nstep
@@ -225,16 +223,17 @@ function create_relerr_symmetry_prior(; u, setup, g = 1)
225223 (; grid, closure_model) = setup
226224 (; dimension, Iu) = grid
227225 D = dimension ()
228- T = eltype (u[1 ][1 ])
226+ T = eltype (u[1 ])
227+ inside = Iu[1 ]
228+ @assert all (== (inside), Iu)
229229 function err (θ)
230- e = sum (u ) do u
230+ e = sum (eachslice (u; dims = ndims (u)) ) do u
231231 cr = closure_model (rot2stag (u, g), θ)
232232 rc = rot2stag (closure_model (u, θ), g)
233- a, b = T (0 ), T (0 )
234- for α = 1 : D
235- a += sum (abs2, view (rc[α] - cr[α], Iu[α]))
236- b += sum (abs2, view (cr[α], Iu[α]))
237- end
233+ cr = view (cr, inside, :)
234+ rc = view (rc, inside, :)
235+ a = sum (abs2, rc - cr)
236+ b = sum (abs2, cr)
238237 sqrt (a) / sqrt (b)
239238 end
240239 e / length (u)
@@ -306,6 +305,6 @@ function create_callback(
306305 (; callbackstate, callback)
307306end
308307
308+ getlearningrate (r) = - 1 # Fallback
309309getlearningrate (r:: Adam ) = r. eta
310310getlearningrate (r:: OptimiserChain{Tuple{Adam,WeightDecay}} ) = r. opts[1 ]. eta
311- getlearningrate (r) = - 1
0 commit comments