Skip to content

Commit 908c32b

Browse files
feat: add substitute_component
1 parent 606a043 commit 908c32b

File tree

4 files changed

+448
-1
lines changed

4 files changed

+448
-1
lines changed

src/ModelingToolkit.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
291291
hasunit, getunit, hasconnect, getconnect,
292292
hasmisc, getmisc, state_priority
293293
export ode_order_lowering, dae_order_lowering, liouville_transform,
294-
change_independent_variable
294+
change_independent_variable, substitute_component
295295
export PDESystem
296296
export Differential, expand_derivatives, @derivatives
297297
export Equation, ConstrainedEquation

src/systems/abstractsystem.jl

+177
Original file line numberDiff line numberDiff line change
@@ -3181,3 +3181,180 @@ has_diff_eqs(osys21) # returns `false`.
31813181
```
31823182
"""
31833183
has_diff_eqs(sys::AbstractSystem) = any(is_diff_equation, get_eqs(sys))
3184+
3185+
"""
3186+
$(TYPEDSIGNATURES)
3187+
3188+
Validate the rules for replacement of subcomponents as defined in `substitute_component`.
3189+
"""
3190+
function validate_replacement_rule(
3191+
rule::Pair{T, T}; namespace = []) where {T <: AbstractSystem}
3192+
lhs, rhs = rule
3193+
3194+
iscomplete(lhs) && throw(ArgumentError("LHS of replacement rule cannot be completed."))
3195+
iscomplete(rhs) && throw(ArgumentError("RHS of replacement rule cannot be completed."))
3196+
3197+
rhs_h = namespace_hierarchy(nameof(rhs))
3198+
if length(rhs_h) != 1
3199+
throw(ArgumentError("RHS of replacement rule must not be namespaced."))
3200+
end
3201+
rhs_h[1] == namespace_hierarchy(nameof(lhs))[end] ||
3202+
throw(ArgumentError("LHS and RHS must have the same name."))
3203+
3204+
if !isequal(get_iv(lhs), get_iv(rhs))
3205+
throw(ArgumentError("LHS and RHS of replacement rule must have the same independent variable."))
3206+
end
3207+
3208+
lhs_u = get_unknowns(lhs)
3209+
rhs_u = Dict(get_unknowns(rhs) .=> nothing)
3210+
for u in lhs_u
3211+
if !haskey(rhs_u, u)
3212+
if isempty(namespace)
3213+
throw(ArgumentError("RHS of replacement rule does not contain unknown $u."))
3214+
else
3215+
throw(ArgumentError("Subsystem $(join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)) of RHS does not contain unknown $u."))
3216+
end
3217+
end
3218+
ru = getkey(rhs_u, u, nothing)
3219+
name = join([namespace; nameof(lhs); (hasname(u) ? getname(u) : Symbol(u))],
3220+
NAMESPACE_SEPARATOR)
3221+
l_connect = something(getconnect(u), Equality)
3222+
r_connect = something(getconnect(ru), Equality)
3223+
if l_connect != r_connect
3224+
throw(ArgumentError("Variable $(name) should have connection metadata $(l_connect),"))
3225+
end
3226+
3227+
l_input = isinput(u)
3228+
r_input = isinput(ru)
3229+
if l_input != r_input
3230+
throw(ArgumentError("Variable $name has differing causality. Marked as `input = $l_input` in LHS and `input = $r_input` in RHS."))
3231+
end
3232+
l_output = isoutput(u)
3233+
r_output = isoutput(ru)
3234+
if l_output != r_output
3235+
throw(ArgumentError("Variable $name has differing causality. Marked as `output = $l_output` in LHS and `output = $r_output` in RHS."))
3236+
end
3237+
end
3238+
3239+
lhs_p = get_ps(lhs)
3240+
rhs_p = Set(get_ps(rhs))
3241+
for p in lhs_p
3242+
if !(p in rhs_p)
3243+
if isempty(namespace)
3244+
throw(ArgumentError("RHS of replacement rule does not contain parameter $p"))
3245+
else
3246+
throw(ArgumentError("Subsystem $(join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)) of RHS does not contain parameter $p."))
3247+
end
3248+
end
3249+
end
3250+
3251+
lhs_s = get_systems(lhs)
3252+
rhs_s = Dict(nameof(s) => s for s in get_systems(rhs))
3253+
3254+
for s in lhs_s
3255+
if haskey(rhs_s, nameof(s))
3256+
rs = rhs_s[nameof(s)]
3257+
if isconnector(s)
3258+
name = join([namespace; nameof(lhs); nameof(s)], NAMESPACE_SEPARATOR)
3259+
if !isconnector(rs)
3260+
throw(ArgumentError("Subsystem $name of RHS is not a connector."))
3261+
end
3262+
if (lct = get_connector_type(s)) !== (rct = get_connector_type(rs))
3263+
throw(ArgumentError("Subsystem $name of RHS has connection type $rct but LHS has $lct."))
3264+
end
3265+
end
3266+
validate_replacement_rule(s => rs; namespace = [namespace; nameof(rhs)])
3267+
continue
3268+
end
3269+
name1 = join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)
3270+
throw(ArgumentError("$name1 of replacement rule does not contain subsystem $(nameof(s))."))
3271+
end
3272+
end
3273+
3274+
"""
3275+
$(TYPEDSIGNATURES)
3276+
3277+
Chain `getproperty` calls on `root` in the order given in `hierarchy`.
3278+
3279+
# Keyword Arguments
3280+
3281+
- `skip_namespace_first`: Whether to avoid namespacing in the first `getproperty` call.
3282+
"""
3283+
function recursive_getproperty(
3284+
root::AbstractSystem, hierarchy::Vector{Symbol}; skip_namespace_first = true)
3285+
cur = root
3286+
for (i, name) in enumerate(hierarchy)
3287+
cur = getproperty(cur, name; namespace = i > 1 || !skip_namespace_first)
3288+
end
3289+
return cur
3290+
end
3291+
3292+
"""
3293+
$(TYPEDSIGNATURES)
3294+
3295+
Recursively descend through `sys`, finding all connection equations and re-creating them
3296+
using the names of the involved variables/systems and finding the required variables/
3297+
systems in the hierarchy.
3298+
"""
3299+
function recreate_connections(sys::AbstractSystem)
3300+
eqs = map(get_eqs(sys)) do eq
3301+
eq.lhs isa Union{Connection, AnalysisPoint} || return eq
3302+
if eq.lhs isa Connection
3303+
oldargs = get_systems(eq.rhs)
3304+
else
3305+
ap::AnalysisPoint = eq.rhs
3306+
oldargs = [ap.input; ap.outputs]
3307+
end
3308+
newargs = map(get_systems(eq.rhs)) do arg
3309+
name = arg isa AbstractSystem ? nameof(arg) : getname(arg)
3310+
hierarchy = namespace_hierarchy(name)
3311+
return recursive_getproperty(sys, hierarchy)
3312+
end
3313+
if eq.lhs isa Connection
3314+
return eq.lhs ~ Connection(newargs)
3315+
else
3316+
return eq.lhs ~ AnalysisPoint(newargs[1], eq.rhs.name, newargs[2:end])
3317+
end
3318+
end
3319+
@set! sys.eqs = eqs
3320+
@set! sys.systems = map(recreate_connections, get_systems(sys))
3321+
return sys
3322+
end
3323+
3324+
"""
3325+
$(TYPEDSIGNATURES)
3326+
3327+
Given a hierarchical system `sys` and a rule `lhs => rhs`, replace the subsystem `lhs` in
3328+
`sys` by `rhs`. The `lhs` must be the namespaced version of a subsystem of `sys` (e.g.
3329+
obtained via `sys.inner.component`). The `rhs` must be valid as per the following
3330+
conditions:
3331+
3332+
1. `rhs` must not be namespaced.
3333+
2. The name of `rhs` must be the same as the unnamespaced name of `lhs`.
3334+
3. Neither one of `lhs` or `rhs` can be marked as complete.
3335+
4. Both `lhs` and `rhs` must share the same independent variable.
3336+
5. `rhs` must contain at least all of the unknowns and parameters present in
3337+
`lhs`.
3338+
6. Corresponding unknowns in `rhs` must share the same connection and causality
3339+
(input/output) metadata as their counterparts in `lhs`.
3340+
7. For each subsystem of `lhs`, there must be an identically named subsystem of `rhs`.
3341+
These two corresponding subsystems must satisfy conditions 3, 4, 5, 6, 7. If the
3342+
subsystem of `lhs` is a connector, the corresponding subsystem of `rhs` must also
3343+
be a connector of the same type.
3344+
3345+
`sys` also cannot be marked as complete.
3346+
"""
3347+
function substitute_component(sys::T, rule::Pair{T, T}) where {T <: AbstractSystem}
3348+
iscomplete(sys) &&
3349+
throw(ArgumentError("Cannot replace subsystems of completed systems"))
3350+
3351+
validate_replacement_rule(rule)
3352+
3353+
lhs, rhs = rule
3354+
hierarchy = namespace_hierarchy(nameof(lhs))
3355+
3356+
newsys, _ = modify_nested_subsystem(sys, hierarchy) do inner
3357+
return rhs, ()
3358+
end
3359+
return recreate_connections(newsys)
3360+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ end
9898
@safetestset "Causal Variables Connection Test" include("causal_variables_connection.jl")
9999
@safetestset "Debugging Test" include("debugging.jl")
100100
@safetestset "Namespacing test" include("namespacing.jl")
101+
@safetestset "Subsystem replacement" include("substitute_component.jl")
101102
end
102103
end
103104

0 commit comments

Comments
 (0)