Skip to content

Commit 9f61e95

Browse files
authored
[Containers] add support for view of DenseAxisArray (#3152)
1 parent 8f26202 commit 9f61e95

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

src/Containers/DenseAxisArray.jl

+86
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,17 @@ function Base.IndexStyle(::Type{DenseAxisArray{T,N,Ax}}) where {T,N,Ax}
348348
return IndexAnyCartesian()
349349
end
350350

351+
function Base.setindex!(
352+
A::DenseAxisArray{T,N},
353+
value::DenseAxisArray{T,N},
354+
args...,
355+
) where {T,N}
356+
for key in Base.product(args...)
357+
A[key...] = value[key...]
358+
end
359+
return A
360+
end
361+
351362
########
352363
# Keys #
353364
########
@@ -363,6 +374,10 @@ end
363374
Base.getindex(k::DenseAxisArrayKey, args...) = getindex(k.I, args...)
364375
Base.getindex(a::DenseAxisArray, k::DenseAxisArrayKey) = a[k.I...]
365376

377+
function Base.setindex!(A::DenseAxisArray, value, key::DenseAxisArrayKey)
378+
return setindex!(A, value, key.I...)
379+
end
380+
366381
struct DenseAxisArrayKeys{T<:Tuple,S<:DenseAxisArrayKey,N} <: AbstractArray{S,N}
367382
product_iter::Base.Iterators.ProductIterator{T}
368383
function DenseAxisArrayKeys(a::DenseAxisArray{TT,N,Ax}) where {TT,N,Ax}
@@ -559,3 +574,74 @@ end
559574
# but some users may depend on it's functionality so we have a work-around
560575
# instead of just breaking code.
561576
Base.repeat(x::DenseAxisArray; kwargs...) = repeat(x.data; kwargs...)
577+
578+
###
579+
### view
580+
###
581+
582+
_get_subaxis(::Colon, b) = b
583+
584+
function _get_subaxis(a::AbstractVector, b)
585+
for ai in a
586+
if !(ai in b)
587+
throw(KeyError(ai))
588+
end
589+
end
590+
return a
591+
end
592+
593+
function _get_subaxis(a::T, b::AbstractVector{T}) where {T}
594+
if !(a in b)
595+
throw(KeyError(a))
596+
end
597+
return a
598+
end
599+
600+
struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N}
601+
data::D
602+
axes::A
603+
function DenseAxisArrayView(
604+
x::Containers.DenseAxisArray{T,N},
605+
args...,
606+
) where {T,N}
607+
axis = _get_subaxis.(args, axes(x))
608+
return new{T,N,typeof(x),typeof(axis)}(x, axis)
609+
end
610+
end
611+
612+
function Base.view(A::Containers.DenseAxisArray, args...)
613+
return DenseAxisArrayView(A, args...)
614+
end
615+
616+
Base.size(x::DenseAxisArrayView) = length.(x.axes)
617+
618+
Base.axes(x::DenseAxisArrayView) = x.axes
619+
620+
function Base.getindex(x::DenseAxisArrayView, args...)
621+
y = _get_subaxis.(args, x.axes)
622+
return getindex(x.data, y...)
623+
end
624+
625+
Base.getindex(a::DenseAxisArrayView, k::DenseAxisArrayKey) = a[k.I...]
626+
627+
function Base.setindex!(x::DenseAxisArrayView, args...)
628+
return setindex!(x.data, args...)
629+
end
630+
631+
function Base.eachindex(A::DenseAxisArrayView)
632+
# Return a generator so that we lazily evaluate the product instead of
633+
# collecting into a vector.
634+
#
635+
# In future, we might want to return the appropriate matrix of
636+
# `CartesianIndex` to avoid having to do the lookups with
637+
# `DenseAxisArrayKey`.
638+
return (DenseAxisArrayKey(k) for k in Base.product(A.axes...))
639+
end
640+
641+
Base.show(io::IO, x::DenseAxisArrayView) = print(io, x.data)
642+
643+
Base.print_array(io::IO, x::DenseAxisArrayView) = show(io, x)
644+
645+
function Base.summary(io::IO, x::DenseAxisArrayView)
646+
return print(io, "view(::DenseAxisArray, ", join(x.axes, ", "), "), over")
647+
end

test/Containers/test_DenseAxisArray.jl

+123
Original file line numberDiff line numberDiff line change
@@ -478,4 +478,127 @@ function test_DenseAxisArray_vector_keys()
478478
return
479479
end
480480

481+
function test_containers_denseaxisarray_setindex_vector()
482+
A = Containers.DenseAxisArray(zeros(3), 1:3)
483+
A[2:3] .= 1.0
484+
@test A.data == [0.0, 1.0, 1.0]
485+
A = Containers.DenseAxisArray(zeros(3), 1:3)
486+
A[[2, 3]] .= 1.0
487+
@test A.data == [0.0, 1.0, 1.0]
488+
A = Containers.DenseAxisArray(zeros(3), 1:3)
489+
A[[1, 3]] .= 1.0
490+
@test A.data == [1.0, 0.0, 1.0]
491+
A = Containers.DenseAxisArray(zeros(3), 1:3)
492+
A[[2]] .= 1.0
493+
@test A.data == [0.0, 1.0, 0.0]
494+
A[2:3] = Containers.DenseAxisArray([2.0, 3.0], 2:3)
495+
@test A.data == [0.0, 2.0, 3.0]
496+
A = Containers.DenseAxisArray(zeros(3), 1:3)
497+
A[:] .= 1.0
498+
@test A.data == [1.0, 1.0, 1.0]
499+
return
500+
end
501+
502+
function test_containers_denseaxisarray_setindex_matrix()
503+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
504+
A[:, [:a, :b]] .= 1.0
505+
@test A.data == [1.0 1.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
506+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
507+
A[2:3, [:a, :b]] .= 1.0
508+
@test A.data == [0.0 0.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
509+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
510+
A[3:3, [:a, :b]] .= 1.0
511+
@test A.data == [0.0 0.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
512+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
513+
A[[1, 3], [:a, :b]] .= 1.0
514+
@test A.data == [1.0 1.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
515+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
516+
A[[1, 3], [:a, :c]] .= 1.0
517+
@test A.data == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0]
518+
return
519+
end
520+
521+
function test_containers_denseaxisarray_view()
522+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
523+
B = view(A, :, [:a, :b])
524+
@test_throws KeyError view(A, :, [:d])
525+
@test size(B) == (3, 2)
526+
@test B[1, :a] == A[1, :a]
527+
@test B[3, :a] == A[3, :a]
528+
@test_throws KeyError B[3, :c]
529+
@test sprint(show, B) == sprint(show, B.data)
530+
@test sprint(Base.print_array, B) == sprint(show, B.data)
531+
@test sprint(Base.summary, B) ==
532+
"view(::DenseAxisArray, 1:3, [:a, :b]), over"
533+
return
534+
end
535+
536+
function test_containers_denseaxisarray_jump_3151()
537+
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
538+
E = Containers.DenseAxisArray(ones(3), [:a, :b, :c])
539+
I = [:a, :b]
540+
D[I] = E[I]
541+
@test D.data == [1.0, 1.0, 0.0]
542+
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
543+
I = [:b, :c]
544+
D[I] = E[I]
545+
@test D.data == [0.0, 1.0, 1.0]
546+
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
547+
I = [:a, :c]
548+
D[I] = E[I]
549+
@test D.data == [1.0, 0.0, 1.0]
550+
return
551+
end
552+
553+
function test_containers_denseaxisarray_view_operations()
554+
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
555+
d = view(c, 2:3, :)
556+
@test sum(c) == 60
557+
@test sum(d) == 30
558+
d .= 1
559+
@test sum(d) == 4
560+
@test sum(c) == 34
561+
return
562+
end
563+
564+
function test_containers_denseaxisarray_view_addition()
565+
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
566+
d = view(c, 2:3, :)
567+
@test_throws MethodError d + d
568+
return
569+
end
570+
571+
function test_containers_denseaxisarray_view_colon()
572+
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
573+
d = view(c, 2:3, :)
574+
@test d[:, 2] == Containers.@container([i = 2:3], i + 2 * 2)
575+
return
576+
end
577+
578+
function test_containers_denseaxisarray_setindex_invalid()
579+
c = Containers.@container([i = 1:4, j = 2:3], 0)
580+
d = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
581+
setindex!(c, d, 1:4, 2:3)
582+
@test c == d
583+
c .= 0
584+
setindex!(c, d, 1:4, 2:2)
585+
@test c == Containers.@container([i = 1:4, j = 2:3], (4 + i) * (j == 2))
586+
d = Containers.@container([i = 5:6, j = 2:3], i + 2 * j)
587+
@test_throws KeyError setindex!(c, d, 1:4, 2:3)
588+
return
589+
end
590+
591+
function test_containers_denseaxisarray_setindex_keys()
592+
c = Containers.@container([i = 1:4, j = 2:3], 0)
593+
for (i, k) in enumerate(keys(c))
594+
c[k] = c[k] + i
595+
end
596+
@test c == Containers.@container([i = 1:4, j = 2:3], 4 * (j - 2) + i)
597+
for (i, k) in enumerate(keys(c))
598+
c[k] = c[k] + i
599+
end
600+
@test c == Containers.@container([i = 1:4, j = 2:3], 2 * (4 * (j - 2) + i))
601+
return
602+
end
603+
481604
end # module

0 commit comments

Comments
 (0)