Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Containers] add support for view of DenseAxisArray #3152

Merged
merged 3 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/Containers/DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,17 @@ function Base.IndexStyle(::Type{DenseAxisArray{T,N,Ax}}) where {T,N,Ax}
return IndexAnyCartesian()
end

function Base.setindex!(
A::DenseAxisArray{T,N},
value::DenseAxisArray{T,N},
args...,
) where {T,N}
for key in Base.product(args...)
A[key...] = value[key...]
end
return A
end

########
# Keys #
########
Expand All @@ -363,6 +374,10 @@ end
Base.getindex(k::DenseAxisArrayKey, args...) = getindex(k.I, args...)
Base.getindex(a::DenseAxisArray, k::DenseAxisArrayKey) = a[k.I...]

function Base.setindex!(A::DenseAxisArray, value, key::DenseAxisArrayKey)
return setindex!(A, value, key.I...)
end

struct DenseAxisArrayKeys{T<:Tuple,S<:DenseAxisArrayKey,N} <: AbstractArray{S,N}
product_iter::Base.Iterators.ProductIterator{T}
function DenseAxisArrayKeys(a::DenseAxisArray{TT,N,Ax}) where {TT,N,Ax}
Expand Down Expand Up @@ -559,3 +574,74 @@ end
# but some users may depend on it's functionality so we have a work-around
# instead of just breaking code.
Base.repeat(x::DenseAxisArray; kwargs...) = repeat(x.data; kwargs...)

###
### view
###

_get_subaxis(::Colon, b) = b

function _get_subaxis(a::AbstractVector, b)
for ai in a
if !(ai in b)
throw(KeyError(ai))
end
end
return a
end

function _get_subaxis(a::T, b::AbstractVector{T}) where {T}
if !(a in b)
throw(KeyError(a))
end
return a
end

struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N}
data::D
axes::A
function DenseAxisArrayView(
x::Containers.DenseAxisArray{T,N},
args...,
) where {T,N}
axis = _get_subaxis.(args, axes(x))
return new{T,N,typeof(x),typeof(axis)}(x, axis)
end
end

function Base.view(A::Containers.DenseAxisArray, args...)
return DenseAxisArrayView(A, args...)
end

Base.size(x::DenseAxisArrayView) = length.(x.axes)

Base.axes(x::DenseAxisArrayView) = x.axes

function Base.getindex(x::DenseAxisArrayView, args...)
y = _get_subaxis.(args, x.axes)
return getindex(x.data, y...)
end

Base.getindex(a::DenseAxisArrayView, k::DenseAxisArrayKey) = a[k.I...]

function Base.setindex!(x::DenseAxisArrayView, args...)
return setindex!(x.data, args...)
end

function Base.eachindex(A::DenseAxisArrayView)
# Return a generator so that we lazily evaluate the product instead of
# collecting into a vector.
#
# In future, we might want to return the appropriate matrix of
# `CartesianIndex` to avoid having to do the lookups with
# `DenseAxisArrayKey`.
return (DenseAxisArrayKey(k) for k in Base.product(A.axes...))
end

Base.show(io::IO, x::DenseAxisArrayView) = print(io, x.data)

Base.print_array(io::IO, x::DenseAxisArrayView) = show(io, x)

function Base.summary(io::IO, x::DenseAxisArrayView)
return print(io, "view(::DenseAxisArray, ", join(x.axes, ", "), "), over")
end
123 changes: 123 additions & 0 deletions test/Containers/test_DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,4 +478,127 @@ function test_DenseAxisArray_vector_keys()
return
end

function test_containers_denseaxisarray_setindex_vector()
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[2:3] .= 1.0
@test A.data == [0.0, 1.0, 1.0]
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[[2, 3]] .= 1.0
@test A.data == [0.0, 1.0, 1.0]
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[[1, 3]] .= 1.0
@test A.data == [1.0, 0.0, 1.0]
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[[2]] .= 1.0
@test A.data == [0.0, 1.0, 0.0]
A[2:3] = Containers.DenseAxisArray([2.0, 3.0], 2:3)
@test A.data == [0.0, 2.0, 3.0]
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[:] .= 1.0
@test A.data == [1.0, 1.0, 1.0]
return
end

function test_containers_denseaxisarray_setindex_matrix()
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[:, [:a, :b]] .= 1.0
@test A.data == [1.0 1.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[2:3, [:a, :b]] .= 1.0
@test A.data == [0.0 0.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[3:3, [:a, :b]] .= 1.0
@test A.data == [0.0 0.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[[1, 3], [:a, :b]] .= 1.0
@test A.data == [1.0 1.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[[1, 3], [:a, :c]] .= 1.0
@test A.data == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0]
return
end

function test_containers_denseaxisarray_view()
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
B = view(A, :, [:a, :b])
@test_throws KeyError view(A, :, [:d])
@test size(B) == (3, 2)
@test B[1, :a] == A[1, :a]
@test B[3, :a] == A[3, :a]
@test_throws KeyError B[3, :c]
@test sprint(show, B) == sprint(show, B.data)
@test sprint(Base.print_array, B) == sprint(show, B.data)
@test sprint(Base.summary, B) ==
"view(::DenseAxisArray, 1:3, [:a, :b]), over"
return
end

function test_containers_denseaxisarray_jump_3151()
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
E = Containers.DenseAxisArray(ones(3), [:a, :b, :c])
I = [:a, :b]
D[I] = E[I]
@test D.data == [1.0, 1.0, 0.0]
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
I = [:b, :c]
D[I] = E[I]
@test D.data == [0.0, 1.0, 1.0]
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
I = [:a, :c]
D[I] = E[I]
@test D.data == [1.0, 0.0, 1.0]
return
end

function test_containers_denseaxisarray_view_operations()
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
d = view(c, 2:3, :)
@test sum(c) == 60
@test sum(d) == 30
d .= 1
@test sum(d) == 4
@test sum(c) == 34
return
end

function test_containers_denseaxisarray_view_addition()
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
d = view(c, 2:3, :)
@test_throws MethodError d + d
return
end

function test_containers_denseaxisarray_view_colon()
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
d = view(c, 2:3, :)
@test d[:, 2] == Containers.@container([i = 2:3], i + 2 * 2)
return
end

function test_containers_denseaxisarray_setindex_invalid()
c = Containers.@container([i = 1:4, j = 2:3], 0)
d = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
setindex!(c, d, 1:4, 2:3)
@test c == d
c .= 0
setindex!(c, d, 1:4, 2:2)
@test c == Containers.@container([i = 1:4, j = 2:3], (4 + i) * (j == 2))
d = Containers.@container([i = 5:6, j = 2:3], i + 2 * j)
@test_throws KeyError setindex!(c, d, 1:4, 2:3)
return
end

function test_containers_denseaxisarray_setindex_keys()
c = Containers.@container([i = 1:4, j = 2:3], 0)
for (i, k) in enumerate(keys(c))
c[k] = c[k] + i
end
@test c == Containers.@container([i = 1:4, j = 2:3], 4 * (j - 2) + i)
for (i, k) in enumerate(keys(c))
c[k] = c[k] + i
end
@test c == Containers.@container([i = 1:4, j = 2:3], 2 * (4 * (j - 2) + i))
return
end

end # module