@@ -209,6 +209,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
209
209
structure:: SystemStructure
210
210
extra_eqs:: Vector
211
211
param_derivative_map:: Dict{BasicSymbolic, Any}
212
+ statemachines:: Vector{T}
212
213
end
213
214
214
215
TransformationState (sys:: AbstractSystem ) = TearingState (sys)
@@ -217,6 +218,22 @@ function system_subset(ts::TearingState, ieqs::Vector{Int})
217
218
@set! ts. original_eqs = ts. original_eqs[ieqs]
218
219
@set! ts. sys. eqs = eqs[ieqs]
219
220
@set! ts. structure = system_subset (ts. structure, ieqs)
221
+ if all (eq -> eq. rhs isa StateMachineOperator, get_eqs (ts. sys))
222
+ names = Symbol[]
223
+ for eq in get_eqs (ts. sys)
224
+ if eq. lhs isa Transition
225
+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. from))))
226
+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. to))))
227
+ elseif eq. lhs isa InitialState
228
+ push! (names, first (namespace_hierarchy (nameof (eq. rhs. s))))
229
+ else
230
+ error (" Unhandled state machine operator" )
231
+ end
232
+ end
233
+ @set! ts. statemachines = filter (x -> nameof (x) in names, ts. statemachines)
234
+ else
235
+ @set! ts. statemachines = eltype (ts. statemachines)[]
236
+ end
220
237
ts
221
238
end
222
239
@@ -270,6 +287,49 @@ function symbolic_contains(var, set)
270
287
all (x -> x in set, Symbolics. scalarize (var))
271
288
end
272
289
290
+ """
291
+ $(TYPEDSIGNATURES)
292
+
293
+ Descend through the system hierarchy and look for statemachines. Remove equations from
294
+ the inner statemachine systems. Return the new `sys` and an array of top-level
295
+ statemachines.
296
+ """
297
+ function extract_top_level_statemachines (sys:: AbstractSystem )
298
+ eqs = get_eqs (sys)
299
+
300
+ if ! isempty (eqs) && all (eq -> eq. lhs isa StateMachineOperator, eqs)
301
+ # top-level statemachine
302
+ with_removed = @set sys. systems = map (remove_child_equations, get_systems (sys))
303
+ return with_removed, [sys]
304
+ elseif ! isempty (eqs) && any (eq -> eq. lhs isa StateMachineOperator, eqs)
305
+ # error: can't mix
306
+ error (" Mixing statemachine equations and standard equations in a top-level statemachine is not allowed." )
307
+ else
308
+ # descend
309
+ subsystems = get_systems (sys)
310
+ newsubsystems = eltype (subsystems)[]
311
+ statemachines = eltype (subsystems)[]
312
+ for subsys in subsystems
313
+ newsubsys, sub_statemachines = extract_top_level_statemachines (subsys)
314
+ push! (newsubsystems, newsubsys)
315
+ append! (statemachines, sub_statemachines)
316
+ end
317
+ @set! sys. systems = newsubsystems
318
+ return sys, statemachines
319
+ end
320
+ end
321
+
322
+ """
323
+ $(TYPEDSIGNATURES)
324
+
325
+ Return `sys` with all equations (including those in subsystems) removed.
326
+ """
327
+ function remove_child_equations (sys:: AbstractSystem )
328
+ @set! sys. eqs = eltype (get_eqs (sys))[]
329
+ @set! sys. systems = map (remove_child_equations, get_systems (sys))
330
+ return sys
331
+ end
332
+
273
333
function TearingState (sys; quick_cancel = false , check = true , sort_eqs = true )
274
334
# flatten system
275
335
sys = flatten (sys)
@@ -334,9 +394,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
334
394
# change the equation if the RHS is `missing` so the rest of this loop works
335
395
eq = 0.0 ~ coalesce (eq. rhs, 0.0 )
336
396
end
337
- rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
338
- if ! _iszero (eq. lhs)
397
+ is_statemachine_equation = false
398
+ if eq. lhs isa StateMachineOperator
399
+ is_statemachine_equation = true
400
+ eq = eq
401
+ rhs = eq. rhs
402
+ elseif _iszero (eq. lhs)
403
+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
404
+ else
339
405
lhs = quick_cancel ? quick_cancel_expr (eq. lhs) : eq. lhs
406
+ rhs = quick_cancel ? quick_cancel_expr (eq. rhs) : eq. rhs
340
407
eq = 0 ~ rhs - lhs
341
408
end
342
409
empty! (varsbuf)
@@ -400,8 +467,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
400
467
addvar! (v, VARIABLE)
401
468
end
402
469
end
403
-
404
- if isalgeq
470
+ if isalgeq || is_statemachine_equation
405
471
eqs[i] = eq
406
472
else
407
473
eqs[i] = eqs[i]. lhs ~ rhs
@@ -519,8 +585,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
519
585
ts = TearingState (sys, original_eqs, fullvars,
520
586
SystemStructure (complete (var_to_diff), complete (eq_to_diff),
521
587
complete (graph), nothing , var_types, false ),
522
- Any[], param_derivative_map)
523
-
588
+ Any[], param_derivative_map, typeof (sys)[])
524
589
return ts
525
590
end
526
591
@@ -747,6 +812,7 @@ function mtkcompile!(state::TearingState; simplify = false,
747
812
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
748
813
check_consistency, fully_determined,
749
814
kwargs... )
815
+ additional_passes = get (kwargs, :additional_passes , nothing )
750
816
if ! isnothing (additional_passes) && any (discrete_compile_pass, additional_passes)
751
817
discrete_pass_idx = findfirst (discrete_compile_pass, additional_passes)
752
818
discrete_compile = additional_passes[discrete_pass_idx]
0 commit comments