203
203
mutable struct TearingState{T <: AbstractSystem } <: AbstractTearingState{T}
204
204
""" The system of equations."""
205
205
sys:: T
206
+ original_eqs:: Vector{Equation}
206
207
""" The set of variables of the system."""
207
208
fullvars:: Vector{BasicSymbolic}
208
209
structure:: SystemStructure
213
214
TransformationState (sys:: AbstractSystem ) = TearingState (sys)
214
215
function system_subset (ts:: TearingState , ieqs:: Vector{Int} )
215
216
eqs = equations (ts)
217
+ @set! ts. original_eqs = ts. original_eqs[ieqs]
216
218
@set! ts. sys. eqs = eqs[ieqs]
217
219
@set! ts. structure = system_subset (ts. structure, ieqs)
218
220
ts
@@ -274,8 +276,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
274
276
sys = process_parameter_equations (sys)
275
277
ivs = independent_variables (sys)
276
278
iv = length (ivs) == 1 ? ivs[1 ] : nothing
277
- # flatten array equations
278
- eqs = flatten_equations (equations (sys))
279
+ # scalarize array equations, without scalarizing arguments to registered functions
280
+ original_eqs = flatten_equations (copy (equations (sys)))
281
+ eqs = copy (original_eqs)
279
282
neqs = length (eqs)
280
283
param_derivative_map = Dict {BasicSymbolic, Any} ()
281
284
# * Scalarize unknowns
@@ -513,7 +516,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
513
516
514
517
eq_to_diff = DiffGraph (nsrcs (graph))
515
518
516
- ts = TearingState (sys, fullvars,
519
+ ts = TearingState (sys, original_eqs, fullvars,
517
520
SystemStructure (complete (var_to_diff), complete (eq_to_diff),
518
521
complete (graph), nothing , var_types, false ),
519
522
Any[], param_derivative_map)
@@ -696,6 +699,22 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
696
699
printstyled (io, " SelectedState" )
697
700
end
698
701
702
+ function make_eqs_zero_equals! (ts:: TearingState )
703
+ neweqs = map (enumerate (get_eqs (ts. sys))) do kvp
704
+ i, eq = kvp
705
+ isalgeq = true
706
+ for j in 𝑠neighbors (ts. structure. graph, i)
707
+ isalgeq &= invview (ts. structure. var_to_diff)[j] === nothing
708
+ end
709
+ if isalgeq
710
+ return 0 ~ eq. rhs - eq. lhs
711
+ else
712
+ return eq
713
+ end
714
+ end
715
+ copyto! (get_eqs (ts. sys), neweqs)
716
+ end
717
+
699
718
function mtkcompile! (state:: TearingState ; simplify = false ,
700
719
check_consistency = true , fully_determined = true , warn_initialize_determined = true ,
701
720
inputs = Any[], outputs = Any[],
@@ -722,6 +741,7 @@ function mtkcompile!(state::TearingState; simplify = false,
722
741
""" ))
723
742
end
724
743
if length (tss) > 1
744
+ make_eqs_zero_equals! (tss[continuous_id])
725
745
# simplify as normal
726
746
sys = _mtkcompile! (tss[continuous_id]; simplify,
727
747
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
0 commit comments