Skip to content

Commit 95f9fdd

Browse files
feat: retain original equations of the system in TearingState
1 parent 4fb31f5 commit 95f9fdd

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

src/systems/systems.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
8080

8181
@unpack structure, fullvars = state
8282
@unpack graph, var_to_diff, var_types = structure
83-
eqs = equations(state)
8483
brown_vars = Int[]
8584
new_idxs = zeros(Int, length(var_types))
8685
idx = 0
@@ -98,7 +97,8 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
9897
Is = Int[]
9998
Js = Int[]
10099
vals = Num[]
101-
new_eqs = copy(eqs)
100+
make_eqs_zero_equals!(state)
101+
new_eqs = copy(equations(state))
102102
dvar2eq = Dict{Any, Int}()
103103
for (v, dv) in enumerate(var_to_diff)
104104
dv === nothing && continue

src/systems/systemstructure.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ end
203203
mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
204204
"""The system of equations."""
205205
sys::T
206+
original_eqs::Vector{Equation}
206207
"""The set of variables of the system."""
207208
fullvars::Vector{BasicSymbolic}
208209
structure::SystemStructure
@@ -213,6 +214,7 @@ end
213214
TransformationState(sys::AbstractSystem) = TearingState(sys)
214215
function system_subset(ts::TearingState, ieqs::Vector{Int})
215216
eqs = equations(ts)
217+
@set! ts.original_eqs = ts.original_eqs[ieqs]
216218
@set! ts.sys.eqs = eqs[ieqs]
217219
@set! ts.structure = system_subset(ts.structure, ieqs)
218220
ts
@@ -274,8 +276,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
274276
sys = process_parameter_equations(sys)
275277
ivs = independent_variables(sys)
276278
iv = length(ivs) == 1 ? ivs[1] : nothing
277-
# flatten array equations
278-
eqs = flatten_equations(equations(sys))
279+
# scalarize array equations, without scalarizing arguments to registered functions
280+
original_eqs = flatten_equations(copy(equations(sys)))
281+
eqs = copy(original_eqs)
279282
neqs = length(eqs)
280283
param_derivative_map = Dict{BasicSymbolic, Any}()
281284
# * Scalarize unknowns
@@ -513,7 +516,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
513516

514517
eq_to_diff = DiffGraph(nsrcs(graph))
515518

516-
ts = TearingState(sys, fullvars,
519+
ts = TearingState(sys, original_eqs, fullvars,
517520
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
518521
complete(graph), nothing, var_types, false),
519522
Any[], param_derivative_map)
@@ -696,6 +699,22 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
696699
printstyled(io, " SelectedState")
697700
end
698701

702+
function make_eqs_zero_equals!(ts::TearingState)
703+
neweqs = map(enumerate(get_eqs(ts.sys))) do kvp
704+
i, eq = kvp
705+
isalgeq = true
706+
for j in 𝑠neighbors(ts.structure.graph, i)
707+
isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing
708+
end
709+
if isalgeq
710+
return 0 ~ eq.rhs - eq.lhs
711+
else
712+
return eq
713+
end
714+
end
715+
copyto!(get_eqs(ts.sys), neweqs)
716+
end
717+
699718
function mtkcompile!(state::TearingState; simplify = false,
700719
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
701720
inputs = Any[], outputs = Any[],
@@ -722,6 +741,7 @@ function mtkcompile!(state::TearingState; simplify = false,
722741
"""))
723742
end
724743
if length(tss) > 1
744+
make_eqs_zero_equals!(tss[continuous_id])
725745
# simplify as normal
726746
sys = _mtkcompile!(tss[continuous_id]; simplify,
727747
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,

0 commit comments

Comments
 (0)