Skip to content

Commit d8ec8e9

Browse files
authored
Allow IndexedVarArray to hold any AbstractVariableRef (#29)
* Allow IndexedVarArray to hold any AbstractVariableRef * Minor fix in type usage
1 parent 65473ed commit d8ec8e9

File tree

2 files changed

+64
-20
lines changed

2 files changed

+64
-20
lines changed

Diff for: src/indexedarray.jl

+24-20
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""
2-
IndexedVarArray{N,T}
2+
IndexedVarArray{V,N,T}
33
44
Structure for holding an optimization variable with a sparse structure with extra indexing
55
"""
6-
struct IndexedVarArray{N,T} <: AbstractSparseArray{VariableRef,N}
6+
struct IndexedVarArray{V<:AbstractVariableRef,N,T} <: AbstractSparseArray{V,N}
77
f::Function
8-
data::Dictionary{T,VariableRef}
8+
data::Dictionary{T,V}
99
index_names::NamedTuple
1010
index_cache::Vector{Dictionary}
1111
end
@@ -20,7 +20,7 @@ function IndexedVarArray(
2020
T = Tuple{eltype.(fieldtypes(Ts))...}
2121
N = length(fieldtypes(Ts))
2222
dict = Dictionary{T,VariableRef}()
23-
return model[Symbol(name)] = IndexedVarArray{N,T}(
23+
return model[Symbol(name)] = IndexedVarArray{VariableRef,N,T}(
2424
(ix...) -> createvar(model, name, ix; lower_bound, kw_args...),
2525
dict,
2626
index_names,
@@ -43,7 +43,7 @@ function IndexedVarArray(
4343
indices,
4444
(createvar(model, name, k; lower_bound, kw_args...) for k in indices),
4545
)
46-
return model[Symbol(name)] = IndexedVarArray{N,T}(
46+
return model[Symbol(name)] = IndexedVarArray{VariableRef,N,T}(
4747
(ix...) -> createvar(model, name, ix; lower_bound, kw_args...),
4848
dict,
4949
index_names,
@@ -93,18 +93,18 @@ function clear_cache!(var)
9393
end
9494

9595
"""
96-
insertvar!(var::IndexedVarArray{N,T}, index...; lower_bound = 0, kw_args...)
96+
insertvar!(var::IndexedVarArray{V,N,T}, index...; lower_bound = 0, kw_args...)
9797
9898
Insert a new variable with the given index only after checking if keys are valid and not already defined.
9999
"""
100100
function insertvar!(
101-
var::IndexedVarArray{N,T},
101+
var::IndexedVarArray{V,N,T},
102102
index...;
103103
lower_bound = 0,
104104
kw_args...,
105-
) where {N,T}
105+
) where {V,N,T}
106106
!valid_index(var, index) && throw(BoundsError(var, index))# "Not a valid index for $(var.name): $index"g
107-
already_defined(var, index) && error("$(var.name): $index already defined")
107+
already_defined(var, index) && error("$index already defined for array")
108108

109109
var[index] = var.f(index...)
110110

@@ -113,17 +113,17 @@ function insertvar!(
113113
end
114114

115115
"""
116-
unsafe_insertvar!(var::indexedVarArray{N,T}, index...; lower_bound = 0, kw_args...)
116+
unsafe_insertvar!(var::indexedVarArray{V,N,T}, index...; lower_bound = 0, kw_args...)
117117
118118
Insert a new variable with the given index withouth checking if the index is valid or
119119
already assigned.
120120
"""
121121
function unsafe_insertvar!(
122-
var::IndexedVarArray{N,T},
122+
var::IndexedVarArray{V,N,T},
123123
index...;
124124
lower_bound = 0,
125125
kw_args...,
126-
) where {N,T}
126+
) where {V,N,T}
127127
return var[index] = var.f(index...)
128128
end
129129

@@ -147,7 +147,7 @@ joinex(ex1, ex2) = :($ex1..., $ex2...)
147147
return :(tuple($(exs[end])...))
148148
end
149149

150-
function build_cache!(cache, pat, sa::IndexedVarArray{N,T}) where {N,T}
150+
function build_cache!(cache, pat, sa::IndexedVarArray{V,N,T}) where {V,N,T}
151151
if isempty(cache)
152152
for v in keys(sa)
153153
vred = _active(v, pat)
@@ -158,7 +158,7 @@ function build_cache!(cache, pat, sa::IndexedVarArray{N,T}) where {N,T}
158158
return cache
159159
end
160160

161-
function _select_cached(sa::IndexedVarArray{N,T}, pat) where {N,T}
161+
function _select_cached(sa::IndexedVarArray{V,N,T}, pat) where {V,N,T}
162162
# TODO: Benchmark to find good cutoff-value for caching
163163
# TODO: Return same type for type stability
164164
length(_data(sa)) < 100 && return _select_gen(keys(_data(sa)), pat)
@@ -208,26 +208,29 @@ Non-colons count as 1, colons as 0, which are binary encoded to an integer.
208208
return :($i)
209209
end
210210

211-
function _decode_nonslices(::IndexedVarArray{N,T}, v::Integer) where {N,T}
211+
function _decode_nonslices(::IndexedVarArray{V,N,T}, v::Integer) where {V,N,T}
212212
fts = fieldtypes(T)
213213
return Tuple{
214214
(fts[i] for (i, c) in enumerate(last(bitstring(v), N)) if c == '1')...,
215215
}
216216
end
217217

218218
"""
219-
_decode_nonslices(::IndexedVarArray{N,T}, ::P)
219+
_decode_nonslices(::IndexedVarArray{V,N,T}, ::P)
220220
221221
Reconstruct types of a pattern from the array types and the pattern type
222222
"""
223-
@generated function _decode_nonslices(::IndexedVarArray{N,T}, ::P) where {N,T,P}
223+
@generated function _decode_nonslices(
224+
::IndexedVarArray{V,N,T},
225+
::P,
226+
) where {V,N,T,P}
224227
fts = fieldtypes(T)
225228
fts2 = fieldtypes(P)
226229
t = Tuple{(fts[i] for (i, v) in enumerate(fts2) if v != Colon)...}
227230
return :($t)
228231
end
229232

230-
function _getcache(sa::IndexedVarArray{N,T}, pat::P) where {N,T,P}
233+
function _getcache(sa::IndexedVarArray{V,N,T}, pat::P) where {V,N,T,P}
231234
t = _get_cache_index(pat)
232235
if isassigned(sa.index_cache, t)
233236
return sa.index_cache[t]
@@ -247,9 +250,10 @@ function Containers.container(
247250
iva_names = NamedTuple{tuple(names...)}(indices.prod.iterators)
248251
T = Tuple{eltype.(indices.prod.iterators)...}
249252
N = length(names)
250-
return IndexedVarArray{N,T}(
253+
V = first(Base.return_types(f))
254+
return IndexedVarArray{V,N,T}(
251255
f,
252-
Dictionary{T,VariableRef}(),
256+
Dictionary{T,V}(),
253257
iva_names,
254258
Vector{Dictionary}(undef, 2^N),
255259
)

Diff for: test/runtests.jl

+40
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,43 @@ end
311311
unsafe_insertvar!(x, 2, 102)
312312
@test length(x) == 2
313313
end
314+
315+
# Mockup of custom variable type
316+
struct MockVariable <: JuMP.AbstractVariable
317+
var::JuMP.ScalarVariable
318+
end
319+
320+
struct MockVariableRef <: JuMP.AbstractVariableRef
321+
v::VariableRef
322+
end
323+
324+
JuMP.name(mv::MockVariableRef) = JuMP.name(mv.v)
325+
326+
struct Mocking end
327+
328+
function JuMP.build_variable(::Function, info::JuMP.VariableInfo, _::Mocking)
329+
return MockVariable(JuMP.ScalarVariable(info))
330+
end
331+
332+
function JuMP.add_variable(model::Model, x::MockVariable, name::String)
333+
variable = JuMP.add_variable(model, x.var, name)
334+
return MockVariableRef(variable)
335+
end
336+
337+
@testset "Custom VariableRef" begin
338+
m = Model()
339+
@variable(
340+
m,
341+
x[i = 1:3, j = 100:102] >= 0,
342+
Mocking(),
343+
container = IndexedVarArray
344+
)
345+
@test length(x) == 0
346+
insertvar!(x, 1, 101)
347+
@test length(x) == 1
348+
@test typeof(first(x[:, :])) <: MockVariableRef
349+
insertvar!(x, 1, 100)
350+
@test length(x) == 2
351+
@test sum(x) == sum(x[:, :])
352+
@test typeof(sum(x)) <: GenericAffExpr{Float64,MockVariableRef}
353+
end

0 commit comments

Comments
 (0)