Skip to content

Commit 2835108

Browse files
authored
Refactor to better support type promotion due to thinks like Dual Numbers (#30)
* Refactor to better support type promotion due to thinks like ForwardDiff.Dual * bump version
1 parent d26062c commit 2835108

File tree

10 files changed

+164
-79
lines changed

10 files changed

+164
-79
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "GraphDynamics"
22
uuid = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
3-
version = "0.4.1"
3+
version = "0.4.2"
44

55
[deps]
66
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ function GraphDynamics.to_subsystem(p::Particle)
3939
# Use `name`, `m`, and `q` as parameters
4040
# Every subsystem should have a unique name symbol.
4141
params = SubsystemParams{Particle}(
42-
;name,
43-
m,
42+
;m,
4443
q,
4544
)
4645
# Assemble a Subsystem
@@ -82,8 +81,7 @@ function GraphDynamics.to_subsystem(p::Oscillator)
8281
# Use `name`, `m`, `k`, `x₀`, and `q` as parameters
8382
# Every subsystem should have a unique name symbol.
8483
params = SubsystemParams{Oscillator}(
85-
;name,
86-
m,
84+
;m,
8785
k,
8886
x₀,
8987
q,

src/GraphDynamics.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,11 +300,21 @@ abstract type ConnectionRule end
300300
(c::ConnectionRule)(src, dst, t) = c(src, dst)
301301
Base.zero(::T) where {T <: ConnectionRule} = zero(T)
302302

303-
struct NotConnected <: ConnectionRule end
304-
(::NotConnected)(l, r) = zero(promote_type(eltype(l), eltype(r)))
305-
struct ConnectionMatrix{N, CR, Tup <: NTuple{N, NTuple{N, Union{NotConnected, AbstractMatrix{CR}}}}}
303+
struct NotConnected{CR <: ConnectionRule} end
304+
Base.getindex(::NotConnected{CR}, inds...) where {CR} = zero(CR)
305+
Base.copy(c::NotConnected) = c
306+
struct ConnectionMatrix{N, CR, Tup <: NTuple{N, NTuple{N, Union{NotConnected{CR}, AbstractMatrix{CR}}}}}
306307
data::Tup
307308
end
309+
function Base.copy(c::ConnectionMatrix)
310+
data′ = map(c.data) do col
311+
map(col) do mat
312+
copy(mat)
313+
end
314+
end
315+
ConnectionMatrix(data′)
316+
end
317+
308318
struct ConnectionMatrices{NConn, Tup <: NTuple{NConn, ConnectionMatrix}}
309319
matrices::Tup
310320
end
@@ -314,7 +324,7 @@ Base.getindex(m::ConnectionMatrix, ::Val{i}, ::Val{j}) where {i, j} = m.data[i][
314324
@inline Base.getindex(m::ConnectionMatrices, i) = m.matrices[i]
315325
Base.length(m::ConnectionMatrices) = length(m.matrices)
316326
Base.size(m::ConnectionMatrix{N}) where {N} = (N, N)
317-
327+
Base.copy(c::ConnectionMatrices) = ConnectionMatrices(map(copy, c.matrices))
318328

319329
#----------------------------------------------------------
320330
# Infrastructure for subsystems

src/graph_system.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@ end
6060
function system_wiring_rule!(g, node)
6161
add_node!(g, node)
6262
end
63-
function system_wiring_rule!(g, src, dst; conn, kwargs...)
64-
add_connection!(g, src, dst; conn, kwargs...)
63+
function system_wiring_rule!(g, src, dst; kwargs...)
64+
if !haskey(kwargs, :conn)
65+
error("conn keyword argument not specified for connection between $src and $dst")
66+
end
67+
add_connection!(g, src, dst; conn=kwargs[:conn], kwargs...)
6568
end
6669

6770
@kwdef struct PartitionedGraphSystem{CM <: ConnectionMatrices, S, P, EVT, Ns, EP, SNM, PNM, CNM}
@@ -140,10 +143,10 @@ function PartitionedGraphSystem(g::GraphSystem)
140143
@named n2 = SysType1(x=1, y=3)
141144
@named n3 = SysType2(a=1, b=2, c=3)
142145
143-
add_connection!(g, n1, n2; conn=C1(1))
144-
add_connection!(g, n2, n3; conn=C1(2))
145-
add_connection!(g, n3, n1; conn=C2(3))
146-
add_connection!(g, n3, n2; conn=C3(4))
146+
add_connection!(g, n1, n2; conn=Conn1(1))
147+
add_connection!(g, n2, n3; conn=Conn1(2))
148+
add_connection!(g, n3, n1; conn=Conn2(3))
149+
add_connection!(g, n3, n2; conn=Conn2(4))
147150
148151
we'd get
149152
connection_matrix_1 = Conn1[⎡. 1⎤⎡.⎤
@@ -226,7 +229,7 @@ function make_connection_matrices(g_flat, nodes_partitioned=make_partitioned_nod
226229
end
227230
end
228231
rule_matrix = if isempty(conns)
229-
NotConnected()#{CT}(length(nodeks), length(nodeis))
232+
NotConnected{CT}() #{CT}(length(nodeks), length(nodeis))
230233
else
231234
sparse(ls, js, conns, length(nodeks), length(nodeis))
232235
end

src/problems.jl

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,10 @@ function SciMLBase.ODEProblem(g::PartitionedGraphSystem, u0map, tspan, param_ma
2727
end
2828
tstops = vcat(tstops, nt.tstops)
2929
prob = ODEProblem{true, SciMLBase.FullSpecialize}(f, u, tspan, p; callback, tstops, kwargs...)
30-
# let ukeys = map(first, u0map),
31-
# uvals = map(last, u0map)
32-
# setu(prob, ukeys)(prob, uvals)
33-
# end
34-
# let pkeys = map(first, param_map),
35-
# pvals = map(last, param_map)
36-
# setp(prob, pkeys)(prob, pvals)
37-
# end
38-
for (k, v) u0map
39-
setu(prob, k)(prob, v)
40-
end
41-
for (k, v) param_map
42-
setp(prob, k)(prob, v)
30+
for (k, v) u0map
31+
setu(prob, k)(prob, v)
4332
end
33+
@reset prob.p = set_params!!(prob.p, param_map)
4434
prob
4535
end
4636

@@ -58,27 +48,32 @@ function SciMLBase.SDEProblem(g::PartitionedGraphSystem, u0map, tspan, param_map
5848
uvals = map(last, u0map)
5949
setu(prob, ukeys)(prob, uvals)
6050
end
61-
let pkeys = map(first, param_map),
62-
pvals = map(last, param_map)
63-
setp(prob, pkeys)(prob, pvals)
64-
end
51+
@reset prob.p = set_params!!(prob.p, param_map)
6552
prob
6653
end
6754

68-
Base.@kwdef struct GraphSystemParameters{PP, CM, S, PAP, DEC, EP<:NamedTuple}
55+
Base.@kwdef struct GraphSystemParameters{PP, CM, S, PAP, DEC, NP, SNM, PNM, CNM, EP<:NamedTuple}
6956
params_partitioned::PP
7057
connection_matrices::CM
7158
scheduler::S
7259
partition_plan::PAP
7360
discrete_event_cache::DEC
61+
names_partitioned::NP
62+
state_namemap::SNM
63+
param_namemap::PNM
64+
compu_namemap::CNM
7465
extra_params::EP=(;)
7566
end
7667

7768
function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete, u0map, param_map, global_events)
7869
(; states_partitioned,
7970
params_partitioned,
8071
connection_matrices,
81-
tstops) = g
72+
tstops,
73+
names_partitioned,
74+
state_namemap,
75+
param_namemap,
76+
compu_namemap) = g
8277

8378
total_eltype = let
8479
states_eltype = mapreduce(promote_type, states_partitioned) do v
@@ -87,12 +82,24 @@ function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete
8782
u0map_eltype = mapreduce(promote_type, u0map; init=Union{}) do (k, v)
8883
typeof(v)
8984
end
90-
promote_type(states_eltype, u0map_eltype)
85+
numeric_params_eltype = mapreduce(promote_type, params_partitioned) do v
86+
if isconcretetype(eltype(v))
87+
promote_numeric_param_eltype(eltype(v))
88+
else
89+
mapreduce(promote_type, v) do params
90+
promote_numeric_param_eltype(typeof(params))
91+
end
92+
end
93+
end
94+
numeric_param_map_eltype = let numeric_params_from_map = [v for (_, v) in param_map if v isa Number]
95+
mapreduce(typeof, promote_type, numeric_params_from_map; init=Union{})
96+
end
97+
promote_type(states_eltype, u0map_eltype, numeric_params_eltype, numeric_param_map_eltype)
9198
end
9299

93100
re_eltype(s::SubsystemStates{T}) where {T} = convert(SubsystemStates{T, total_eltype}, s)
94101
states_partitioned = map(states_partitioned) do v
95-
if eltype(eltype(v)) <: total_eltype && eltype(eltype(v)) !== Union{}
102+
if eltype(eltype(v)) <: total_eltype
96103
v
97104
else
98105
re_eltype.(v)
@@ -159,7 +166,11 @@ function _problem(g::PartitionedGraphSystem, tspan; scheduler, allow_nonconcrete
159166
connection_matrices,
160167
scheduler,
161168
partition_plan,
162-
discrete_event_cache)
169+
discrete_event_cache,
170+
names_partitioned,
171+
state_namemap,
172+
param_namemap,
173+
compu_namemap)
163174

164175
(; f, u, tspan, p, callback, tstops)
165176
end

src/subsystems.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,23 @@ function ConstructionBase.getproperties(s::SubsystemParams)
1414
end
1515

1616
function ConstructionBase.setproperties(s::SubsystemParams{T}, patch::NamedTuple) where {T}
17+
set_param_prop(s, patch; allow_typechange=false)
18+
end
19+
function set_param_prop(s::SubsystemParams{T}, key, val; allow_typechange=false) where {T}
20+
set_param_prop(s, NamedTuple{(key,)}(val); allow_typechange)
21+
end
22+
function set_param_prop(s::SubsystemParams{T}, patch; allow_typechange=false) where {T}
1723
props = NamedTuple(s)
1824
props′ = merge(props, patch)
19-
if typeof(props) != typeof(props′)
25+
if typeof(props) != typeof(props′) && !allow_typechange
2026
param_setproperror(props, props′)
2127
end
2228
SubsystemParams{T}(props′)
2329
end
30+
2431
@noinline function param_setproperror(props, props′)
2532
error("Type unstable change to subsystem params! Expected properties of type\n $(typeof(props))\nbut got\n $(typeof(props′))")
26-
end
33+
end
2734

2835
get_tag(::SubsystemParams{Name}) where {Name} = Name
2936
get_tag(::Type{<:SubsystemParams{Name}}) where {Name} = Name
@@ -37,6 +44,9 @@ end
3744
function Base.convert(::Type{SubsystemParams{Name, NT}}, p::SubsystemParams{Name}) where {Name, NT}
3845
SubsystemParams{Name}(convert(NT, NamedTuple(p)))
3946
end
47+
@generated function promote_numeric_param_eltype(::Type{SubsystemParams{Name, NamedTuple{props, Tup}}}) where {Name, props, Tup}
48+
:(promote_type($(param for param in Tup.parameters if param <: Number)...))
49+
end
4050

4151
#------------------------------------------------------------
4252
# Subsystem states
@@ -156,8 +166,9 @@ function Base.convert(::Type{Subsystem{Name, Eltype}}, s::Subsystem{Name}) where
156166
end
157167

158168
@generated function promote_nt_type(::Type{NamedTuple{names, T1}},
159-
::Type{NamedTuple{names, T2}}) where {names, T1, T2}
160-
NamedTuple{names, Tuple{(promote_type(T1.parameters[i], T2.parameters[i]) for i eachindex(names))...}}
169+
::Type{NamedTuple{names, T2}}) where {names, T1, T2}
170+
proms = [:(promote_type($(T1.parameters[i]), $(T2.parameters[i]))) for i in eachindex(names)]
171+
:(NamedTuple{names, Tuple{$(proms...)}})
161172
end
162173

163174
function Base.promote_rule(::Type{SubsystemParams{Name, NT1}},

src/symbolic_indexing.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,36 @@ end
200200
# :t
201201
# )
202202
# end
203+
204+
function SymbolicIndexingInterface.remake_buffer(sys, oldbuffer::GraphSystemParameters, idxs, vals)
205+
newbuffer = @set oldbuffer.params_partitioned = copy.(oldbuffer.params_partitioned)
206+
set_params!!(newbuffer, zip(idxs, vals))
207+
end
208+
209+
function set_params!!(buffer::GraphSystemParameters, param_map)
210+
(; params_partitioned, param_namemap) = buffer
211+
for (key, val) param_map
212+
let (;tup_index, v_index, prop) = param_namemap[key]
213+
params = params_partitioned[tup_index][v_index]
214+
params_new = set_param_prop(params, prop, val; allow_typechange=true)
215+
peltype = eltype(params_partitioned[tup_index])
216+
if !(typeof(params_new) <: peltype)
217+
new_eltype = promote_type(typeof(params_new), peltype)
218+
@reset params_partitioned[tup_index] = convert.(new_eltype, params_partitioned[tup_index])
219+
end
220+
params_partitioned[tup_index][v_index] = params_new
221+
end
222+
end
223+
@reset buffer.params_partitioned = params_partitioned#re_eltype_params(params_partitioned)
224+
end
225+
226+
function re_eltype_params(params_partitioned)
227+
map(params_partitioned) do v
228+
ptype = mapreduce(typeof, promote_type, v)
229+
if ptype == eltype(v)
230+
v
231+
else
232+
convert.(ptype, v)
233+
end
234+
end
235+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
34
GraphDynamics = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
45
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
56
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"

0 commit comments

Comments
 (0)