Skip to content

Commit 9d8b816

Browse files
Test parameters + defaults in initialization propagation
Fixes #2774
1 parent 7170f64 commit 9d8b816

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,9 @@ function get_u0(
823823
if parammap !== nothing
824824
defs = mergedefaults(defs, parammap, ps)
825825
end
826-
obs = filter!(x -> !(x[1] isa Number), map(x -> x.rhs => x.lhs, observed(sys)))
826+
827+
obs = filter!(x -> !(x[1] isa Number),
828+
map(x -> isparameter(x.rhs) ? x.lhs => x.rhs : x.rhs => x.lhs, observed(sys)))
827829
observedmap = isempty(obs) ? Dict() : todict(obs)
828830
defs = mergedefaults(defs, observedmap, u0map, dvs)
829831
if symbolic_u0

test/guess_propagation.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,38 @@ prob = ODEProblem(sys, [], tspan, [])
7373

7474
@test prob.f.initializeprob[x] == -1.0
7575
sol = solve(prob.f.initializeprob; show_trace = Val(true))
76+
77+
# Test parameters + defaults
78+
# https://github.com/SciML/ModelingToolkit.jl/issues/2774
79+
80+
@parameters x0
81+
@variables x(t)
82+
@variables y(t) = x
83+
@mtkbuild sys = ODESystem([x ~ x0, D(y) ~ x], t)
84+
prob = ODEProblem(sys, [], (0.0, 1.0), [x0 => 1.0])
85+
@test prob[x] == 1.0
86+
@test prob[y] == 1.0
87+
88+
@parameters x0
89+
@variables x(t)
90+
@variables y(t) = x0
91+
@mtkbuild sys = ODESystem([x ~ x0, D(y) ~ x], t)
92+
prob = ODEProblem(sys, [], (0.0, 1.0), [x0 => 1.0])
93+
prob[x] == 1.0
94+
prob[y] == 1.0
95+
96+
@parameters x0
97+
@variables x(t)
98+
@variables y(t) = x0
99+
@mtkbuild sys = ODESystem([x ~ y, D(y) ~ x], t)
100+
prob = ODEProblem(sys, [], (0.0, 1.0), [x0 => 1.0])
101+
prob[x] == 1.0
102+
prob[y] == 1.0
103+
104+
@parameters x0
105+
@variables x(t) = x0
106+
@variables y(t) = x
107+
@mtkbuild sys = ODESystem([x ~ y, D(y) ~ x], t)
108+
prob = ODEProblem(sys, [], (0.0, 1.0), [x0 => 1.0])
109+
prob[x] == 1.0
110+
prob[y] == 1.0

0 commit comments

Comments
 (0)