Skip to content

Commit cd558a2

Browse files
Merge pull request #3531 from vyudu/cost_coalesce
feat: add `cost` and `coalesce` to ODESystem
2 parents 5b8cc35 + 220db08 commit cd558a2

File tree

5 files changed

+168
-28
lines changed

5 files changed

+168
-28
lines changed

src/systems/abstractsystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,9 @@ for prop in [:eqs
926926
:tstops
927927
:index_cache
928928
:is_scalar_noise
929-
:isscheduled]
929+
:isscheduled
930+
:costs
931+
:consolidate]
930932
fname_get = Symbol(:get_, prop)
931933
fname_has = Symbol(:has_, prop)
932934
@eval begin

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,7 @@ end
697697
```julia
698698
DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan,
699699
parammap = DiffEqBase.NullParameters();
700+
allow_cost = false,
700701
version = nothing, tgrad = false,
701702
jac = false,
702703
checkbounds = false, sparse = false,
@@ -730,6 +731,7 @@ end
730731
function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
731732
tspan = get_tspan(sys),
732733
parammap = DiffEqBase.NullParameters();
734+
allow_cost = false,
733735
callback = nothing,
734736
check_length = true,
735737
warn_initialize_determined = true,
@@ -745,6 +747,12 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
745747
Consider a BVProblem instead.")
746748
end
747749

750+
if !isempty(get_costs(sys)) && !allow_cost
751+
error("ODEProblem will not optimize solutions of ODESystems that have associated cost functions.
752+
Solvers for optimal control problems are forthcoming. In order to bypass this error (e.g.
753+
to check the cost of a regular solution), pass `allow_cost` = true into the constructor.")
754+
end
755+
748756
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
749757
t = tspan !== nothing ? tspan[1] : tspan,
750758
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)
@@ -796,21 +804,19 @@ If an ODESystem without `constraints` is specified, it will be treated as an ini
796804
797805
```julia
798806
@parameters g t_c = 0.5
799-
@variables x(..) y(t) [state_priority = 10] λ(t)
807+
@variables x(..) y(t) λ(t)
800808
eqs = [D(D(x(t))) ~ λ * x(t)
801809
D(D(y)) ~ λ * y - g
802810
x(t)^2 + y^2 ~ 1]
803811
cstr = [x(0.5) ~ 1]
804-
@named cstrs = ConstraintsSystem(cstr, t)
805-
@mtkbuild pend = ODESystem(eqs, t)
812+
@mtkbuild pend = ODESystem(eqs, t; constraints = cstrs)
806813
807814
tspan = (0.0, 1.5)
808815
u0map = [x(t) => 0.6, y => 0.8]
809816
parammap = [g => 1]
810817
guesses = [λ => 1]
811-
constraints = [x(0.5) ~ 1]
812818
813-
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
819+
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; guesses, check_length = false)
814820
```
815821
816822
If the `ODESystem` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
@@ -839,6 +845,7 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
839845
tspan = get_tspan(sys),
840846
parammap = DiffEqBase.NullParameters();
841847
guesses = Dict(),
848+
allow_cost = false,
842849
version = nothing, tgrad = false,
843850
callback = nothing,
844851
check_length = true,
@@ -852,6 +859,12 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
852859
end
853860
!isnothing(callback) && error("BVP solvers do not support callbacks.")
854861

862+
if !isempty(get_costs(sys)) && !allow_cost
863+
error("BVProblem will not optimize solutions of ODESystems that have associated cost functions.
864+
Solvers for optimal control problems are forthcoming. In order to bypass this error (e.g.
865+
to check the cost of a regular solution), pass `allow_cost` = true into the constructor.")
866+
end
867+
855868
has_alg_eqs(sys) &&
856869
error("The BVProblem constructor currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.
857870

@@ -924,7 +937,7 @@ function generate_function_bc(sys::ODESystem, u0, u0_idxs, tspan; kwargs...)
924937
exprs = vcat(init_conds, cons)
925938
_p = reorder_parameters(sys, ps)
926939

927-
build_function_wrapper(sys, exprs, sol, _p..., t; output_type = Array, kwargs...)
940+
build_function_wrapper(sys, exprs, sol, _p..., iv; output_type = Array, kwargs...)
928941
end
929942

930943
"""
@@ -952,11 +965,19 @@ end
952965

953966
function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
954967
parammap = DiffEqBase.NullParameters();
968+
allow_cost = false,
955969
warn_initialize_determined = true,
956970
check_length = true, eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip}
957971
if !iscomplete(sys)
958972
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEProblem`.")
959973
end
974+
975+
if !isempty(get_costs(sys)) && !allow_cost
976+
error("DAEProblem will not optimize solutions of ODESystems that have associated cost functions.
977+
Solvers for optimal control problems are forthcoming. In order to bypass this error (e.g.
978+
to check the cost of a regular solution), pass `allow_cost` = true into the constructor.")
979+
end
980+
960981
f, du0, u0, p = process_SciMLProblem(DAEFunction{iip}, sys, u0map, parammap;
961982
implicit_dae = true, du0map = du0map, check_length,
962983
t = tspan !== nothing ? tspan[1] : tspan,

src/systems/diffeqs/odesystem.jl

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ struct ODESystem <: AbstractODESystem
5151
observed::Vector{Equation}
5252
"""System of constraints that must be satisfied by the solution to the system."""
5353
constraintsystem::Union{Nothing, ConstraintsSystem}
54+
"""A set of expressions defining the costs of the system for optimal control."""
55+
costs::Vector
56+
"""Takes the cost vector and returns a scalar for optimization."""
57+
consolidate::Union{Nothing, Function}
5458
"""
5559
Time-derivative matrix. Note: this field will not be defined until
5660
[`calculate_tgrad`](@ref) is called on the system.
@@ -205,7 +209,8 @@ struct ODESystem <: AbstractODESystem
205209
parent::Any
206210

207211
function ODESystem(
208-
tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, constraints, tgrad,
212+
tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls,
213+
observed, constraints, costs, consolidate, tgrad,
209214
jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses,
210215
torn_matching, initializesystem, initialization_eqs, schedule,
211216
connector_type, preface, cevents,
@@ -229,7 +234,7 @@ struct ODESystem <: AbstractODESystem
229234
check_units(u, deqs)
230235
end
231236
new(tag, deqs, iv, dvs, ps, tspan, var_to_name,
232-
ctrls, observed, constraints, tgrad, jac,
237+
ctrls, observed, constraints, costs, consolidate, tgrad, jac,
233238
ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, torn_matching,
234239
initializesystem, initialization_eqs, schedule, connector_type, preface,
235240
cevents, devents, parameter_dependencies, assertions, metadata,
@@ -243,6 +248,8 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
243248
controls = Num[],
244249
observed = Equation[],
245250
constraintsystem = nothing,
251+
costs = Num[],
252+
consolidate = nothing,
246253
systems = ODESystem[],
247254
tspan = nothing,
248255
name = nothing,
@@ -323,22 +330,27 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps;
323330
cons = get_constraintsystem(sys)
324331
cons !== nothing && push!(conssystems, cons)
325332
end
326-
@show conssystems
327333
@set! constraintsystem.systems = conssystems
328334
end
335+
costs = wrap.(costs)
336+
337+
if length(costs) > 1 && isnothing(consolidate)
338+
error("Must specify a consolidation function for the costs vector.")
339+
end
329340

330341
assertions = Dict{BasicSymbolic, Any}(unwrap(k) => v for (k, v) in assertions)
331342

332343
ODESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)),
333-
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, constraintsystem, tgrad, jac,
344+
deqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed,
345+
constraintsystem, costs, consolidate, tgrad, jac,
334346
ctrl_jac, Wfact, Wfact_t, name, description, systems,
335347
defaults, guesses, nothing, initializesystem,
336348
initialization_eqs, schedule, connector_type, preface, cont_callbacks,
337349
disc_callbacks, parameter_dependencies, assertions,
338350
metadata, gui_metadata, is_dde, tstops, checks = checks)
339351
end
340352

341-
function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
353+
function ODESystem(eqs, iv; constraints = Equation[], costs = Num[], kwargs...)
342354
diffvars, allunknowns, ps, eqs = process_equations(eqs, iv)
343355

344356
for eq in get(kwargs, :parameter_dependencies, Equation[])
@@ -384,8 +396,16 @@ function ODESystem(eqs, iv; constraints = Equation[], kwargs...)
384396
end
385397
end
386398

399+
if !isempty(costs)
400+
coststs, costps = process_costs(costs, allunknowns, new_ps, iv)
401+
for p in costps
402+
!in(p, new_ps) && push!(new_ps, p)
403+
end
404+
end
405+
costs = wrap.(costs)
406+
387407
return ODESystem(eqs, iv, collect(Iterators.flatten((diffvars, algevars, consvars))),
388-
collect(new_ps); constraintsystem, kwargs...)
408+
collect(new_ps); constraintsystem, costs, kwargs...)
389409
end
390410

391411
# NOTE: equality does not check cached Jacobian
@@ -400,7 +420,9 @@ function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
400420
_eq_unordered(get_ps(sys1), get_ps(sys2)) &&
401421
_eq_unordered(continuous_events(sys1), continuous_events(sys2)) &&
402422
_eq_unordered(discrete_events(sys1), discrete_events(sys2)) &&
403-
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2)))
423+
all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2))) &&
424+
isequal(get_constraintsystem(sys1), get_constraintsystem(sys2)) &&
425+
_eq_unordered(get_costs(sys1), get_costs(sys2))
404426
end
405427

406428
function flatten(sys::ODESystem, noeqs = false)
@@ -734,22 +756,53 @@ function Base.show(io::IO, mime::MIME"text/plain", sys::ODESystem; hint = true,
734756
return nothing
735757
end
736758

737-
# Validate that all the variables in the BVP constraints are well-formed states or parameters.
738-
# - Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
739-
# - Callable/delay parameters should be parameters of the system (and have one arg, etc.)
759+
"""
760+
Build the constraint system for the ODESystem.
761+
"""
740762
function process_constraint_system(
741763
constraints::Vector{Equation}, sts, ps, iv; consname = :cons)
742764
isempty(constraints) && return nothing
743765

744766
constraintsts = OrderedSet()
745767
constraintps = OrderedSet()
746-
747768
for cons in constraints
748769
collect_vars!(constraintsts, constraintps, cons, iv)
749770
end
750771

751772
# Validate the states.
752-
for var in constraintsts
773+
validate_vars_and_find_ps!(constraintsts, constraintps, sts, iv)
774+
775+
ConstraintsSystem(
776+
constraints, collect(constraintsts), collect(constraintps); name = consname)
777+
end
778+
779+
"""
780+
Process the costs for the constraint system.
781+
"""
782+
function process_costs(costs::Vector, sts, ps, iv)
783+
coststs = OrderedSet()
784+
costps = OrderedSet()
785+
for cost in costs
786+
collect_vars!(coststs, costps, cost, iv)
787+
end
788+
789+
validate_vars_and_find_ps!(coststs, costps, sts, iv)
790+
coststs, costps
791+
end
792+
793+
"""
794+
Validate that all the variables in an auxiliary system of the ODESystem (constraint or costs) are
795+
well-formed states or parameters.
796+
- Callable/delay variables (e.g. of the form x(0.6) should be unknowns of the system (and have one arg, etc.)
797+
- Callable/delay parameters should be parameters of the system
798+
799+
Return the set of additional parameters found in the system, e.g. in x(p) ~ 3 then p should be added as a
800+
parameter of the system.
801+
"""
802+
function validate_vars_and_find_ps!(auxvars, auxps, sysvars, iv)
803+
sts = sysvars
804+
805+
for var in auxvars
753806
if !iscall(var)
754807
occursin(iv, var) && (var sts ||
755808
throw(ArgumentError("Time-dependent variable $var is not an unknown of the system.")))
@@ -764,13 +817,42 @@ function process_constraint_system(
764817
arg isa AbstractFloat ||
765818
throw(ArgumentError("Invalid argument specified for variable $var. The argument of the variable should be either $iv, a parameter, or a value specifying the time that the constraint holds."))
766819

767-
isparameter(arg) && push!(constraintps, arg)
820+
isparameter(arg) && push!(auxps, arg)
768821
else
769822
var sts &&
770823
@warn "Variable $var has no argument. It will be interpreted as $var($iv), and the constraint will apply to the entire interval."
771824
end
772825
end
826+
end
773827

774-
ConstraintsSystem(
775-
constraints, collect(constraintsts), collect(constraintps); name = consname)
828+
"""
829+
Generate a function that takes a solution object and computes the cost function obtained by coalescing the costs vector.
830+
"""
831+
function generate_cost_function(sys::ODESystem, kwargs...)
832+
costs = get_costs(sys)
833+
consolidate = get_consolidate(sys)
834+
iv = get_iv(sys)
835+
836+
ps = parameters(sys; initial_parameters = false)
837+
sts = unknowns(sys)
838+
np = length(ps)
839+
ns = length(sts)
840+
stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
841+
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])
842+
843+
@variables sol(..)[1:ns]
844+
for st in vars(costs)
845+
x = operation(st)
846+
t = only(arguments(st))
847+
idx = stidxmap[x(iv)]
848+
849+
costs = map(c -> Symbolics.fast_substitute(c, Dict(x(t) => sol(t)[idx])), costs)
850+
end
851+
852+
_p = reorder_parameters(sys, ps)
853+
fs = build_function_wrapper(sys, costs, sol, _p..., t; output_type = Array, kwargs...)
854+
vc_oop, vc_iip = eval_or_rgf.(fs)
855+
856+
cost(sol, p, t) = consolidate(vc_oop(sol, p, t))
857+
return cost
776858
end

0 commit comments

Comments
 (0)