Skip to content

Commit 4b027ab

Browse files
committed
[Containers] add support for view of DenseAxisArray
1 parent 34462e4 commit 4b027ab

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed

src/Containers/DenseAxisArray.jl

+71
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,19 @@ function Base.IndexStyle(::Type{DenseAxisArray{T,N,Ax}}) where {T,N,Ax}
348348
return IndexAnyCartesian()
349349
end
350350

351+
function Base.setindex!(
352+
A::DenseAxisArray{T,N},
353+
value::DenseAxisArray{T,N},
354+
args...,
355+
) where {T,N}
356+
@show args
357+
@show Base.to_index(A, args)
358+
for key in Base.product(args...)
359+
A[key...] = value[key...]
360+
end
361+
return A
362+
end
363+
351364
########
352365
# Keys #
353366
########
@@ -363,6 +376,10 @@ end
363376
Base.getindex(k::DenseAxisArrayKey, args...) = getindex(k.I, args...)
364377
Base.getindex(a::DenseAxisArray, k::DenseAxisArrayKey) = a[k.I...]
365378

379+
function Base.setindex!(A::DenseAxisArray, value, key::DenseAxisArrayKey)
380+
return setindex!(A, value, key.I...)
381+
end
382+
366383
struct DenseAxisArrayKeys{T<:Tuple,S<:DenseAxisArrayKey,N} <: AbstractArray{S,N}
367384
product_iter::Base.Iterators.ProductIterator{T}
368385
function DenseAxisArrayKeys(a::DenseAxisArray{TT,N,Ax}) where {TT,N,Ax}
@@ -559,3 +576,57 @@ end
559576
# but some users may depend on it's functionality so we have a work-around
560577
# instead of just breaking code.
561578
Base.repeat(x::DenseAxisArray; kwargs...) = repeat(x.data; kwargs...)
579+
580+
###
581+
### view
582+
###
583+
584+
_get_subaxis(::Colon, b) = b
585+
586+
function _get_subaxis(a, b)
587+
for ai in a
588+
if !(ai in b)
589+
throw(KeyError(ai))
590+
end
591+
end
592+
return a
593+
end
594+
struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N}
595+
data::D
596+
axes::A
597+
function DenseAxisArrayView(
598+
x::Containers.DenseAxisArray{T,N},
599+
args...,
600+
) where {T,N}
601+
axis = tuple([_get_subaxis(a, b) for (a, b) in zip(args, axes(x))]...)
602+
return new{T,N,typeof(x),typeof(axis)}(x, axis)
603+
end
604+
end
605+
606+
function Base.view(A::Containers.DenseAxisArray, args...)
607+
return DenseAxisArrayView(A, args...)
608+
end
609+
610+
Base.size(x::DenseAxisArrayView) = length.(x.axes)
611+
612+
Base.axes(x::DenseAxisArrayView) = x.axes
613+
614+
function Base.getindex(x::DenseAxisArrayView, args...)
615+
return getindex(x.data, args...)
616+
end
617+
618+
function Base.setindex!(x::DenseAxisArrayView, args...)
619+
return setindex!(x.data, args...)
620+
end
621+
622+
function Base.eachindex(A::DenseAxisArrayView)
623+
return DenseAxisArrayKey.(Base.product(A.axes...))
624+
end
625+
626+
Base.show(io::IO, x::DenseAxisArrayView) = print(io, x.data)
627+
628+
Base.print_array(io::IO, x::DenseAxisArrayView) = show(io, x)
629+
630+
function Base.summary(io::IO, x::DenseAxisArrayView)
631+
return print(io, "view(::DenseAxisArray, ", join(x.axes, ", "), "), over")
632+
end

test/Containers/DenseAxisArray.jl

+117
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,123 @@ function test_DenseAxisArray_vector_keys()
489489
return
490490
end
491491

492+
function test_containers_denseaxisarray_setindex_vector()
493+
A = Containers.DenseAxisArray(zeros(3), 1:3)
494+
A[2:3] .= 1.0
495+
@test A.data == [0.0, 1.0, 1.0]
496+
A = Containers.DenseAxisArray(zeros(3), 1:3)
497+
A[[2, 3]] .= 1.0
498+
@test A.data == [0.0, 1.0, 1.0]
499+
A = Containers.DenseAxisArray(zeros(3), 1:3)
500+
A[[1, 3]] .= 1.0
501+
@test A.data == [1.0, 0.0, 1.0]
502+
A = Containers.DenseAxisArray(zeros(3), 1:3)
503+
A[[2]] .= 1.0
504+
@test A.data == [0.0, 1.0, 0.0]
505+
A[2:3] = Containers.DenseAxisArray([2.0, 3.0], 2:3)
506+
@test A.data == [0.0, 2.0, 3.0]
507+
A = Containers.DenseAxisArray(zeros(3), 1:3)
508+
A[:] .= 1.0
509+
@test A.data == [1.0, 1.0, 1.0]
510+
return
511+
end
512+
513+
function test_containers_denseaxisarray_setindex_matrix()
514+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
515+
A[:, [:a, :b]] .= 1.0
516+
@test A.data == [1.0 1.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
517+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
518+
A[2:3, [:a, :b]] .= 1.0
519+
@test A.data == [0.0 0.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
520+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
521+
A[3:3, [:a, :b]] .= 1.0
522+
@test A.data == [0.0 0.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
523+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
524+
A[[1, 3], [:a, :b]] .= 1.0
525+
@test A.data == [1.0 1.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
526+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
527+
A[[1, 3], [:a, :c]] .= 1.0
528+
@test A.data == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0]
529+
return
530+
end
531+
532+
function test_containers_denseaxisarray_view()
533+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
534+
B = view(A, :, [:a, :b])
535+
@test_throws KeyError view(A, :, [:d])
536+
@test size(B) == (3, 2)
537+
@test B[1, :a] == A[1, :a]
538+
@test B[3, :a] == A[3, :a]
539+
# Views are weird, because we can still access the underlying array?
540+
@test B[3, :c] == A[3, :c]
541+
@test sprint(show, B) == sprint(show, B.data)
542+
@test sprint(Base.print_array, B) == sprint(show, B.data)
543+
@test sprint(Base.summary, B) ==
544+
"view(::DenseAxisArray, 1:3, [:a, :b]), over"
545+
return
546+
end
547+
548+
function test_containers_denseaxisarray_jump_3151()
549+
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
550+
E = Containers.DenseAxisArray(ones(3), [:a, :b, :c])
551+
I = [:a, :b]
552+
D[I] = E[I]
553+
@test D.data == [1.0, 1.0, 0.0]
554+
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
555+
I = [:b, :c]
556+
D[I] = E[I]
557+
@test D.data == [0.0, 1.0, 1.0]
558+
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
559+
I = [:a, :c]
560+
D[I] = E[I]
561+
@test D.data == [1.0, 0.0, 1.0]
562+
return
563+
end
564+
565+
function test_containers_denseaxisarray_view_operations()
566+
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
567+
d = view(c, 2:3, :)
568+
@test sum(c) == 60
569+
@test sum(d) == 30
570+
d .= 1
571+
@test sum(d) == 4
572+
@test sum(c) == 34
573+
return
574+
end
575+
576+
function test_containers_denseaxisarray_view_addition()
577+
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
578+
d = view(c, 2:3, :)
579+
@test_throws MethodError d + d
580+
return
581+
end
582+
583+
function test_containers_denseaxisarray_setindex_invalid()
584+
c = Containers.@container([i = 1:4, j = 2:3], 0)
585+
d = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
586+
setindex!(c, d, 1:4, 2:3)
587+
@test c == d
588+
c .= 0
589+
setindex!(c, d, 1:4, 2:2)
590+
@test c == Containers.@container([i = 1:4, j = 2:3], (4 + i) * (j == 2))
591+
d = Containers.@container([i = 5:6, j = 2:3], i + 2 * j)
592+
@test_throws KeyError setindex!(c, d, 1:4, 2:3)
593+
return
594+
end
595+
596+
function test_containers_denseaxisarray_setindex_keys()
597+
c = Containers.@container([i = 1:4, j = 2:3], 0)
598+
for (i, k) in enumerate(keys(c))
599+
c[k] = c[k] + i
600+
end
601+
@test c == Containers.@container([i = 1:4, j = 2:3], 4 * (j - 2) + i)
602+
for (i, k) in enumerate(keys(c))
603+
c[k] = c[k] + i
604+
end
605+
@test c == Containers.@container([i = 1:4, j = 2:3], 2 * (4 * (j - 2) + i))
606+
return
607+
end
608+
492609
end # module
493610

494611
TestContainersDenseAxisArray.runtests()

0 commit comments

Comments
 (0)