Skip to content

Commit a75b06a

Browse files
feat: propagate state machines in structural simplification
1 parent 5d13ee9 commit a75b06a

File tree

2 files changed

+75
-7
lines changed

2 files changed

+75
-7
lines changed

src/systems/systems.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
7575
return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify)
7676
end
7777

78+
sys, statemachines = extract_top_level_statemachines(sys)
7879
sys = expand_connections(sys)
79-
state = TearingState(sys; sort_eqs)
80+
state = TearingState(sys)
81+
append!(state.statemachines, statemachines)
8082

8183
@unpack structure, fullvars = state
8284
@unpack graph, var_to_diff, var_types = structure

src/systems/systemstructure.jl

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
209209
structure::SystemStructure
210210
extra_eqs::Vector
211211
param_derivative_map::Dict{BasicSymbolic, Any}
212+
statemachines::Vector{T}
212213
end
213214

214215
TransformationState(sys::AbstractSystem) = TearingState(sys)
@@ -217,6 +218,22 @@ function system_subset(ts::TearingState, ieqs::Vector{Int})
217218
@set! ts.original_eqs = ts.original_eqs[ieqs]
218219
@set! ts.sys.eqs = eqs[ieqs]
219220
@set! ts.structure = system_subset(ts.structure, ieqs)
221+
if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys))
222+
names = Symbol[]
223+
for eq in get_eqs(ts.sys)
224+
if eq.lhs isa Transition
225+
push!(names, first(namespace_hierarchy(nameof(eq.rhs.from))))
226+
push!(names, first(namespace_hierarchy(nameof(eq.rhs.to))))
227+
elseif eq.lhs isa InitialState
228+
push!(names, first(namespace_hierarchy(nameof(eq.rhs.s))))
229+
else
230+
error("Unhandled state machine operator")
231+
end
232+
end
233+
@set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines)
234+
else
235+
@set! ts.statemachines = eltype(ts.statemachines)[]
236+
end
220237
ts
221238
end
222239

@@ -270,6 +287,49 @@ function symbolic_contains(var, set)
270287
all(x -> x in set, Symbolics.scalarize(var))
271288
end
272289

290+
"""
291+
$(TYPEDSIGNATURES)
292+
293+
Descend through the system hierarchy and look for statemachines. Remove equations from
294+
the inner statemachine systems. Return the new `sys` and an array of top-level
295+
statemachines.
296+
"""
297+
function extract_top_level_statemachines(sys::AbstractSystem)
298+
eqs = get_eqs(sys)
299+
300+
if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs)
301+
# top-level statemachine
302+
with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys))
303+
return with_removed, [sys]
304+
elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs)
305+
# error: can't mix
306+
error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.")
307+
else
308+
# descend
309+
subsystems = get_systems(sys)
310+
newsubsystems = eltype(subsystems)[]
311+
statemachines = eltype(subsystems)[]
312+
for subsys in subsystems
313+
newsubsys, sub_statemachines = extract_top_level_statemachines(subsys)
314+
push!(newsubsystems, newsubsys)
315+
append!(statemachines, sub_statemachines)
316+
end
317+
@set! sys.systems = newsubsystems
318+
return sys, statemachines
319+
end
320+
end
321+
322+
"""
323+
$(TYPEDSIGNATURES)
324+
325+
Return `sys` with all equations (including those in subsystems) removed.
326+
"""
327+
function remove_child_equations(sys::AbstractSystem)
328+
@set! sys.eqs = eltype(get_eqs(sys))[]
329+
@set! sys.systems = map(remove_child_equations, get_systems(sys))
330+
return sys
331+
end
332+
273333
function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
274334
# flatten system
275335
sys = flatten(sys)
@@ -334,9 +394,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
334394
# change the equation if the RHS is `missing` so the rest of this loop works
335395
eq = 0.0 ~ coalesce(eq.rhs, 0.0)
336396
end
337-
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
338-
if !_iszero(eq.lhs)
397+
is_statemachine_equation = false
398+
if eq.lhs isa StateMachineOperator
399+
is_statemachine_equation = true
400+
eq = eq
401+
rhs = eq.rhs
402+
elseif _iszero(eq.lhs)
403+
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
404+
else
339405
lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs
406+
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
340407
eq = 0 ~ rhs - lhs
341408
end
342409
empty!(varsbuf)
@@ -400,8 +467,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
400467
addvar!(v, VARIABLE)
401468
end
402469
end
403-
404-
if isalgeq
470+
if isalgeq || is_statemachine_equation
405471
eqs[i] = eq
406472
else
407473
eqs[i] = eqs[i].lhs ~ rhs
@@ -519,8 +585,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
519585
ts = TearingState(sys, original_eqs, fullvars,
520586
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
521587
complete(graph), nothing, var_types, false),
522-
Any[], param_derivative_map)
523-
588+
Any[], param_derivative_map, typeof(sys)[])
524589
return ts
525590
end
526591

@@ -747,6 +812,7 @@ function mtkcompile!(state::TearingState; simplify = false,
747812
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
748813
check_consistency, fully_determined,
749814
kwargs...)
815+
additional_passes = get(kwargs, :additional_passes, nothing)
750816
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
751817
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
752818
discrete_compile = additional_passes[discrete_pass_idx]

0 commit comments

Comments
 (0)