Skip to content

Commit

Permalink
Add fixed length solver loop
Browse files Browse the repository at this point in the history
  • Loading branch information
agdestein committed Dec 4, 2023
1 parent 08fbc47 commit 88316bc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
5 changes: 2 additions & 3 deletions src/closures/cnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@ function cnn(;
rng = Random.default_rng(),
)
r, c, σ, b = radii, channels, activations, use_bias
(; grid) = setup
(; dimension, x, Δu) = grid
(; T, grid) = setup
(; dimension) = grid
D = dimension()

dx = map(d -> d[2:end-1], Δu)

# Weight initializer
T = eltype(x[1])
glorot_uniform_T(rng::AbstractRNG, dims...) = glorot_uniform(rng, T, dims...)

# Make sure there are two force fields in output
Expand Down
38 changes: 17 additions & 21 deletions src/solvers/solve_unsteady.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Δt = zero(eltype(u₀[1])),
cfl = 1,
n_adapt_Δt = 1,
inplace = true,
docopy = true,
processors = (;),
)
Expand Down Expand Up @@ -40,7 +39,6 @@ function solve_unsteady(
Δt = zero(eltype(u₀[1])),
cfl = 1,
n_adapt_Δt = 1,
inplace = true,
docopy = true,
processors = (;),
)
Expand All @@ -51,45 +49,43 @@ function solve_unsteady(

t_start, t_end = tlims
isadaptive = isnothing(Δt)
if !isadaptive
nstep = round(Int, (t_end - t_start) / Δt)
Δt = (t_end - t_start) / nstep
end

if inplace
cache = ode_method_cache(method, setup, u₀, p₀)
end
# Cache arrays for intermediate computations
cache = ode_method_cache(method, setup, u₀, p₀)

# Time stepper
stepper = create_stepper(method; setup, pressure_solver, u = u₀, p = p₀, t = t_start)

# Get initial time step
isadaptive && (Δt = get_timestep(stepper, cfl))

# Initialize processors for iteration results
state = Observable(get_state(stepper))
initialized = (; (k => v.initialize(state) for (k, v) in pairs(processors))...)

while stepper.t < t_end
if isadaptive
if isadaptive
while stepper.t < t_end
if stepper.n % n_adapt_Δt == 0
# Change timestep based on operators
Δt = get_timestep(stepper, cfl)
end

# Make sure not to step past `t_end`
Δt = min(Δt, t_end - stepper.t)
end

# Perform a single time step with the time integration method
if inplace
# Perform a single time step with the time integration method
stepper = timestep!(method, stepper, Δt; cache)
else
stepper = timestep(method, stepper, Δt)

# Process iteration results with each processor
state[] = get_state(stepper)
end
else
nstep = round(Int, (t_end - t_start) / Δt)
Δt = (t_end - t_start) / nstep
for it = 1:nstep
# Perform a single time step with the time integration method
stepper = timestep!(method, stepper, Δt; cache)

# Process iteration results with each processor
state[] = get_state(stepper)
# Process iteration results with each processor
state[] = get_state(stepper)
end
end

# Final state
Expand Down

0 comments on commit 88316bc

Please sign in to comment.