Skip to content

Commit 39885d8

Browse files
Add propagation of guesses from parameters and observed
Fixes #2716
1 parent 3b340c2 commit 39885d8

File tree

5 files changed

+94
-3
lines changed

5 files changed

+94
-3
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,8 @@ function get_u0_p(sys,
790790
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
791791
end
792792
end
793-
defs = mergedefaults(defs, u0map, dvs)
793+
observedmap = todict(map(x->x.rhs => x.lhs,observed(sys)))
794+
defs = mergedefaults(defs, observedmap, u0map, dvs)
794795
for (k, v) in defs
795796
if Symbolics.isarraysymbolic(k)
796797
ks = scalarize(k)
@@ -821,7 +822,9 @@ function get_u0(
821822
if parammap !== nothing
822823
defs = mergedefaults(defs, parammap, ps)
823824
end
824-
defs = mergedefaults(defs, u0map, dvs)
825+
obs = map(x->x.rhs => x.lhs, observed(sys))
826+
observedmap = isempty(obs) ? Dict() : todict(obs)
827+
defs = mergedefaults(defs, observedmap, u0map, dvs)
825828
if symbolic_u0
826829
u0 = varmap_to_vars(
827830
u0map, dvs; defaults = defs, tofloat = false, use_union = false, toterm)

src/systems/nonlinear/initializesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function generate_initializesystem(sys::ODESystem;
3131
schedule = getfield(sys, :schedule)
3232

3333
if schedule !== nothing
34-
guessmap = [x[2] => get(guesses, x[1], default_dd_value)
34+
guessmap = [x[1] => get(guesses, x[1], default_dd_value)
3535
for x in schedule.dummy_sub]
3636
dd_guess = Dict(filter(x -> !isnothing(x[1]), guessmap))
3737
if u0map === nothing || isempty(u0map)

src/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,18 @@ function mergedefaults(defaults, varmap, vars)
617617
end
618618
end
619619

620+
function mergedefaults(defaults, observedmap, varmap, vars)
621+
defs = if varmap isa Dict
622+
merge(observedmap, defaults, varmap)
623+
elseif eltype(varmap) <: Pair
624+
merge(observedmap, defaults, Dict(varmap))
625+
elseif eltype(varmap) <: Number
626+
merge(observedmap, defaults, Dict(zip(vars, varmap)))
627+
else
628+
merge(observedmap, defaults)
629+
end
630+
end
631+
620632
@noinline function throw_missingvars_in_sys(vars)
621633
throw(ArgumentError("$vars are either missing from the variable map or missing from the system's unknowns/parameters list."))
622634
end

test/guess_propagation.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
using ModelingToolkit, OrdinaryDiffEq
2+
using ModelingToolkit: D, t_nounits as t
3+
using Test
4+
5+
# Standard case
6+
7+
@variables x(t) [guess = 2]
8+
@variables y(t)
9+
eqs = [D(x) ~ 1
10+
x ~ y]
11+
initialization_eqs = [1 ~ exp(1 + x)]
12+
13+
@named sys = ODESystem(eqs, t; initialization_eqs)
14+
sys = complete(structural_simplify(sys))
15+
tspan = (0.0, 0.2)
16+
prob = ODEProblem(sys, [], tspan, [])
17+
18+
@test prob.f.initializeprob[y] == 2.0
19+
@test prob.f.initializeprob[x] == 2.0
20+
sol = solve(prob.f.initializeprob; show_trace=Val(true))
21+
22+
# Guess via observed
23+
24+
@variables x(t)
25+
@variables y(t) [guess = 2]
26+
eqs = [D(x) ~ 1
27+
x ~ y]
28+
initialization_eqs = [1 ~ exp(1 + x)]
29+
30+
@named sys = ODESystem(eqs, t; initialization_eqs)
31+
sys = complete(structural_simplify(sys))
32+
tspan = (0.0, 0.2)
33+
prob = ODEProblem(sys, [], tspan, [])
34+
35+
@test prob.f.initializeprob[x] == 2.0
36+
@test prob.f.initializeprob[y] == 2.0
37+
sol = solve(prob.f.initializeprob; show_trace=Val(true))
38+
39+
# Guess via parameter
40+
41+
@parameters a = -1.0
42+
@variables x(t) [guess = a]
43+
44+
eqs = [D(x) ~ a]
45+
46+
initialization_eqs = [1 ~ exp(1 + x)]
47+
48+
@named sys = ODESystem(eqs, t; initialization_eqs)
49+
sys = complete(structural_simplify(sys))
50+
51+
tspan = (0.0, 0.2)
52+
prob = ODEProblem(sys, [], tspan, [])
53+
54+
@test prob.f.initializeprob[x] == -1.0
55+
sol = solve(prob.f.initializeprob; show_trace=Val(true))
56+
57+
# Guess via observed parameter
58+
59+
@parameters a = -1.0
60+
@variables x(t)
61+
@variables y(t) [guess = a]
62+
63+
eqs = [D(x) ~ a,
64+
y ~ x]
65+
66+
initialization_eqs = [1 ~ exp(1 + x)]
67+
68+
@named sys = ODESystem(eqs, t; initialization_eqs)
69+
sys = complete(structural_simplify(sys))
70+
71+
tspan = (0.0, 0.2)
72+
prob = ODEProblem(sys, [], tspan, [])
73+
74+
@test prob.f.initializeprob[x] == -1.0
75+
sol = solve(prob.f.initializeprob; show_trace=Val(true))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ end
3838
@safetestset "DDESystem Test" include("dde.jl")
3939
@safetestset "NonlinearSystem Test" include("nonlinearsystem.jl")
4040
@safetestset "InitializationSystem Test" include("initializationsystem.jl")
41+
@safetestset "Guess Propagation" include("guess_propagation.jl")
4142
@safetestset "Hierarchical Initialization Equations" include("hierarchical_initialization_eqs.jl")
4243
@safetestset "PDE Construction Test" include("pde.jl")
4344
@safetestset "JumpSystem Test" include("jumpsystem.jl")

0 commit comments

Comments
 (0)