From 88316bca69d83c8f418bc9c62340f01673ec45a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Syver=20D=C3=B8ving=20Agdestein?= Date: Mon, 4 Dec 2023 13:38:04 +0100 Subject: [PATCH] Add fixed length solver loop --- src/closures/cnn.jl | 5 ++--- src/solvers/solve_unsteady.jl | 38 ++++++++++++++++------------------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/src/closures/cnn.jl b/src/closures/cnn.jl index 255a9f875..a33f8fd88 100644 --- a/src/closures/cnn.jl +++ b/src/closures/cnn.jl @@ -22,14 +22,13 @@ function cnn(; rng = Random.default_rng(), ) r, c, σ, b = radii, channels, activations, use_bias - (; grid) = setup - (; dimension, x, Δu) = grid + (; T, grid) = setup + (; dimension) = grid D = dimension() dx = map(d -> d[2:end-1], Δu) # Weight initializer - T = eltype(x[1]) glorot_uniform_T(rng::AbstractRNG, dims...) = glorot_uniform(rng, T, dims...) # Make sure there are two force fields in output diff --git a/src/solvers/solve_unsteady.jl b/src/solvers/solve_unsteady.jl index 4e10ff0e6..601fb9941 100644 --- a/src/solvers/solve_unsteady.jl +++ b/src/solvers/solve_unsteady.jl @@ -9,7 +9,6 @@ Δt = zero(eltype(u₀[1])), cfl = 1, n_adapt_Δt = 1, - inplace = true, docopy = true, processors = (;), ) @@ -40,7 +39,6 @@ function solve_unsteady( Δt = zero(eltype(u₀[1])), cfl = 1, n_adapt_Δt = 1, - inplace = true, docopy = true, processors = (;), ) @@ -51,27 +49,19 @@ function solve_unsteady( t_start, t_end = tlims isadaptive = isnothing(Δt) - if !isadaptive - nstep = round(Int, (t_end - t_start) / Δt) - Δt = (t_end - t_start) / nstep - end - if inplace - cache = ode_method_cache(method, setup, u₀, p₀) - end + # Cache arrays for intermediate computations + cache = ode_method_cache(method, setup, u₀, p₀) # Time stepper stepper = create_stepper(method; setup, pressure_solver, u = u₀, p = p₀, t = t_start) - # Get initial time step - isadaptive && (Δt = get_timestep(stepper, cfl)) - # Initialize processors for iteration results state = Observable(get_state(stepper)) initialized = (; (k => v.initialize(state) for (k, v) in pairs(processors))...) - while stepper.t < t_end - if isadaptive + if isadaptive + while stepper.t < t_end if stepper.n % n_adapt_Δt == 0 # Change timestep based on operators Δt = get_timestep(stepper, cfl) @@ -79,17 +69,23 @@ function solve_unsteady( # Make sure not to step past `t_end` Δt = min(Δt, t_end - stepper.t) - end - # Perform a single time step with the time integration method - if inplace + # Perform a single time step with the time integration method stepper = timestep!(method, stepper, Δt; cache) - else - stepper = timestep(method, stepper, Δt) + + # Process iteration results with each processor + state[] = get_state(stepper) end + else + nstep = round(Int, (t_end - t_start) / Δt) + Δt = (t_end - t_start) / nstep + for it = 1:nstep + # Perform a single time step with the time integration method + stepper = timestep!(method, stepper, Δt; cache) - # Process iteration results with each processor - state[] = get_state(stepper) + # Process iteration results with each processor + state[] = get_state(stepper) + end end # Final state