Skip to content

Commit 65473ed

Browse files
authored
Lh/extension refactor (#25)
* Support extending JuMP containers * Add specialization of Base.sum Co-authored-by: Lars Hellemo <[email protected]>
1 parent 0965909 commit 65473ed

File tree

3 files changed

+44
-22
lines changed

3 files changed

+44
-22
lines changed

src/indexedarray.jl

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
Structure for holding an optimization variable with a sparse structure with extra indexing
55
"""
66
struct IndexedVarArray{N,T} <: AbstractSparseArray{VariableRef,N}
7-
model::Model
8-
name::Any
7+
f::Function
98
data::Dictionary{T,VariableRef}
109
index_names::NamedTuple
1110
index_cache::Vector{Dictionary}
@@ -22,8 +21,7 @@ function IndexedVarArray(
2221
N = length(fieldtypes(Ts))
2322
dict = Dictionary{T,VariableRef}()
2423
return model[Symbol(name)] = IndexedVarArray{N,T}(
25-
model,
26-
name,
24+
(ix...) -> createvar(model, name, ix; lower_bound, kw_args...),
2725
dict,
2826
index_names,
2927
Vector{Dictionary}(undef, 2^N),
@@ -46,8 +44,7 @@ function IndexedVarArray(
4644
(createvar(model, name, k; lower_bound, kw_args...) for k in indices),
4745
)
4846
return model[Symbol(name)] = IndexedVarArray{N,T}(
49-
model,
50-
name,
47+
(ix...) -> createvar(model, name, ix; lower_bound, kw_args...),
5148
dict,
5249
index_names,
5350
Vector{Dictionary}(undef, 2^N),
@@ -109,7 +106,7 @@ function insertvar!(
109106
!valid_index(var, index) && throw(BoundsError(var, index))# "Not a valid index for $(var.name): $index"g
110107
already_defined(var, index) && error("$(var.name): $index already defined")
111108

112-
var[index] = createvar(var.model, var.name, index; lower_bound, kw_args...)
109+
var[index] = var.f(index...)
113110

114111
clear_cache!(var)
115112
return var[index]
@@ -127,19 +124,7 @@ function unsafe_insertvar!(
127124
lower_bound = 0,
128125
kw_args...,
129126
) where {N,T}
130-
return var[index] =
131-
createvar(var.model, var.name, index; lower_bound, kw_args...)
132-
133-
#TODO: Reactivate this later
134-
# If active caches, update with new variable
135-
# cache = _getcache(var.sa, index)
136-
# for ind in keys(cache)
137-
# vred = Tuple(val for (i, val) in enumerate(index) if i in ind)
138-
# if !(vred in keys(var.index_cache[ind]))
139-
# var.index_cache[ind][vred] = []
140-
# end
141-
# push!(var.index_cache[ind][vred], index)
142-
# end
127+
return var[index] = var.f(index...)
143128
end
144129

145130
joinex(ex1, ex2) = :($ex1..., $ex2...)
@@ -251,3 +236,27 @@ function _getcache(sa::IndexedVarArray{N,T}, pat::P) where {N,T,P}
251236
end
252237
return sa.index_cache[t]
253238
end
239+
240+
# Extension for standard JuMP macros
241+
function Containers.container(
242+
f::Function,
243+
indices,
244+
D::Type{IndexedVarArray},
245+
names,
246+
)
247+
iva_names = NamedTuple{tuple(names...)}(indices.prod.iterators)
248+
T = Tuple{eltype.(indices.prod.iterators)...}
249+
N = length(names)
250+
return IndexedVarArray{N,T}(
251+
f,
252+
Dictionary{T,VariableRef}(),
253+
iva_names,
254+
Vector{Dictionary}(undef, 2^N),
255+
)
256+
end
257+
258+
# Fallback when no names are provided
259+
function Containers.container(f::Function, indices, D::Type{IndexedVarArray})
260+
index_vars = Symbol.("i$i" for i in 1:length(indices.prod.iterators))
261+
return Containers.container(f, indices, D, index_vars)
262+
end

src/sparsearray.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ function Base.setindex!(sa::AbstractSparseArray{T,N}, val, idx...) where {T,N}
3333
return setindex!(sa, val, idx)
3434
end
3535

36-
Base.length(sa::AbstractSparseArray) = length(_data(sa))
37-
Base.size(sa::AbstractSparseArray) = length(_data(sa))
3836
Base.keys(sa::AbstractSparseArray) = keys(_data(sa))
3937

4038
function Base.summary(io::IO, sa::AbstractSparseArray)
@@ -47,6 +45,9 @@ function Base.summary(io::IO, sa::AbstractSparseArray)
4745
isone(num_entries) ? " entry" : " entries",
4846
)
4947
end
48+
Base.length(sa::AbstractSparseArray) = length(_data(sa))
49+
Base.sum(sa::AbstractSparseArray) = sum(_data(sa))
50+
5051
function Base.show(io::IO, ::MIME"text/plain", sa::AbstractSparseArray)
5152
summary(io, sa)
5253
if !iszero(length(_data(sa)))

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,15 @@ end
299299
SparseVariables.clear_cache!(z3)
300300
@test length(z3.index_cache[4]) == 0
301301
end
302+
303+
@testset "JuMP extension" begin
304+
305+
# Test JuMP Extension
306+
m = Model()
307+
@variable(m, x[i = 1:3, j = 100:102] >= 0, container = IndexedVarArray)
308+
@test length(x) == 0
309+
insertvar!(x, 1, 100)
310+
@test length(x) == 1
311+
unsafe_insertvar!(x, 2, 102)
312+
@test length(x) == 2
313+
end

0 commit comments

Comments
 (0)