Skip to content

Commit 730ad20

Browse files
fix: avoid infinite loops in MTKParameters initialization
1 parent 4c6e061 commit 730ad20

File tree

1 file changed

+36
-20
lines changed

1 file changed

+36
-20
lines changed

src/systems/parameter_buffer.jl

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,42 @@ function MTKParameters(
4444
defs = merge(defs, u0)
4545
defs = merge(Dict(eq.lhs => eq.rhs for eq in observed(sys)), defs)
4646
bigdefs = merge(defs, p)
47-
p = merge(Dict(unwrap(k) => v for (k, v) in p),
48-
Dict(default_toterm(unwrap(k)) => v for (k, v) in p))
47+
p = Dict()
48+
missing_params = Set()
49+
50+
for sym in all_ps
51+
ttsym = default_toterm(sym)
52+
isarr = iscall(sym) && operation(sym) === getindex
53+
arrparent = isarr ? arguments(sym)[1] : nothing
54+
ttarrparent = isarr ? default_toterm(arrparent) : nothing
55+
pname = hasname(sym) ? getname(sym) : nothing
56+
ttpname = hasname(ttsym) ? getname(ttsym) : nothing
57+
p[sym] = p[ttsym] = if haskey(bigdefs, sym)
58+
bigdefs[sym]
59+
elseif haskey(bigdefs, ttsym)
60+
bigdefs[ttsym]
61+
elseif haskey(bigdefs, pname)
62+
isarr ? bigdefs[pname][arguments(sym)[2:end]...] : bigdefs[pname]
63+
elseif haskey(bigdefs, ttpname)
64+
isarr ? bigdefs[ttpname][arguments(sym)[2:end]...] : bigdefs[pname]
65+
elseif isarr && haskey(bigdefs, arrparent)
66+
bigdefs[arrparent][arguments(sym)[2:end]...]
67+
elseif isarr && haskey(bigdefs, ttarrparent)
68+
bigdefs[ttarrparent][arguments(sym)[2:end]...]
69+
end
70+
if get(p, sym, nothing) === nothing
71+
push!(missing_params, sym)
72+
continue
73+
end
74+
# We may encounter the `ttsym` version first, add it to `missing_params`
75+
# then encounter the "normal" version of a parameter or vice versa
76+
# Remove the old one in `missing_params` just in case
77+
delete!(missing_params, sym)
78+
delete!(missing_params, ttsym)
79+
end
80+
81+
isempty(missing_params) || throw(MissingParametersError(collect(missing_params)))
82+
4983
p = Dict(unwrap(k) => fixpoint_sub(v, bigdefs) for (k, v) in p)
5084
for (sym, _) in p
5185
if iscall(sym) && operation(sym) === getindex &&
@@ -54,24 +88,6 @@ function MTKParameters(
5488
end
5589
end
5690

57-
missing_params = Set()
58-
for idxmap in (ic.tunable_idx, ic.discrete_idx, ic.constant_idx, ic.nonnumeric_idx)
59-
for sym in keys(idxmap)
60-
sym isa Symbol && continue
61-
haskey(p, sym) && continue
62-
hasname(sym) && haskey(p, getname(sym)) && continue
63-
ttsym = default_toterm(sym)
64-
haskey(p, ttsym) && continue
65-
hasname(ttsym) && haskey(p, getname(ttsym)) && continue
66-
67-
iscall(sym) && operation(sym) === getindex && haskey(p, arguments(sym)[1]) &&
68-
continue
69-
push!(missing_params, sym)
70-
end
71-
end
72-
73-
isempty(missing_params) || throw(MissingParametersError(collect(missing_params)))
74-
7591
tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length)
7692
for temp in ic.tunable_buffer_sizes)
7793
disc_buffer = Tuple(Vector{temp.type}(undef, temp.length)

0 commit comments

Comments
 (0)