Skip to content

Commit 003f739

Browse files
committed
fix: update for array
1 parent 3b3c7fe commit 003f739

File tree

16 files changed

+214
-153
lines changed

16 files changed

+214
-153
lines changed

lib/NeuralClosure/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
88
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
99
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1010
IncompressibleNavierStokes = "5e318141-6589-402b-868d-77d7df8c442e"
11+
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1112
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
@@ -28,6 +29,7 @@ Accessors = "0.1"
2829
ComponentArrays = "0.15"
2930
DocStringExtensions = "0.9"
3031
IncompressibleNavierStokes = "2"
32+
JLD2 = "0.5.7"
3133
KernelAbstractions = "0.9"
3234
LinearAlgebra = "1"
3335
Lux = "1"

lib/NeuralClosure/src/NeuralClosure.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using ComponentArrays: ComponentArray
88
using DocStringExtensions
99
using IncompressibleNavierStokes
1010
using IncompressibleNavierStokes: Dimension, momentum!, apply_bc_u!, project!
11+
using JLD2
1112
using KernelAbstractions
1213
using LinearAlgebra
1314
using Lux

lib/NeuralClosure/src/closure.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
"""
22
Wrap closure model and parameters so that it can be used in the solver.
33
"""
4-
wrappedclosure(m) =
4+
function wrappedclosure(m, setup)
5+
(; Iu) = setup.grid
6+
inside = Iu[1]
7+
@assert all(==(inside), Iu) "Only periodic grids are supported"
58
function neuralclosure(u, θ)
69
s = size(u)
7-
u = reshape(u, s..., 1) # Add sample dim
10+
# u = view(u, inside, :)
11+
u = u[inside, :]
12+
u = reshape(u, size(u)..., 1) # Add sample dim
813
mu = m(u, θ)
914
mu = pad_circular(mu, 1)
1015
mu = reshape(mu, s) # Remove sample dim
1116
end
17+
end
1218

1319
"""
1420
Create neural closure model from layers.

lib/NeuralClosure/src/data_generation.jl

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,20 @@ end
5757
"""
5858
Save filtered DNS data.
5959
"""
60-
filtersaver(
60+
function filtersaver(
6161
dns,
6262
les,
6363
filters,
6464
compression,
6565
psolver_dns,
6666
psolver_les;
6767
nupdate = 1,
68+
filenames,
6869
F = vectorfield(dns),
6970
p = scalarfield(dns),
70-
) =
71-
processor(
72-
(results, state) -> (; results..., comptime = time() - results.comptime),
73-
) do state
71+
)
72+
@assert isnothing(filenames) || length(filenames) == length(les) * length(filters)
73+
function initialize(state)
7474
comptime = time()
7575
t = fill(state[].t, 0)
7676
dnsobs = Observable((; state[].u, F, state[].t))
@@ -105,6 +105,20 @@ filtersaver(
105105
state[] = state[] # Save initial conditions
106106
results
107107
end
108+
function finalize(results, state)
109+
comptime = time() - results.comptime
110+
(; data, t) = results
111+
map(enumerate(data)) do (i, data)
112+
(; u, c) = data
113+
u = stack(u)
114+
c = stack(c)
115+
results = (; u, c, t, comptime)
116+
isnothing(filenames) || jldsave(filenames[i]; results...)
117+
results
118+
end
119+
end
120+
processor(initialize, finalize)
121+
end
108122

109123
"""
110124
Create filtered DNS data.
@@ -126,6 +140,7 @@ function create_les_data(;
126140
icfunc = (setup, psolver, rng) -> random_field(setup, typeof(Re)(0); psolver, rng),
127141
processors = (; log = timelogger(; nupdate = 10)),
128142
rng,
143+
filenames = nothing,
129144
kwargs...,
130145
)
131146
T = typeof(Re)
@@ -194,6 +209,7 @@ function create_les_data(;
194209
psolver,
195210
psolver_les;
196211
nupdate = savefreq,
212+
filenames,
197213

198214
# Reuse arrays from cache to save memory in 3D DNS.
199215
# Since processors are called outside
@@ -224,7 +240,7 @@ function create_io_arrays(data, setup)
224240
for it = 1:nt, α = 1:D
225241
copyto!(
226242
view(u, colons..., α, it),
227-
view(getfield(trajectory, usym)[it], Iu[α], α),
243+
view(getfield(trajectory, usym), Iu[α], α, it),
228244
)
229245
end
230246
u

lib/NeuralClosure/src/filter.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ struct VolumeAverage <: AbstractFilter end
2525

2626
function (::FaceAverage)(v, u, setup_les, comp)
2727
(; grid, backend, workgroupsize) = setup_les
28-
(; Nu, Iu) = grid
29-
D = length(u)
28+
(; dimension, Nu, Iu) = grid
29+
D = dimension()
3030
@kernel function Φ!(v, u, ::Val{α}, face, I0) where {α}
3131
I = @index(Global, Cartesian)
3232
J = I0 + comp * (I - oneunit(I))
@@ -48,8 +48,8 @@ end
4848
"Reconstruct DNS velocity `u` from LES velocity `v`."
4949
function reconstruct!(u, v, setup_dns, setup_les, comp)
5050
(; grid, boundary_conditions, backend, workgroupsize) = setup_les
51-
(; N) = grid
52-
D = length(u)
51+
(; dimension, N) = grid
52+
D = dimension()
5353
e = Offset(D)
5454
@assert all(bc -> bc[1] isa PeriodicBC && bc[2] isa PeriodicBC, boundary_conditions)
5555
@kernel function R!(u, v, ::Val{α}, volume) where {α}
@@ -79,13 +79,13 @@ reconstruct(v, setup_dns, setup_les, comp) =
7979

8080
function (::VolumeAverage)(v, u, setup_les, comp)
8181
(; grid, boundary_conditions, backend, workgroupsize) = setup_les
82-
(; N, Nu, Iu) = grid
83-
D = length(u)
82+
(; dimension, N, Nu, Iu) = grid
83+
D = dimension()
8484
@assert all(bc -> bc[1] isa PeriodicBC && bc[2] isa PeriodicBC, boundary_conditions)
8585
@kernel function Φ!(v, u, ::Val{α}, volume, I0) where {α}
8686
I = @index(Global, Cartesian)
8787
J = I0 + comp * (I - oneunit(I))
88-
s = zero(eltype(v[α]))
88+
s = zero(eltype(v))
8989
# n = 0
9090
for i in volume
9191
# Periodic extension

lib/NeuralClosure/src/groupconv.jl

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,42 @@ function rot2(u, r)
4141
u[I, chans...]
4242
end
4343

44-
# For vector fields (u, v)
44+
"Rotate vector fields `[ux;;; uy]`"
4545
function rot2(u::Tuple{T,T}, r) where {T}
46+
# ux, uy = eachslice(u; dims = ndims(u))
47+
ux, uy = u
4648
r = mod(r, 4)
47-
ru = rot2(u[1], r)
48-
rv = rot2(u[2], r)
49-
if r == 0
50-
(ru, rv)
49+
rx = rot2(ux, r)
50+
ry = rot2(uy, r)
51+
ru = if r == 0
52+
(rx, ry)
53+
elseif r == 1
54+
(-ry, rx)
55+
elseif r == 2
56+
(-rx, -ry)
57+
elseif r == 3
58+
(ry, -rx)
59+
end
60+
ru
61+
end
62+
63+
"Rotate vector fields `[ux;;; uy]`"
64+
function vecrot2(u, r)
65+
# ux, uy = eachslice(u; dims = ndims(u))
66+
ux, uy = u[:, :, 1], u[:, :, 2]
67+
r = mod(r, 4)
68+
rx = rot2(ux, r)
69+
ry = rot2(uy, r)
70+
ru = if r == 0
71+
(rx, ry)
5172
elseif r == 1
52-
(-rv, ru)
73+
(-ry, rx)
5374
elseif r == 2
54-
(-ru, -rv)
75+
(-rx, -ry)
5576
elseif r == 3
56-
(rv, -ru)
77+
(ry, -rx)
5778
end
79+
stack(ru)
5880
end
5981

6082
# # For augmented vector fields (u, v, -u, -v)
@@ -77,8 +99,9 @@ end
7799
"Rotate staggered grid velocity field. See also [`rot2`](@ref)."
78100
function rot2stag(u, g)
79101
g = mod(g, 4)
80-
u = rot2(u, g)
81-
ux, uy = u
102+
u = vecrot2(u, g)
103+
# ux, uy = eachslice(u; dims = ndims(u))
104+
ux, uy = u[:, :, 1], u[:, :, 2]
82105
if g in (1, 2)
83106
ux = circshift(ux, -1)
84107
ux[end, :] .= ux[2, :]
@@ -87,7 +110,7 @@ function rot2stag(u, g)
87110
uy = circshift(uy, (0, -1))
88111
uy[:, end] .= uy[:, 2]
89112
end
90-
(ux, uy)
113+
cat(ux, uy; dims = 3)
91114
end
92115

93116
"""

lib/NeuralClosure/src/training.jl

Lines changed: 36 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ create_dataloader_post(trajectories; ntrajectory, nunroll, device = identity) =
3232
@assert nt nunroll "Trajectory too short for nunroll = $nunroll"
3333
istart = rand(rng, 1:nt-nunroll)
3434
it = istart:istart+nunroll
35-
(; u = device.(u[it]), t = t[it])
35+
u = selectdim(u, ndims(u), it) |> Array |> device # convert view to array first
36+
(; u, t = t[it])
3637
end
3738
data, rng
3839
end
@@ -115,25 +116,25 @@ Create a-posteriori loss function.
115116
function create_loss_post(; setup, method, psolver, closure, nsubstep = 1)
116117
closure_model = wrappedclosure(closure, setup)
117118
setup = (; setup..., closure_model)
118-
(; dimension, Iu, x) = setup.grid
119-
D = dimension()
119+
(; Iu) = setup.grid
120+
inside = Iu[1]
121+
@assert all(==(inside), Iu)
120122
loss_post(data, θ) =
121123
sum(data) do (; u, t)
122124
T = eltype(θ)
123-
v = u[1]
125+
ules = selectdim(u, ndims(u), 1) |> collect
124126
stepper =
125-
create_stepper(method; setup, psolver, u = v, temp = nothing, t = t[1])
127+
create_stepper(method; setup, psolver, u = ules, temp = nothing, t = t[1])
126128
loss = zero(T)
127129
for it = 2:length(t)
128130
Δt = (t[it] - t[it-1]) / nsubstep
129131
for isub = 1:nsubstep
130132
stepper = timestep(method, stepper, Δt; θ)
131133
end
132-
a, b = T(0), T(0)
133-
for α = 1:length(u[1])
134-
a += sum(abs2, (stepper.u[α]-u[it][α])[Iu[α]])
135-
b += sum(abs2, u[it][α][Iu[α]])
136-
end
134+
uref = view(u, inside, :, it)
135+
ules = view(stepper.u, inside, :)
136+
a = sum(abs2, ules - uref)
137+
b = sum(abs2, uref)
137138
loss += a / b
138139
end
139140
loss / (length(t) - 1)
@@ -145,14 +146,15 @@ Create a-posteriori relative error.
145146
"""
146147
function create_relerr_post(; data, setup, method, psolver, closure_model, nsubstep = 1)
147148
setup = (; setup..., closure_model)
148-
(; dimension, Iu) = setup.grid
149-
D = dimension()
149+
(; Iu) = setup.grid
150+
inside = Iu[1]
151+
@assert all(==(inside), Iu)
150152
(; u, t) = data
151-
v = copy.(u[1])
153+
v = selectdim(u, ndims(u), 1) |> collect
152154
cache = IncompressibleNavierStokes.ode_method_cache(method, setup)
153155
function relerr_post(θ)
154-
T = eltype(u[1][1])
155-
copyto!.(v, u[1])
156+
T = eltype(u)
157+
copyto!(v, selectdim(u, ndims(u), 1))
156158
stepper = create_stepper(method; setup, psolver, u = v, temp = nothing, t = t[1])
157159
e = zero(T)
158160
for it = 2:length(t)
@@ -161,13 +163,10 @@ function create_relerr_post(; data, setup, method, psolver, closure_model, nsubs
161163
stepper =
162164
IncompressibleNavierStokes.timestep!(method, stepper, Δt; θ, cache)
163165
end
164-
a, b = T(0), T(0)
165-
for α = 1:D
166-
# a += sum(abs2, (stepper.u[α]-u[it][α])[Iu[α]])
167-
# b += sum(abs2, u[it][α][Iu[α]])
168-
a += sum(abs2, view(stepper.u[α] - u[it][α], Iu[α]))
169-
b += sum(abs2, view(u[it][α], Iu[α]))
170-
end
166+
uref = view(u, inside, :, it)
167+
ules = view(stepper.u, inside, :)
168+
a = sum(abs2, ules - uref)
169+
b = sum(abs2, uref)
171170
e += sqrt(a) / sqrt(b)
172171
end
173172
e / (length(t) - 1)
@@ -189,15 +188,17 @@ function create_relerr_symmetry_post(;
189188
(; dimension, Iu) = setup.grid
190189
D = dimension()
191190
T = eltype(u[1])
191+
inside = Iu[1]
192+
@assert all(==(inside), Iu)
192193
cache = IncompressibleNavierStokes.ode_method_cache(method, setup)
193194
function err(θ)
194195
stepper =
195-
create_stepper(method; setup, psolver, u = copy.(u), temp = nothing, t = T(0))
196+
create_stepper(method; setup, psolver, u = copy(u), temp = nothing, t = T(0))
196197
stepper_rot = create_stepper(
197198
method;
198199
setup,
199200
psolver,
200-
u = rot2stag(copy.(u), g),
201+
u = rot2stag(copy(u), g),
201202
temp = nothing,
202203
t = T(0),
203204
)
@@ -207,11 +208,8 @@ function create_relerr_symmetry_post(;
207208
stepper_rot =
208209
IncompressibleNavierStokes.timestep!(method, stepper_rot, Δt; θ, cache)
209210
u_rot = rot2stag(stepper.u, g)
210-
a, b = T(0), T(0)
211-
for α = 1:D
212-
a += sum(abs2, view(stepper_rot.u[α] - u_rot[α], Iu[α]))
213-
b += sum(abs2, view(u_rot[α], Iu[α]))
214-
end
211+
a = sum(abs2, view(stepper_rot.u - u_rot, inside, :))
212+
b = sum(abs2, view(u_rot, inside, :))
215213
e += sqrt(a) / sqrt(b)
216214
end
217215
e / nstep
@@ -225,16 +223,17 @@ function create_relerr_symmetry_prior(; u, setup, g = 1)
225223
(; grid, closure_model) = setup
226224
(; dimension, Iu) = grid
227225
D = dimension()
228-
T = eltype(u[1][1])
226+
T = eltype(u[1])
227+
inside = Iu[1]
228+
@assert all(==(inside), Iu)
229229
function err(θ)
230-
e = sum(u) do u
230+
e = sum(eachslice(u; dims = ndims(u))) do u
231231
cr = closure_model(rot2stag(u, g), θ)
232232
rc = rot2stag(closure_model(u, θ), g)
233-
a, b = T(0), T(0)
234-
for α = 1:D
235-
a += sum(abs2, view(rc[α] - cr[α], Iu[α]))
236-
b += sum(abs2, view(cr[α], Iu[α]))
237-
end
233+
cr = view(cr, inside, :)
234+
rc = view(rc, inside, :)
235+
a = sum(abs2, rc - cr)
236+
b = sum(abs2, cr)
238237
sqrt(a) / sqrt(b)
239238
end
240239
e / length(u)
@@ -306,6 +305,6 @@ function create_callback(
306305
(; callbackstate, callback)
307306
end
308307

308+
getlearningrate(r) = -1 # Fallback
309309
getlearningrate(r::Adam) = r.eta
310310
getlearningrate(r::OptimiserChain{Tuple{Adam,WeightDecay}}) = r.opts[1].eta
311-
getlearningrate(r) = -1

0 commit comments

Comments
 (0)