Skip to content

Commit d54449d

Browse files
Merge pull request #3502 from AayushSabharwal/as/substitute-component
feat: add `substitute_component`
2 parents 61a64f9 + e747405 commit d54449d

File tree

4 files changed

+460
-1
lines changed

4 files changed

+460
-1
lines changed

src/ModelingToolkit.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
291291
hasunit, getunit, hasconnect, getconnect,
292292
hasmisc, getmisc, state_priority
293293
export ode_order_lowering, dae_order_lowering, liouville_transform,
294-
change_independent_variable
294+
change_independent_variable, substitute_component
295295
export PDESystem
296296
export Differential, expand_derivatives, @derivatives
297297
export Equation, ConstrainedEquation

src/systems/abstractsystem.jl

+185
Original file line numberDiff line numberDiff line change
@@ -3184,3 +3184,188 @@ has_diff_eqs(osys21) # returns `false`.
31843184
```
31853185
"""
31863186
has_diff_eqs(sys::AbstractSystem) = any(is_diff_equation, get_eqs(sys))
3187+
3188+
"""
3189+
$(TYPEDSIGNATURES)
3190+
3191+
Validate the rules for replacement of subcomponents as defined in `substitute_component`.
3192+
"""
3193+
function validate_replacement_rule(
3194+
rule::Pair{T, T}; namespace = []) where {T <: AbstractSystem}
3195+
lhs, rhs = rule
3196+
3197+
iscomplete(lhs) && throw(ArgumentError("LHS of replacement rule cannot be completed."))
3198+
iscomplete(rhs) && throw(ArgumentError("RHS of replacement rule cannot be completed."))
3199+
3200+
rhs_h = namespace_hierarchy(nameof(rhs))
3201+
if length(rhs_h) != 1
3202+
throw(ArgumentError("RHS of replacement rule must not be namespaced."))
3203+
end
3204+
rhs_h[1] == namespace_hierarchy(nameof(lhs))[end] ||
3205+
throw(ArgumentError("LHS and RHS must have the same name."))
3206+
3207+
if !isequal(get_iv(lhs), get_iv(rhs))
3208+
throw(ArgumentError("LHS and RHS of replacement rule must have the same independent variable."))
3209+
end
3210+
3211+
lhs_u = get_unknowns(lhs)
3212+
rhs_u = Dict(get_unknowns(rhs) .=> nothing)
3213+
for u in lhs_u
3214+
if !haskey(rhs_u, u)
3215+
if isempty(namespace)
3216+
throw(ArgumentError("RHS of replacement rule does not contain unknown $u."))
3217+
else
3218+
throw(ArgumentError("Subsystem $(join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)) of RHS does not contain unknown $u."))
3219+
end
3220+
end
3221+
ru = getkey(rhs_u, u, nothing)
3222+
name = join([namespace; nameof(lhs); (hasname(u) ? getname(u) : Symbol(u))],
3223+
NAMESPACE_SEPARATOR)
3224+
l_connect = something(getconnect(u), Equality)
3225+
r_connect = something(getconnect(ru), Equality)
3226+
if l_connect != r_connect
3227+
throw(ArgumentError("Variable $(name) should have connection metadata $(l_connect),"))
3228+
end
3229+
3230+
l_input = isinput(u)
3231+
r_input = isinput(ru)
3232+
if l_input != r_input
3233+
throw(ArgumentError("Variable $name has differing causality. Marked as `input = $l_input` in LHS and `input = $r_input` in RHS."))
3234+
end
3235+
l_output = isoutput(u)
3236+
r_output = isoutput(ru)
3237+
if l_output != r_output
3238+
throw(ArgumentError("Variable $name has differing causality. Marked as `output = $l_output` in LHS and `output = $r_output` in RHS."))
3239+
end
3240+
end
3241+
3242+
lhs_p = get_ps(lhs)
3243+
rhs_p = Set(get_ps(rhs))
3244+
for p in lhs_p
3245+
if !(p in rhs_p)
3246+
if isempty(namespace)
3247+
throw(ArgumentError("RHS of replacement rule does not contain parameter $p"))
3248+
else
3249+
throw(ArgumentError("Subsystem $(join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)) of RHS does not contain parameter $p."))
3250+
end
3251+
end
3252+
end
3253+
3254+
lhs_s = get_systems(lhs)
3255+
rhs_s = Dict(nameof(s) => s for s in get_systems(rhs))
3256+
3257+
for s in lhs_s
3258+
if haskey(rhs_s, nameof(s))
3259+
rs = rhs_s[nameof(s)]
3260+
if isconnector(s)
3261+
name = join([namespace; nameof(lhs); nameof(s)], NAMESPACE_SEPARATOR)
3262+
if !isconnector(rs)
3263+
throw(ArgumentError("Subsystem $name of RHS is not a connector."))
3264+
end
3265+
if (lct = get_connector_type(s)) !== (rct = get_connector_type(rs))
3266+
throw(ArgumentError("Subsystem $name of RHS has connection type $rct but LHS has $lct."))
3267+
end
3268+
end
3269+
validate_replacement_rule(s => rs; namespace = [namespace; nameof(rhs)])
3270+
continue
3271+
end
3272+
name1 = join([namespace; nameof(lhs)], NAMESPACE_SEPARATOR)
3273+
throw(ArgumentError("$name1 of replacement rule does not contain subsystem $(nameof(s))."))
3274+
end
3275+
end
3276+
3277+
"""
3278+
$(TYPEDSIGNATURES)
3279+
3280+
Chain `getproperty` calls on `root` in the order given in `hierarchy`.
3281+
3282+
# Keyword Arguments
3283+
3284+
- `skip_namespace_first`: Whether to avoid namespacing in the first `getproperty` call.
3285+
"""
3286+
function recursive_getproperty(
3287+
root::AbstractSystem, hierarchy::Vector{Symbol}; skip_namespace_first = true)
3288+
cur = root
3289+
for (i, name) in enumerate(hierarchy)
3290+
cur = getproperty(cur, name; namespace = i > 1 || !skip_namespace_first)
3291+
end
3292+
return unwrap(cur)
3293+
end
3294+
3295+
"""
3296+
$(TYPEDSIGNATURES)
3297+
3298+
Recursively descend through `sys`, finding all connection equations and re-creating them
3299+
using the names of the involved variables/systems and finding the required variables/
3300+
systems in the hierarchy.
3301+
"""
3302+
function recreate_connections(sys::AbstractSystem)
3303+
eqs = map(get_eqs(sys)) do eq
3304+
eq.lhs isa Union{Connection, AnalysisPoint} || return eq
3305+
if eq.lhs isa Connection
3306+
oldargs = get_systems(eq.rhs)
3307+
else
3308+
ap::AnalysisPoint = eq.rhs
3309+
oldargs = [ap.input; ap.outputs]
3310+
end
3311+
newargs = map(get_systems(eq.rhs)) do arg
3312+
rewrap_nameof = arg isa SymbolicWithNameof
3313+
if rewrap_nameof
3314+
arg = arg.var
3315+
end
3316+
name = arg isa AbstractSystem ? nameof(arg) : getname(arg)
3317+
hierarchy = namespace_hierarchy(name)
3318+
newarg = recursive_getproperty(sys, hierarchy)
3319+
if rewrap_nameof
3320+
newarg = SymbolicWithNameof(newarg)
3321+
end
3322+
return newarg
3323+
end
3324+
if eq.lhs isa Connection
3325+
return eq.lhs ~ Connection(newargs)
3326+
else
3327+
return eq.lhs ~ AnalysisPoint(newargs[1], eq.rhs.name, newargs[2:end])
3328+
end
3329+
end
3330+
@set! sys.eqs = eqs
3331+
@set! sys.systems = map(recreate_connections, get_systems(sys))
3332+
return sys
3333+
end
3334+
3335+
"""
3336+
$(TYPEDSIGNATURES)
3337+
3338+
Given a hierarchical system `sys` and a rule `lhs => rhs`, replace the subsystem `lhs` in
3339+
`sys` by `rhs`. The `lhs` must be the namespaced version of a subsystem of `sys` (e.g.
3340+
obtained via `sys.inner.component`). The `rhs` must be valid as per the following
3341+
conditions:
3342+
3343+
1. `rhs` must not be namespaced.
3344+
2. The name of `rhs` must be the same as the unnamespaced name of `lhs`.
3345+
3. Neither one of `lhs` or `rhs` can be marked as complete.
3346+
4. Both `lhs` and `rhs` must share the same independent variable.
3347+
5. `rhs` must contain at least all of the unknowns and parameters present in
3348+
`lhs`.
3349+
6. Corresponding unknowns in `rhs` must share the same connection and causality
3350+
(input/output) metadata as their counterparts in `lhs`.
3351+
7. For each subsystem of `lhs`, there must be an identically named subsystem of `rhs`.
3352+
These two corresponding subsystems must satisfy conditions 3, 4, 5, 6, 7. If the
3353+
subsystem of `lhs` is a connector, the corresponding subsystem of `rhs` must also
3354+
be a connector of the same type.
3355+
3356+
`sys` also cannot be marked as complete.
3357+
"""
3358+
function substitute_component(sys::T, rule::Pair{T, T}) where {T <: AbstractSystem}
3359+
iscomplete(sys) &&
3360+
throw(ArgumentError("Cannot replace subsystems of completed systems"))
3361+
3362+
validate_replacement_rule(rule)
3363+
3364+
lhs, rhs = rule
3365+
hierarchy = namespace_hierarchy(nameof(lhs))
3366+
3367+
newsys, _ = modify_nested_subsystem(sys, hierarchy) do inner
3368+
return rhs, ()
3369+
end
3370+
return recreate_connections(newsys)
3371+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ end
9898
@safetestset "Causal Variables Connection Test" include("causal_variables_connection.jl")
9999
@safetestset "Debugging Test" include("debugging.jl")
100100
@safetestset "Namespacing test" include("namespacing.jl")
101+
@safetestset "Subsystem replacement" include("substitute_component.jl")
101102
end
102103
end
103104

0 commit comments

Comments
 (0)