Skip to content

Commit 1072681

Browse files
authored
Tables support for IndexedVarArray (#30)
* Tables support for IndexedVarArray * Implement Tables support along recent JuMP changes * Format fix * Move rowtable out of JuMP.Containers namespace * Test with generic column names until new JuMP ver
1 parent d8ec8e9 commit 1072681

File tree

6 files changed

+63
-143
lines changed

6 files changed

+63
-143
lines changed

Project.toml

-4
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,8 @@ version = "0.6.2"
77
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
88
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
11-
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1210

1311
[compat]
1412
Dictionaries = "0.3"
1513
JuMP = "1"
16-
Requires = "1.3"
17-
Tables = "1.7"
1814
julia = "1.6"

src/SparseVariables.jl

+1-10
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,19 @@ module SparseVariables
33
using Dictionaries
44
using JuMP
55
using LinearAlgebra
6-
using Tables
7-
using Requires
86

97
include("sparsearray.jl")
108
include("sparsevararray.jl")
119
include("macros.jl")
1210
include("dictionaries.jl")
13-
include("tables.jl")
1411
include("indexedarray.jl")
12+
include("tables.jl")
1513

1614
export SparseArray
1715
export SparseVarArray
1816
export IndexedVarArray
1917
export @sparsevariable
2018
export insertvar!
2119
export unsafe_insertvar!
22-
export table
23-
24-
function __init__()
25-
@require DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" include(
26-
"dataframes.jl",
27-
)
28-
end
2920

3021
end # module

src/dataframes.jl

-13
This file was deleted.

src/tables.jl

+29-100
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,36 @@
1-
abstract type SolutionTable end
2-
3-
Tables.istable(::Type{<:SolutionTable}) = true
4-
Tables.rowaccess(::Type{<:SolutionTable}) = true
5-
6-
rows(t::SolutionTable) = t
7-
names(t::SolutionTable) = getfield(t, :names)
8-
lookup(t::SolutionTable) = getfield(t, :lookup)
9-
10-
Base.eltype(::SolutionTable) = SolutionRow
11-
Base.length(t::SolutionTable) = length(t.var)
12-
13-
struct SolutionRow <: Tables.AbstractRow
14-
index_vals::Any
15-
sol_val::Float64
16-
source::SolutionTable
17-
end
18-
19-
function Tables.getcolumn(s::SolutionRow, i::Int)
20-
if i > length(getfield(s, :index_vals))
21-
return getfield(s, :sol_val)
22-
end
23-
return getfield(s, :index_vals)[i]
24-
end
25-
26-
function Tables.getcolumn(s::SolutionRow, nm::Symbol)
27-
i = lookup(getfield(s, :source))[nm]
28-
if i > length(getfield(s, :index_vals))
29-
return getfield(s, :sol_val)
30-
end
31-
return getfield(s, :index_vals)[i]
32-
end
33-
34-
Tables.columnnames(s::SolutionRow) = names(getfield(s, :source))
35-
36-
struct SolutionTableSparse <: SolutionTable
37-
names::Vector{Symbol}
38-
lookup::Dict{Symbol,Int}
39-
var::SparseVarArray
40-
end
41-
42-
SolutionTableSparse(v::SparseVarArray) = SolutionTableSparse(v, Symbol(v.name))
43-
44-
function SolutionTableSparse(v::SparseVarArray, name)
45-
if length(v) > 0 && !has_values(first(v.data).model)
46-
error("No solution values available for variable")
47-
end
48-
names = vcat(v.index_names, name)
49-
lookup = Dict(nm => i for (i, nm) in enumerate(names))
50-
return SolutionTableSparse(names, lookup, v)
51-
end
52-
53-
function Base.iterate(t::SolutionTableSparse, state = nothing)
54-
next =
55-
isnothing(state) ? iterate(keys(t.var.data)) :
56-
iterate(keys(t.var.data), state)
57-
next === nothing && return nothing
58-
return SolutionRow(next[1], JuMP.value(t.var[next[1]]), t), next[2]
59-
end
60-
61-
table(var::SparseVarArray) = SolutionTableSparse(var)
62-
table(var::SparseVarArray, name) = SolutionTableSparse(var, name)
63-
64-
struct SolutionTableDense <: SolutionTable
65-
names::Vector{Symbol}
66-
lookup::Dict{Symbol,Int}
67-
index_lookup::Dict
68-
var::Containers.DenseAxisArray
69-
end
70-
71-
function SolutionTableDense(
72-
v::Containers.DenseAxisArray{VariableRef,N,Ax,L},
73-
name,
74-
colnames...,
75-
) where {N,Ax,L}
76-
if length(colnames) < length(axes(v))
77-
error("Not enough column names provided")
78-
end
79-
if length(v) > 0 && !has_values(first(v).model)
80-
error("No solution values available for variable")
1+
function _rows(x::Union{SparseArray,SparseVarArray,IndexedVarArray})
2+
return zip(eachindex(x.data), keys(x.data))
3+
end
4+
5+
# The rowtable functions should be moved to the JuMP.Containers namespace
6+
# when Tables support is available in JuMP
7+
function rowtable(
8+
f::Function,
9+
x::AbstractSparseArray;
10+
header::Vector{Symbol} = Symbol[],
11+
)
12+
if isempty(header)
13+
header = Symbol[Symbol("x$i") for i in 1:ndims(x)]
14+
push!(header, :y)
8115
end
82-
names = vcat(colnames..., name)
83-
lookup = Dict(nm => i for (i, nm) in enumerate(names))
84-
index_lookup = Dict()
85-
for (i, ax) in enumerate(v.axes)
86-
index_lookup[i] = collect(ax)
16+
got, want = length(header), ndims(x) + 1
17+
if got != want
18+
error(
19+
"Invalid number of column names provided: Got $got, expected $want.",
20+
)
8721
end
88-
return SolutionTableDense(names, lookup, index_lookup, v)
22+
names = tuple(header...)
23+
return [NamedTuple{names}((args..., f(x[i]))) for (i, args) in _rows(x)]
8924
end
9025

91-
function Base.iterate(t::SolutionTableDense, state = nothing)
92-
next =
93-
isnothing(state) ? iterate(eachindex(t.var)) :
94-
iterate(eachindex(t.var), state)
95-
next === nothing && return nothing
96-
index = next[1]
97-
index_vals = [t.index_lookup[i][index[i]] for i in 1:length(index)]
98-
return SolutionRow(index_vals, JuMP.value(t.var[next[1]]), t), next[2]
26+
function rowtable(f::Function, x::IndexedVarArray, col_header::Symbol)
27+
header = Symbol[k for k in keys(x.index_names)]
28+
push!(header, col_header)
29+
return rowtable(f, x; header = header)
9930
end
10031

101-
function table(
102-
var::Containers.DenseAxisArray{VariableRef,N,Ax,L},
103-
name,
104-
colnames...,
105-
) where {N,Ax,L}
106-
return SolutionTableDense(var, name, colnames...)
32+
function rowtable(f::Function, x::IndexedVarArray)
33+
header = Symbol[k for k in keys(x.index_names)]
34+
push!(header, Symbol(f))
35+
return rowtable(f, x; header = header)
10736
end

test/Project.toml

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
[deps]
2-
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
32
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
43
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
54
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"

test/runtests.jl

+33-15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Base: product
2-
using DataFrames
32
using Dictionaries
43
using HiGHS
54
using JuMP
@@ -195,23 +194,11 @@ end
195194
set_optimizer_attribute(m, MOI.Silent(), true)
196195
optimize!(m)
197196

198-
tab = table(y)
199-
@test typeof(tab) == SV.SolutionTableSparse
200-
197+
tab = SparseVariables.rowtable(value, y; header = [:car, :year, :value])
201198
@test length(tab) == 5
202199

203-
r = first(tab)
204-
@test typeof(r) == SV.SolutionRow
200+
r = tab[1]
205201
@test r.car == "bmw"
206-
207-
t2 = table(u, :u, :car, :year)
208-
@test typeof(t2) == SV.SolutionTableDense
209-
@test length(t2) == 12
210-
rows = collect(t2)
211-
@test rows[11].year == 2003
212-
213-
df = dataframe(u, :u, :car, :year)
214-
@test first(df.car) == "ford"
215202
end
216203

217204
@testset "IndexedVarArray" begin
@@ -300,6 +287,37 @@ end
300287
@test length(z3.index_cache[4]) == 0
301288
end
302289

290+
@testset "Tables IndexedVarArray" begin
291+
m = Model()
292+
@variable(m, y[car = cars, year = year] >= 0; container = IndexedVarArray)
293+
for c in cars
294+
insertvar!(y, c, 2002)
295+
end
296+
@constraint(m, sum(y[:, :]) <= 300)
297+
@constraint(
298+
m,
299+
[i in year],
300+
sum(car_cost[c, i] * y[c, i] for (c, i) in SV.select(y, :, i)) <= 200
301+
)
302+
303+
@objective(m, Max, sum(y[c, i] for c in cars, i in year))
304+
305+
set_optimizer(m, HiGHS.Optimizer)
306+
set_optimizer_attribute(m, MOI.Silent(), true)
307+
optimize!(m)
308+
309+
tab = SparseVariables.rowtable(value, y)
310+
311+
T = NamedTuple{(:i1, :i2, :value),Tuple{String,Int,Float64}}
312+
@test tab isa Vector{T}
313+
314+
@test length(tab) == 3
315+
r = tab[1]
316+
@test r.i1 == "ford"
317+
@test r.i2 == 2002
318+
@test r.value == 300.0
319+
end
320+
303321
@testset "JuMP extension" begin
304322

305323
# Test JuMP Extension

0 commit comments

Comments
 (0)