Skip to content

Commit 924bd6c

Browse files
Merge pull request #2591 from SciML/initialization_equations
Add initialization_equations flattening constructor
2 parents 352e86e + f2bbd06 commit 924bd6c

File tree

7 files changed

+177
-7
lines changed

7 files changed

+177
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ LinearAlgebra = "1"
9393
MLStyle = "0.4.17"
9494
NaNMath = "0.3, 1"
9595
OrderedCollections = "1"
96-
OrdinaryDiffEq = "6.73.0"
96+
OrdinaryDiffEq = "6.82.0"
9797
PrecompileTools = "1"
9898
RecursiveArrayTools = "2.3, 3"
9999
Reexport = "0.2, 1"

src/systems/abstractsystem.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,13 @@ function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sy
889889
map(eq -> namespace_equation(eq, sys; ivs), eqs)
890890
end
891891

892+
function namespace_initialization_equations(
893+
sys::AbstractSystem, ivs = independent_variables(sys))
894+
eqs = initialization_equations(sys)
895+
isempty(eqs) && return Equation[]
896+
map(eq -> namespace_equation(eq, sys; ivs), eqs)
897+
end
898+
892899
function namespace_equation(eq::Equation,
893900
sys,
894901
n = nameof(sys);
@@ -1080,6 +1087,20 @@ function equations(sys::AbstractSystem)
10801087
end
10811088
end
10821089

1090+
function initialization_equations(sys::AbstractSystem)
1091+
eqs = get_initialization_eqs(sys)
1092+
systems = get_systems(sys)
1093+
if isempty(systems)
1094+
return eqs
1095+
else
1096+
eqs = Equation[eqs;
1097+
reduce(vcat,
1098+
namespace_initialization_equations.(get_systems(sys));
1099+
init = Equation[])]
1100+
return eqs
1101+
end
1102+
end
1103+
10831104
function preface(sys::AbstractSystem)
10841105
has_preface(sys) || return nothing
10851106
pre = get_preface(sys)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -908,10 +908,10 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
908908
# TODO: make it work with clocks
909909
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
910910
if sys isa ODESystem && build_initializeprob &&
911-
(implicit_dae || !isempty(missingvars)) &&
912-
all(isequal(Continuous()), ci.var_domain) &&
913-
ModelingToolkit.get_tearing_state(sys) !== nothing &&
914-
t !== nothing
911+
(((implicit_dae || !isempty(missingvars)) &&
912+
all(isequal(Continuous()), ci.var_domain) &&
913+
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
914+
!isempty(initialization_equations(sys))) && t !== nothing
915915
if eltype(u0map) <: Number
916916
u0map = unknowns(sys) .=> u0map
917917
end

src/systems/diffeqs/odesystem.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ function flatten(sys::ODESystem, noeqs = false)
362362
discrete_events = discrete_events(sys),
363363
defaults = defaults(sys),
364364
name = nameof(sys),
365+
initialization_eqs = initialization_equations(sys),
365366
checks = false)
366367
end
367368
end

src/systems/nonlinear/initializesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ function generate_initializesystem(sys::ODESystem;
2323
diffmap = Dict(getfield.(eqs_diff, :lhs) .=> getfield.(eqs_diff, :rhs))
2424
observed_diffmap = Dict(Differential(get_iv(sys)).(getfield.((observed(sys)), :lhs)) .=>
2525
Differential(get_iv(sys)).(getfield.((observed(sys)), :rhs)))
26+
full_diffmap = merge(diffmap, observed_diffmap)
2627

2728
full_states = unique([sts; getfield.((observed(sys)), :lhs)])
2829
set_full_states = Set(full_states)
@@ -39,8 +40,7 @@ function generate_initializesystem(sys::ODESystem;
3940
filtered_u0 = Pair[]
4041
for x in u0map
4142
y = get(schedule.dummy_sub, x[1], x[1])
42-
y = ModelingToolkit.fixpoint_sub(y, observed_diffmap)
43-
y = get(diffmap, y, y)
43+
y = ModelingToolkit.fixpoint_sub(y, full_diffmap)
4444

4545
if y isa Symbolics.Arr
4646
_y = collect(y)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
using ModelingToolkit, OrdinaryDiffEq
2+
3+
t = only(@variables(t))
4+
D = Differential(t)
5+
"""
6+
A simple linear resistor model
7+
8+
![Resistor](https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTpJkiEyqh-BRx27pvVH0GLZ4MP_D1oriBwJhnZdgIq7m17z9VKUWaW9MeNQAz1rTML2ho&usqp=CAU)
9+
"""
10+
@component function Resistor(; name, R = 1.0)
11+
systems = @named begin
12+
p = Pin()
13+
n = Pin()
14+
end
15+
vars = @variables begin
16+
v(t), [guess = 0.0]
17+
i(t), [guess = 0.0]
18+
end
19+
params = @parameters begin
20+
R = R, [description = "Resistance of this Resistor"]
21+
end
22+
eqs = [v ~ p.v - n.v
23+
i ~ p.i
24+
p.i + n.i ~ 0
25+
# Ohm's Law
26+
v ~ i * R]
27+
return ODESystem(eqs, t, vars, params; systems, name)
28+
end
29+
@connector Pin begin
30+
v(t)
31+
i(t), [connect = Flow]
32+
end
33+
@component function ConstantVoltage(; name, V = 1.0)
34+
systems = @named begin
35+
p = Pin()
36+
n = Pin()
37+
end
38+
vars = @variables begin
39+
v(t), [guess = 0.0]
40+
i(t), [guess = 0.0]
41+
end
42+
params = @parameters begin
43+
V = 10
44+
end
45+
eqs = [v ~ p.v - n.v
46+
i ~ p.i
47+
p.i + n.i ~ 0
48+
v ~ V]
49+
return ODESystem(eqs, t, vars, params; systems, name)
50+
end
51+
52+
@component function Capacitor(; name, C = 1.0)
53+
systems = @named begin
54+
p = Pin()
55+
n = Pin()
56+
end
57+
vars = @variables begin
58+
v(t), [guess = 0.0]
59+
i(t), [guess = 0.0]
60+
end
61+
params = @parameters begin
62+
C = C
63+
end
64+
initialization_eqs = [
65+
v ~ 0
66+
]
67+
eqs = [v ~ p.v - n.v
68+
i ~ p.i
69+
p.i + n.i ~ 0
70+
C * D(v) ~ i]
71+
return ODESystem(eqs, t, vars, params; systems, name, initialization_eqs)
72+
end
73+
74+
@component function Ground(; name)
75+
systems = @named begin
76+
g = Pin()
77+
end
78+
eqs = [
79+
g.v ~ 0
80+
]
81+
return ODESystem(eqs, t, [], []; systems, name)
82+
end
83+
84+
@component function Inductor(; name, L = 1.0)
85+
systems = @named begin
86+
p = Pin()
87+
n = Pin()
88+
end
89+
vars = @variables begin
90+
v(t), [guess = 0.0]
91+
i(t), [guess = 0.0]
92+
end
93+
params = @parameters begin
94+
(L = L)
95+
end
96+
eqs = [v ~ p.v - n.v
97+
i ~ p.i
98+
p.i + n.i ~ 0
99+
L * D(i) ~ v]
100+
return ODESystem(eqs, t, vars, params; systems, name)
101+
end
102+
103+
"""
104+
This is an RLC model. This should support markdown. That includes
105+
HTML as well.
106+
"""
107+
@component function RLCModel(; name)
108+
systems = @named begin
109+
resistor = Resistor(R = 100)
110+
capacitor = Capacitor(C = 0.001)
111+
inductor = Inductor(L = 1)
112+
source = ConstantVoltage(V = 30)
113+
ground = Ground()
114+
end
115+
initialization_eqs = [
116+
inductor.i ~ 0
117+
]
118+
eqs = [connect(source.p, inductor.n)
119+
connect(inductor.p, resistor.p, capacitor.p)
120+
connect(resistor.n, ground.g, capacitor.n, source.n)]
121+
return ODESystem(eqs, t, [], []; systems, name, initialization_eqs)
122+
end
123+
"""Run model RLCModel from 0 to 10"""
124+
function simple()
125+
@mtkbuild model = RLCModel()
126+
u0 = []
127+
prob = ODEProblem(model, u0, (0.0, 10.0))
128+
sol = solve(prob)
129+
end
130+
@test SciMLBase.successful_retcode(simple())
131+
132+
@named model = RLCModel()
133+
@test length(ModelingToolkit.get_initialization_eqs(model)) == 1
134+
syslist = ModelingToolkit.get_systems(model)
135+
@test length(ModelingToolkit.get_initialization_eqs(syslist[1])) == 0
136+
@test length(ModelingToolkit.get_initialization_eqs(syslist[2])) == 1
137+
@test length(ModelingToolkit.get_initialization_eqs(syslist[3])) == 0
138+
@test length(ModelingToolkit.get_initialization_eqs(syslist[4])) == 0
139+
@test length(ModelingToolkit.get_initialization_eqs(syslist[5])) == 0
140+
@test length(ModelingToolkit.initialization_equations(model)) == 2
141+
142+
u0 = []
143+
prob = ODEProblem(structural_simplify(model), u0, (0.0, 10.0))
144+
sol = solve(prob, Rodas5P())
145+
@test length(sol[end]) == 2
146+
@test length(equations(prob.f.initializeprob.f.sys)) == 0
147+
@test length(unknowns(prob.f.initializeprob.f.sys)) == 0

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ end
3838
@safetestset "DDESystem Test" include("dde.jl")
3939
@safetestset "NonlinearSystem Test" include("nonlinearsystem.jl")
4040
@safetestset "InitializationSystem Test" include("initializationsystem.jl")
41+
@safetestset "Hierarchical Initialization Equations" include("hierarchical_initialization_eqs.jl")
4142
@safetestset "PDE Construction Test" include("pde.jl")
4243
@safetestset "JumpSystem Test" include("jumpsystem.jl")
4344
@safetestset "Constraints Test" include("constraints.jl")

0 commit comments

Comments
 (0)