Skip to content

Commit e84dfc2

Browse files
feat: allow specifying variable names in modelingtoolkitize
1 parent d6240ce commit e84dfc2

File tree

1 file changed

+72
-21
lines changed

1 file changed

+72
-21
lines changed

src/systems/diffeqs/modelingtoolkitize.jl

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,22 @@ $(TYPEDSIGNATURES)
33
44
Generate `ODESystem`, dependent variables, and parameters from an `ODEProblem`.
55
"""
6-
function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
6+
function modelingtoolkitize(prob::DiffEqBase.ODEProblem; u_names = nothing, p_names = nothing, kwargs...)
77
prob.f isa DiffEqBase.AbstractParameterizedFunction &&
88
return prob.f.sys
9-
@parameters t
10-
9+
t = t_nounits
1110
p = prob.p
1211
has_p = !(p isa Union{DiffEqBase.NullParameters, Nothing})
1312

14-
_vars = define_vars(prob.u0, t)
13+
if u_names !== nothing
14+
_vars = [_defvar(name)(t) for name in u_names]
15+
else
16+
_vars = define_vars(prob.u0, t)
17+
end
1518

1619
vars = prob.u0 isa Number ? _vars : ArrayInterface.restructure(prob.u0, _vars)
1720
params = if has_p
18-
_params = define_params(p)
21+
_params = define_params(p, p_names)
1922
p isa Number ? _params[1] :
2023
(p isa Tuple || p isa NamedTuple || p isa AbstractDict ? _params :
2124
ArrayInterface.restructure(p, _params))
@@ -25,7 +28,7 @@ function modelingtoolkitize(prob::DiffEqBase.ODEProblem; kwargs...)
2528

2629
var_set = Set(vars)
2730

28-
D = Differential(t)
31+
D = D_nounits
2932
mm = prob.f.mass_matrix
3033

3134
if mm === I
@@ -125,41 +128,89 @@ function Base.showerror(io::IO, e::ModelingtoolkitizeParametersNotSupportedError
125128
println(io, e.type)
126129
end
127130

128-
function define_params(p)
131+
function varnames_length_check(vars, names)
132+
if length(names) == length(p)
133+
throw(ArgumentError("""
134+
Number of parameters ($(length(p))) does not match number of names \
135+
($(length(names))).
136+
"""))
137+
end
138+
end
139+
140+
function define_params(p, _ = nothing)
129141
throw(ModelingtoolkitizeParametersNotSupportedError(typeof(p)))
130142
end
131143

132-
function define_params(p::AbstractArray)
133-
[toparam(variable(, i)) for i in eachindex(p)]
144+
function define_params(p::AbstractArray, names = nothing)
145+
if names === nothing
146+
[toparam(variable(, i)) for i in eachindex(p)]
147+
else
148+
varnames_length_check(p, names)
149+
[toparam(variable(name)) for name in names]
150+
end
134151
end
135152

136-
function define_params(p::Number)
137-
[toparam(variable())]
153+
function define_params(p::Number, names = nothing)
154+
if names === nothing
155+
[toparam(variable())]
156+
elseif names isa AbstractArray
157+
varnames_length_check(p, names)
158+
[toparam(variable(name)) for name in names]
159+
else
160+
[toparam(variable(names))]
161+
end
138162
end
139163

140-
function define_params(p::AbstractDict)
141-
OrderedDict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
164+
function define_params(p::AbstractDict, names = nothing)
165+
if names === nothing
166+
OrderedDict(k => toparam(variable(, i)) for (i, k) in zip(1:length(p), keys(p)))
167+
else
168+
varnames_length_check(p, names)
169+
OrderedDict(k => toparam(variable(names[k])) for k in keys(p))
170+
end
142171
end
143172

144-
function define_params(p::Union{SLArray, LArray})
145-
[toparam(variable(x)) for x in LabelledArrays.symnames(typeof(p))]
173+
function define_params(p::Union{SLArray, LArray}, names = nothing)
174+
if names === nothing
175+
[toparam(variable(x)) for x in LabelledArrays.symnames(typeof(p))]
176+
else
177+
varnames_length_check(p, names)
178+
[toparam(variable(name)) for name in names]
179+
end
146180
end
147181

148-
function define_params(p::Tuple)
149-
tuple((toparam(variable(, i)) for i in eachindex(p))...)
182+
function define_params(p::Tuple, names = nothing)
183+
if names === nothing
184+
tuple((toparam(variable(, i)) for i in eachindex(p))...)
185+
else
186+
varnames_length_check(p, names)
187+
tuple((toparam(variable(name)) for name in names))
188+
end
150189
end
151190

152-
function define_params(p::NamedTuple)
153-
NamedTuple(x => toparam(variable(x)) for x in keys(p))
191+
function define_params(p::NamedTuple, names = nothing)
192+
if names === nothing
193+
NamedTuple(x => toparam(variable(x)) for x in keys(p))
194+
else
195+
varnames_length_check(p, names)
196+
NamedTuple(x => toparam(variable(names[x])) for x in keys(p))
197+
end
154198
end
155199

156-
function define_params(p::MTKParameters)
200+
function define_params(p::MTKParameters, names = nothing)
157201
bufs = (p...,)
158202
i = 1
159203
ps = []
160204
for buf in bufs
161205
for _ in buf
162-
push!(ps, toparam(variable(, i)))
206+
push!(
207+
ps,
208+
if names === nothing
209+
toparam(variable(, i))
210+
else
211+
toparam(variable(names[i]))
212+
end
213+
)
163214
end
164215
end
165216
return identity.(ps)

0 commit comments

Comments
 (0)