Skip to content

Commit 7d9829c

Browse files
committed
Implement DiffEqBase.anyeltypedual for GraphDynamics types
The PR #32 somehow broke a couple things in the Neuroblox.jl docs building pipeline due to how it interacted with the GraphDynamics tutorial. This change (which is good to do anyways) should fix that breakage. It basically just helps the differential equation solvers tell if anything in the problem is a dual number. I also did a little cleanup of unwanted stuff, and implemented a couple SymbolicIndexingInterface methods for `Num`s with observed variables.
1 parent 86cfae6 commit 7d9829c

File tree

5 files changed

+26
-16
lines changed

5 files changed

+26
-16
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.4.4"
55
[deps]
66
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
8+
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
89
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
910
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1011
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
@@ -22,6 +23,7 @@ MTKExt = ["Symbolics", "ModelingToolkit"]
2223
[compat]
2324
Accessors = "0.1"
2425
ConstructionBase = "1.5"
26+
DiffEqBase = "6"
2527
ModelingToolkit = "9"
2628
OhMyThreads = "0.6, 0.7, 0.8"
2729
OrderedCollections = "1.6.3"

ext/MTKExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,12 @@ function SymbolicIndexingInterface.is_independent_variable(sys::PartitionedGraph
2525
SymbolicIndexingInterface.is_independent_variable(sys, tosymbol(var; escape=false))
2626
end
2727

28+
function SymbolicIndexingInterface.is_observed(sys::PartitionedGraphSystem, var::Num)
29+
SymbolicIndexingInterface.is_observed(sys, tosymbol(var; escape=false))
30+
end
31+
32+
function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, var::Num)
33+
SymbolicIndexingInterface.observed(sys, tosymbol(var; escape=false))
34+
end
35+
2836
end

src/GraphDynamics.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ using OrderedCollections:
125125
OrderedCollections,
126126
OrderedDict
127127

128+
using DiffEqBase:
129+
DiffEqBase,
130+
anyeltypedual
128131

129132
#----------------------------------------------------------
130133
# Random utils

src/problems.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@ function SciMLBase.SDEProblem(g::PartitionedGraphSystem, u0map, tspan, param_map
4444
end
4545
noise_rate_prototype = nothing # this'll need to change once we support correlated noise
4646
prob = SDEProblem(f, graph_noise!, u, tspan, p; callback, noise_rate_prototype, tstops = vcat(tstops, nt.tstops), kwargs...)
47-
let ukeys = map(first, u0map),
48-
uvals = map(last, u0map)
49-
setu(prob, ukeys)(prob, uvals)
47+
for (k, v) u0map
48+
setu(prob, k)(prob, v)
5049
end
5150
@reset prob.p = set_params!!(prob.p, param_map)
5251
prob
@@ -67,17 +66,6 @@ Base.@kwdef struct GraphSystemParameters{PP, CM, S, PAP, DEC, NP, CONM, SNM, PNM
6766
end
6867

6968
function Base.copy(p::GraphSystemParameters)
70-
copy.(p.params_partitioned)
71-
copy(p.connection_matrices)
72-
p.scheduler
73-
p.partition_plan
74-
copy.(p.discrete_event_cache)
75-
copy.(p.names_partitioned)
76-
copy(p.connection_namemap)
77-
copy(p.state_namemap)
78-
copy(p.param_namemap)
79-
copy(p.compu_namemap)
80-
map(copy, p.extra_params)
8169
GraphSystemParameters(
8270
copy.(p.params_partitioned),
8371
copy(p.connection_matrices),
@@ -93,6 +81,16 @@ function Base.copy(p::GraphSystemParameters)
9381
)
9482
end
9583

84+
function DiffEqBase.anyeltypedual(p::GraphSystemParameters, ::Type{Val{counter}}) where {counter}
85+
anyeltypedual((p.params_partitioned, p.connection_matrices))
86+
end
87+
function DiffEqBase.anyeltypedual(p::ConnectionMatrices, ::Type{Val{counter}}) where {counter}
88+
anyeltypedual(p.matrices)
89+
end
90+
function DiffEqBase.anyeltypedual(p::ConnectionMatrix, ::Type{Val{counter}}) where {counter}
91+
anyeltypedual(p.data)
92+
end
93+
9694
function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete, u0map, param_map, global_events)
9795
(; states_partitioned,
9896
params_partitioned,

src/symbolic_indexing.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,7 @@ function set_params!!(buffer::GraphSystemParameters, param_map)
249249
elseif haskey(connection_namemap, key)
250250
buffer = set_param!!(buffer, key, connection_namemap[key], val)
251251
else
252-
@info "" keys(connection_namemap)
253-
error("Key $key does not correspond to a known parameter")
252+
error("Key $key does not correspond to a known parameter. ")
254253
end
255254
end
256255
buffer

0 commit comments

Comments
 (0)