Skip to content

Commit f31840c

Browse files
authored
[Containers] add slicing for SparseAxisArray (#3031)
1 parent 7135f54 commit f31840c

File tree

3 files changed

+96
-43
lines changed

3 files changed

+96
-43
lines changed

docs/src/manual/containers.md

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -207,33 +207,32 @@ JuMP.Containers.SparseAxisArray{Tuple{Int64, Int64}, 2, Tuple{Int64, Int64}} wit
207207

208208
The `[indices; condition]` syntax is used:
209209
```jldoctest containers_sparse
210-
julia> x = Containers.@container([i = 1:3, j = [:A, :B]; i > 1 && j == :B], (i, j))
211-
JuMP.Containers.SparseAxisArray{Tuple{Int64, Symbol}, 2, Tuple{Int64, Symbol}} with 2 entries:
210+
julia> x = Containers.@container([i = 1:3, j = [:A, :B]; i > 1], (i, j))
211+
JuMP.Containers.SparseAxisArray{Tuple{Int64, Symbol}, 2, Tuple{Int64, Symbol}} with 4 entries:
212+
[2, A] = (2, :A)
212213
[2, B] = (2, :B)
214+
[3, A] = (3, :A)
213215
[3, B] = (3, :B)
214216
```
215217
Here we have the index sets `i = 1:3, j = [:A, :B]`, followed by `;`, and then a
216-
condition, which evaluates to `true` or `false`: `i > 1 && j == :B`.
218+
condition, which evaluates to `true` or `false`: `i > 1`.
217219

218220
### Slicing
219221

220-
```@meta
221-
# TODO: This is included so we know to update the documentation when this is fixed.
222-
```
223-
224-
Slicing is not supported.
222+
Slicing is supported:
225223
```jldoctest containers_sparse
226-
julia> x[:, :B]
227-
ERROR: ArgumentError: Indexing with `:` is not supported by Containers.SparseAxisArray
228-
[...]
224+
julia> y = x[:, :B]
225+
JuMP.Containers.SparseAxisArray{Tuple{Int64, Symbol}, 1, Tuple{Int64}} with 2 entries:
226+
[2] = (2, :B)
227+
[3] = (3, :B)
229228
```
230229

231230
### Looping
232231

233232
Use `eachindex` to loop over the elements:
234233
```jldoctest containers_sparse
235-
julia> for key in eachindex(x)
236-
println(x[key])
234+
julia> for key in eachindex(y)
235+
println(y[key])
237236
end
238237
(2, :B)
239238
(3, :B)
@@ -247,10 +246,10 @@ Broadcasting over a SparseAxisArray returns a SparseAxisArray
247246
julia> swap(x::Tuple) = (last(x), first(x))
248247
swap (generic function with 1 method)
249248
250-
julia> swap.(x)
251-
JuMP.Containers.SparseAxisArray{Tuple{Symbol, Int64}, 2, Tuple{Int64, Symbol}} with 2 entries:
252-
[2, B] = (:B, 2)
253-
[3, B] = (:B, 3)
249+
julia> swap.(y)
250+
JuMP.Containers.SparseAxisArray{Tuple{Symbol, Int64}, 1, Tuple{Int64}} with 2 entries:
251+
[2] = (:B, 2)
252+
[3] = (:B, 3)
254253
```
255254

256255
## Forcing the container type

src/Containers/SparseAxisArray.jl

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,6 @@ function Base.haskey(sa::SparseAxisArray{T,1,Tuple{I}}, idx::I) where {T,I}
8585
return haskey(sa.data, (idx,))
8686
end
8787

88-
# Error for sa[..., :, ...]
89-
_colon_error() = nothing
90-
91-
_colon_error(::Any, args...) = _colon_error(args...)
92-
93-
function _colon_error(::Colon, args...)
94-
return throw(
95-
ArgumentError(
96-
"Indexing with `:` is not supported by" *
97-
" Containers.SparseAxisArray",
98-
),
99-
)
100-
end
101-
10288
function Base.setindex!(
10389
d::SparseAxisArray{T,N,K},
10490
value,
@@ -107,11 +93,16 @@ function Base.setindex!(
10793
return setindex!(d, value, idx...)
10894
end
10995

110-
function Base.setindex!(d::SparseAxisArray{T,N}, value, idx...) where {T,N}
96+
function Base.setindex!(d::SparseAxisArray{T,N,K}, value, idx...) where {T,N,K}
11197
if length(idx) < N
11298
throw(BoundsError(d, idx))
99+
elseif _sliced_key_type(K, idx...) !== nothing
100+
throw(
101+
ArgumentError(
102+
"Slicing is not when calling setindex! on a SparseAxisArray",
103+
),
104+
)
113105
end
114-
_colon_error(idx...)
115106
return setindex!(d.data, value, idx)
116107
end
117108

@@ -122,14 +113,52 @@ function Base.getindex(
122113
return getindex(d, idx...)
123114
end
124115

125-
function Base.getindex(d::SparseAxisArray{T,N}, idx...) where {T,N}
116+
function Base.getindex(d::SparseAxisArray{T,N,K}, idx...) where {T,N,K}
126117
if length(idx) < N
127118
throw(BoundsError(d, idx))
128119
end
129-
_colon_error(idx...)
120+
K2 = _sliced_key_type(K, idx...)
121+
if K2 !== nothing
122+
new_data = Dict{K2,T}(
123+
_sliced_key(k, idx) => v for (k, v) in d.data if _filter(k, idx)
124+
)
125+
return SparseAxisArray(new_data)
126+
end
130127
return getindex(d.data, idx)
131128
end
132129

130+
# Method to check whether an index is an attempt at a slice.
131+
@generated function _sliced_key_type(::Type{K}, idx...) where {K<:Tuple}
132+
expr = Expr(:curly, :Tuple)
133+
for i in 1:length(idx)
134+
Ki = K.parameters[i]
135+
if idx[i] <: Colon || idx[i] <: AbstractVector{<:Ki}
136+
push!(expr.args, Ki)
137+
end
138+
end
139+
return length(expr.args) == 1 ? :(nothing) : expr
140+
end
141+
142+
# Methods to check whether a key `k` is a valid subset of `idx`.
143+
_filter(::Any, ::Colon) = true
144+
_filter(ki::Any, i::Any) = ki == i
145+
_filter(ki::K, i::AbstractVector{<:K}) where {K} = ki in i
146+
_filter(::Tuple{}, ::Tuple{}) = true
147+
function _filter(k::Tuple, idx::Tuple)
148+
return _filter(k[1], idx[1]) && _filter(Base.tail(k), Base.tail(idx))
149+
end
150+
151+
# Methods to subset the key into a new key, dropping all singleton axes.
152+
_sliced_key(k, ::Any) = (k,)
153+
_sliced_key(::K, ::K) where {K} = ()
154+
_sliced_key(::Tuple{}, ::Tuple{}) = ()
155+
function _sliced_key(k::Tuple, idx::Tuple)
156+
return tuple(
157+
_sliced_key(k[1], idx[1])...,
158+
_sliced_key(Base.tail(k), Base.tail(idx))...,
159+
)
160+
end
161+
133162
Base.eachindex(d::SparseAxisArray) = keys(d.data)
134163

135164
################

test/Containers/SparseAxisArray.jl

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,6 @@ using Test
1414
@testset "SparseAxisArray" begin
1515
function sparse_test(d, sum_d, d2, d3, dsqr, d_bads)
1616
sqr(x) = x^2
17-
@testset "Colon indexing" begin
18-
err = ArgumentError(
19-
"Indexing with `:` is not supported by" *
20-
" Containers.SparseAxisArray",
21-
)
22-
@test_throws err d[:, ntuple(one, ndims(d) - 1)...]
23-
@test_throws err d[ntuple(i -> :a, ndims(d) - 1)..., :]
24-
end
2517
@testset "Map" begin
2618
@test d == @inferred map(identity, d)
2719
@test dsqr == @inferred map(sqr, d)
@@ -195,4 +187,37 @@ $(SparseAxisArray{Float64,2,Tuple{Symbol,Char}}) with 2 entries"""
195187
@test y isa SparseAxisArray{Any,2,Tuple{Any,Int}}
196188
@test isempty(y)
197189
end
190+
@testset "Slicing" begin
191+
Containers.@container(x[i = 1:4, j = 1:2; isodd(i + j)], i + j)
192+
@test x[:, :] == x
193+
@test x[1, :] == Containers.@container(y[j = 1:2; isodd(1 + j)], 1 + j)
194+
@test x[:, 1] == Containers.@container(z[i = 1:4; isodd(i + 1)], i + 1)
195+
@test isempty(x[[1, 3], [1, 3]])
196+
@test typeof(x[[1, 3], [1, 3]]) == typeof(x)
197+
@test typeof(x[[1, 3], 1]) ==
198+
Containers.SparseAxisArray{Int,1,Tuple{Int}}
199+
@test isempty(x[[1, 3], 1])
200+
Containers.@container(y[i = 1:4; isodd(i)], i)
201+
@test y[:] == y
202+
Containers.@container(y[i = 1:4; isodd(i)], i)
203+
@test y[[1, 3]] == y
204+
z = Containers.@container([i = 1:3, j = [:A, :B]; i > 1], (i, j))
205+
@test z[2, :] == Containers.@container([j = [:A, :B]; true], (2, j))
206+
@test z[:, :A] == Containers.@container([i = 2:3; true], (i, :A))
207+
@test z[:, :] == z
208+
@test z[1:2, :A] == Containers.@container([i = 2:2; true], (i, :A))
209+
@test z[2, [:A, :B]] ==
210+
Containers.@container([j = [:A, :B]; true], (2, j))
211+
@test z[1:2, [:A, :B]] ==
212+
Containers.@container([i = 2:2, j = [:A, :B]; true], (i, j))
213+
end
214+
@testset "Slicing on set" begin
215+
Containers.@container(x[i = 1:4, j = 1:2; isodd(i + j)], i + j)
216+
err = ArgumentError(
217+
"Slicing is not when calling setindex! on a SparseAxisArray",
218+
)
219+
@test_throws(err, x[:, :] = 1)
220+
@test_throws(err, x[1, :] = 1)
221+
@test_throws(err, x[1, 1:2] = 1)
222+
end
198223
end

0 commit comments

Comments
 (0)