From af8f439ac05444e550de63e7c478c5c5ae3859db Mon Sep 17 00:00:00 2001 From: odow Date: Thu, 15 Dec 2022 15:37:35 +1300 Subject: [PATCH 1/3] [Containers] add support for view of DenseAxisArray --- src/Containers/DenseAxisArray.jl | 71 +++++++++++++++ test/Containers/test_DenseAxisArray.jl | 117 +++++++++++++++++++++++++ 2 files changed, 188 insertions(+) diff --git a/src/Containers/DenseAxisArray.jl b/src/Containers/DenseAxisArray.jl index d61e23341ee..f9ea7d66050 100644 --- a/src/Containers/DenseAxisArray.jl +++ b/src/Containers/DenseAxisArray.jl @@ -348,6 +348,19 @@ function Base.IndexStyle(::Type{DenseAxisArray{T,N,Ax}}) where {T,N,Ax} return IndexAnyCartesian() end +function Base.setindex!( + A::DenseAxisArray{T,N}, + value::DenseAxisArray{T,N}, + args..., +) where {T,N} + @show args + @show Base.to_index(A, args) + for key in Base.product(args...) + A[key...] = value[key...] + end + return A +end + ######## # Keys # ######## @@ -363,6 +376,10 @@ end Base.getindex(k::DenseAxisArrayKey, args...) = getindex(k.I, args...) Base.getindex(a::DenseAxisArray, k::DenseAxisArrayKey) = a[k.I...] +function Base.setindex!(A::DenseAxisArray, value, key::DenseAxisArrayKey) + return setindex!(A, value, key.I...) +end + struct DenseAxisArrayKeys{T<:Tuple,S<:DenseAxisArrayKey,N} <: AbstractArray{S,N} product_iter::Base.Iterators.ProductIterator{T} function DenseAxisArrayKeys(a::DenseAxisArray{TT,N,Ax}) where {TT,N,Ax} @@ -559,3 +576,57 @@ end # but some users may depend on it's functionality so we have a work-around # instead of just breaking code. Base.repeat(x::DenseAxisArray; kwargs...) = repeat(x.data; kwargs...) + +### +### view +### + +_get_subaxis(::Colon, b) = b + +function _get_subaxis(a, b) + for ai in a + if !(ai in b) + throw(KeyError(ai)) + end + end + return a +end +struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N} + data::D + axes::A + function DenseAxisArrayView( + x::Containers.DenseAxisArray{T,N}, + args..., + ) where {T,N} + axis = tuple([_get_subaxis(a, b) for (a, b) in zip(args, axes(x))]...) + return new{T,N,typeof(x),typeof(axis)}(x, axis) + end +end + +function Base.view(A::Containers.DenseAxisArray, args...) + return DenseAxisArrayView(A, args...) +end + +Base.size(x::DenseAxisArrayView) = length.(x.axes) + +Base.axes(x::DenseAxisArrayView) = x.axes + +function Base.getindex(x::DenseAxisArrayView, args...) + return getindex(x.data, args...) +end + +function Base.setindex!(x::DenseAxisArrayView, args...) + return setindex!(x.data, args...) +end + +function Base.eachindex(A::DenseAxisArrayView) + return DenseAxisArrayKey.(Base.product(A.axes...)) +end + +Base.show(io::IO, x::DenseAxisArrayView) = print(io, x.data) + +Base.print_array(io::IO, x::DenseAxisArrayView) = show(io, x) + +function Base.summary(io::IO, x::DenseAxisArrayView) + return print(io, "view(::DenseAxisArray, ", join(x.axes, ", "), "), over") +end diff --git a/test/Containers/test_DenseAxisArray.jl b/test/Containers/test_DenseAxisArray.jl index 5c80c79fe16..07ce06bf98f 100644 --- a/test/Containers/test_DenseAxisArray.jl +++ b/test/Containers/test_DenseAxisArray.jl @@ -478,4 +478,121 @@ function test_DenseAxisArray_vector_keys() return end +function test_containers_denseaxisarray_setindex_vector() + A = Containers.DenseAxisArray(zeros(3), 1:3) + A[2:3] .= 1.0 + @test A.data == [0.0, 1.0, 1.0] + A = Containers.DenseAxisArray(zeros(3), 1:3) + A[[2, 3]] .= 1.0 + @test A.data == [0.0, 1.0, 1.0] + A = Containers.DenseAxisArray(zeros(3), 1:3) + A[[1, 3]] .= 1.0 + @test A.data == [1.0, 0.0, 1.0] + A = Containers.DenseAxisArray(zeros(3), 1:3) + A[[2]] .= 1.0 + @test A.data == [0.0, 1.0, 0.0] + A[2:3] = Containers.DenseAxisArray([2.0, 3.0], 2:3) + @test A.data == [0.0, 2.0, 3.0] + A = Containers.DenseAxisArray(zeros(3), 1:3) + A[:] .= 1.0 + @test A.data == [1.0, 1.0, 1.0] + return +end + +function test_containers_denseaxisarray_setindex_matrix() + A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c]) + A[:, [:a, :b]] .= 1.0 + @test A.data == [1.0 1.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0] + A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c]) + A[2:3, [:a, :b]] .= 1.0 + @test A.data == [0.0 0.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0] + A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c]) + A[3:3, [:a, :b]] .= 1.0 + @test A.data == [0.0 0.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0] + A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c]) + A[[1, 3], [:a, :b]] .= 1.0 + @test A.data == [1.0 1.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0] + A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c]) + A[[1, 3], [:a, :c]] .= 1.0 + @test A.data == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0] + return +end + +function test_containers_denseaxisarray_view() + A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c]) + B = view(A, :, [:a, :b]) + @test_throws KeyError view(A, :, [:d]) + @test size(B) == (3, 2) + @test B[1, :a] == A[1, :a] + @test B[3, :a] == A[3, :a] + # Views are weird, because we can still access the underlying array? + @test B[3, :c] == A[3, :c] + @test sprint(show, B) == sprint(show, B.data) + @test sprint(Base.print_array, B) == sprint(show, B.data) + @test sprint(Base.summary, B) == + "view(::DenseAxisArray, 1:3, [:a, :b]), over" + return +end + +function test_containers_denseaxisarray_jump_3151() + D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c]) + E = Containers.DenseAxisArray(ones(3), [:a, :b, :c]) + I = [:a, :b] + D[I] = E[I] + @test D.data == [1.0, 1.0, 0.0] + D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c]) + I = [:b, :c] + D[I] = E[I] + @test D.data == [0.0, 1.0, 1.0] + D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c]) + I = [:a, :c] + D[I] = E[I] + @test D.data == [1.0, 0.0, 1.0] + return +end + +function test_containers_denseaxisarray_view_operations() + c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j) + d = view(c, 2:3, :) + @test sum(c) == 60 + @test sum(d) == 30 + d .= 1 + @test sum(d) == 4 + @test sum(c) == 34 + return +end + +function test_containers_denseaxisarray_view_addition() + c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j) + d = view(c, 2:3, :) + @test_throws MethodError d + d + return +end + +function test_containers_denseaxisarray_setindex_invalid() + c = Containers.@container([i = 1:4, j = 2:3], 0) + d = Containers.@container([i = 1:4, j = 2:3], i + 2 * j) + setindex!(c, d, 1:4, 2:3) + @test c == d + c .= 0 + setindex!(c, d, 1:4, 2:2) + @test c == Containers.@container([i = 1:4, j = 2:3], (4 + i) * (j == 2)) + d = Containers.@container([i = 5:6, j = 2:3], i + 2 * j) + @test_throws KeyError setindex!(c, d, 1:4, 2:3) + return +end + +function test_containers_denseaxisarray_setindex_keys() + c = Containers.@container([i = 1:4, j = 2:3], 0) + for (i, k) in enumerate(keys(c)) + c[k] = c[k] + i + end + @test c == Containers.@container([i = 1:4, j = 2:3], 4 * (j - 2) + i) + for (i, k) in enumerate(keys(c)) + c[k] = c[k] + i + end + @test c == Containers.@container([i = 1:4, j = 2:3], 2 * (4 * (j - 2) + i)) + return +end + end # module From 95d4a548fcb00ea9f026cf11ebb6ef20ac556a7e Mon Sep 17 00:00:00 2001 From: odow Date: Tue, 3 Jan 2023 13:15:18 +1300 Subject: [PATCH 2/3] Fix getindex of DenseAxisArrayView --- src/Containers/DenseAxisArray.jl | 23 +++++++++++++++++------ test/Containers/test_DenseAxisArray.jl | 10 ++++++++-- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/Containers/DenseAxisArray.jl b/src/Containers/DenseAxisArray.jl index f9ea7d66050..5c055ce0027 100644 --- a/src/Containers/DenseAxisArray.jl +++ b/src/Containers/DenseAxisArray.jl @@ -353,8 +353,6 @@ function Base.setindex!( value::DenseAxisArray{T,N}, args..., ) where {T,N} - @show args - @show Base.to_index(A, args) for key in Base.product(args...) A[key...] = value[key...] end @@ -583,7 +581,7 @@ Base.repeat(x::DenseAxisArray; kwargs...) = repeat(x.data; kwargs...) _get_subaxis(::Colon, b) = b -function _get_subaxis(a, b) +function _get_subaxis(a::AbstractVector, b) for ai in a if !(ai in b) throw(KeyError(ai)) @@ -591,6 +589,14 @@ function _get_subaxis(a, b) end return a end + +function _get_subaxis(a::T, b::AbstractVector{T}) where {T} + if !(a in b) + throw(KeyError(a)) + end + return a +end + struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N} data::D axes::A @@ -598,7 +604,7 @@ struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N} x::Containers.DenseAxisArray{T,N}, args..., ) where {T,N} - axis = tuple([_get_subaxis(a, b) for (a, b) in zip(args, axes(x))]...) + axis = _get_subaxis.(args, axes(x)) return new{T,N,typeof(x),typeof(axis)}(x, axis) end end @@ -612,15 +618,20 @@ Base.size(x::DenseAxisArrayView) = length.(x.axes) Base.axes(x::DenseAxisArrayView) = x.axes function Base.getindex(x::DenseAxisArrayView, args...) - return getindex(x.data, args...) + y = _get_subaxis.(args, x.axes) + return getindex(x.data, y...) end +Base.getindex(a::DenseAxisArrayView, k::DenseAxisArrayKey) = a[k.I...] + function Base.setindex!(x::DenseAxisArrayView, args...) return setindex!(x.data, args...) end function Base.eachindex(A::DenseAxisArrayView) - return DenseAxisArrayKey.(Base.product(A.axes...)) + # Return a generator so that we lazily evaluate the product instead of + # collecting into a vector. + return (DenseAxisArrayKey(k) for k in Base.product(A.axes...)) end Base.show(io::IO, x::DenseAxisArrayView) = print(io, x.data) diff --git a/test/Containers/test_DenseAxisArray.jl b/test/Containers/test_DenseAxisArray.jl index 07ce06bf98f..d95e3eb82f3 100644 --- a/test/Containers/test_DenseAxisArray.jl +++ b/test/Containers/test_DenseAxisArray.jl @@ -525,8 +525,7 @@ function test_containers_denseaxisarray_view() @test size(B) == (3, 2) @test B[1, :a] == A[1, :a] @test B[3, :a] == A[3, :a] - # Views are weird, because we can still access the underlying array? - @test B[3, :c] == A[3, :c] + @test_throws KeyError B[3, :c] @test sprint(show, B) == sprint(show, B.data) @test sprint(Base.print_array, B) == sprint(show, B.data) @test sprint(Base.summary, B) == @@ -569,6 +568,13 @@ function test_containers_denseaxisarray_view_addition() return end +function test_containers_denseaxisarray_view_colon() + c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j) + d = view(c, 2:3, :) + @test d[:, 2] == Containers.@container([i = 2:3], i + 2 * 2) + return +end + function test_containers_denseaxisarray_setindex_invalid() c = Containers.@container([i = 1:4, j = 2:3], 0) d = Containers.@container([i = 1:4, j = 2:3], i + 2 * j) From 3856b82359fabec508f01c7acb1d8943cf0ff64c Mon Sep 17 00:00:00 2001 From: odow Date: Thu, 5 Jan 2023 09:12:56 +1300 Subject: [PATCH 3/3] Add note on future optimization --- src/Containers/DenseAxisArray.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/Containers/DenseAxisArray.jl b/src/Containers/DenseAxisArray.jl index 5c055ce0027..9f64bbf1d6f 100644 --- a/src/Containers/DenseAxisArray.jl +++ b/src/Containers/DenseAxisArray.jl @@ -631,6 +631,10 @@ end function Base.eachindex(A::DenseAxisArrayView) # Return a generator so that we lazily evaluate the product instead of # collecting into a vector. + # + # In future, we might want to return the appropriate matrix of + # `CartesianIndex` to avoid having to do the lookups with + # `DenseAxisArrayKey`. return (DenseAxisArrayKey(k) for k in Base.product(A.axes...)) end