Skip to content

Commit b6d001b

Browse files
authored
Merge branch 'master' into od/ambiguity
2 parents 50d134a + 9f61e95 commit b6d001b

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

src/Containers/DenseAxisArray.jl

+86
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,17 @@ function Base.IndexStyle(::Type{DenseAxisArray{T,N,Ax}}) where {T,N,Ax}
361361
return IndexAnyCartesian()
362362
end
363363

364+
function Base.setindex!(
365+
A::DenseAxisArray{T,N},
366+
value::DenseAxisArray{T,N},
367+
args...,
368+
) where {T,N}
369+
for key in Base.product(args...)
370+
A[key...] = value[key...]
371+
end
372+
return A
373+
end
374+
364375
########
365376
# Keys #
366377
########
@@ -376,6 +387,10 @@ end
376387
Base.getindex(k::DenseAxisArrayKey, args...) = getindex(k.I, args...)
377388
Base.getindex(a::DenseAxisArray, k::DenseAxisArrayKey) = a[k.I...]
378389

390+
function Base.setindex!(A::DenseAxisArray, value, key::DenseAxisArrayKey)
391+
return setindex!(A, value, key.I...)
392+
end
393+
379394
struct DenseAxisArrayKeys{T<:Tuple,S<:DenseAxisArrayKey,N} <: AbstractArray{S,N}
380395
product_iter::Base.Iterators.ProductIterator{T}
381396
function DenseAxisArrayKeys(a::DenseAxisArray{TT,N,Ax}) where {TT,N,Ax}
@@ -572,3 +587,74 @@ end
572587
# but some users may depend on it's functionality so we have a work-around
573588
# instead of just breaking code.
574589
Base.repeat(x::DenseAxisArray; kwargs...) = repeat(x.data; kwargs...)
590+
591+
###
592+
### view
593+
###
594+
595+
_get_subaxis(::Colon, b) = b
596+
597+
function _get_subaxis(a::AbstractVector, b)
598+
for ai in a
599+
if !(ai in b)
600+
throw(KeyError(ai))
601+
end
602+
end
603+
return a
604+
end
605+
606+
function _get_subaxis(a::T, b::AbstractVector{T}) where {T}
607+
if !(a in b)
608+
throw(KeyError(a))
609+
end
610+
return a
611+
end
612+
613+
struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N}
614+
data::D
615+
axes::A
616+
function DenseAxisArrayView(
617+
x::Containers.DenseAxisArray{T,N},
618+
args...,
619+
) where {T,N}
620+
axis = _get_subaxis.(args, axes(x))
621+
return new{T,N,typeof(x),typeof(axis)}(x, axis)
622+
end
623+
end
624+
625+
function Base.view(A::Containers.DenseAxisArray, args...)
626+
return DenseAxisArrayView(A, args...)
627+
end
628+
629+
Base.size(x::DenseAxisArrayView) = length.(x.axes)
630+
631+
Base.axes(x::DenseAxisArrayView) = x.axes
632+
633+
function Base.getindex(x::DenseAxisArrayView, args...)
634+
y = _get_subaxis.(args, x.axes)
635+
return getindex(x.data, y...)
636+
end
637+
638+
Base.getindex(a::DenseAxisArrayView, k::DenseAxisArrayKey) = a[k.I...]
639+
640+
function Base.setindex!(x::DenseAxisArrayView, args...)
641+
return setindex!(x.data, args...)
642+
end
643+
644+
function Base.eachindex(A::DenseAxisArrayView)
645+
# Return a generator so that we lazily evaluate the product instead of
646+
# collecting into a vector.
647+
#
648+
# In future, we might want to return the appropriate matrix of
649+
# `CartesianIndex` to avoid having to do the lookups with
650+
# `DenseAxisArrayKey`.
651+
return (DenseAxisArrayKey(k) for k in Base.product(A.axes...))
652+
end
653+
654+
Base.show(io::IO, x::DenseAxisArrayView) = print(io, x.data)
655+
656+
Base.print_array(io::IO, x::DenseAxisArrayView) = show(io, x)
657+
658+
function Base.summary(io::IO, x::DenseAxisArrayView)
659+
return print(io, "view(::DenseAxisArray, ", join(x.axes, ", "), "), over")
660+
end

test/Containers/test_DenseAxisArray.jl

+123
Original file line numberDiff line numberDiff line change
@@ -490,4 +490,127 @@ function test_ambiguity_isassigned()
490490
return
491491
end
492492

493+
function test_containers_denseaxisarray_setindex_vector()
494+
A = Containers.DenseAxisArray(zeros(3), 1:3)
495+
A[2:3] .= 1.0
496+
@test A.data == [0.0, 1.0, 1.0]
497+
A = Containers.DenseAxisArray(zeros(3), 1:3)
498+
A[[2, 3]] .= 1.0
499+
@test A.data == [0.0, 1.0, 1.0]
500+
A = Containers.DenseAxisArray(zeros(3), 1:3)
501+
A[[1, 3]] .= 1.0
502+
@test A.data == [1.0, 0.0, 1.0]
503+
A = Containers.DenseAxisArray(zeros(3), 1:3)
504+
A[[2]] .= 1.0
505+
@test A.data == [0.0, 1.0, 0.0]
506+
A[2:3] = Containers.DenseAxisArray([2.0, 3.0], 2:3)
507+
@test A.data == [0.0, 2.0, 3.0]
508+
A = Containers.DenseAxisArray(zeros(3), 1:3)
509+
A[:] .= 1.0
510+
@test A.data == [1.0, 1.0, 1.0]
511+
return
512+
end
513+
514+
function test_containers_denseaxisarray_setindex_matrix()
515+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
516+
A[:, [:a, :b]] .= 1.0
517+
@test A.data == [1.0 1.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
518+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
519+
A[2:3, [:a, :b]] .= 1.0
520+
@test A.data == [0.0 0.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
521+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
522+
A[3:3, [:a, :b]] .= 1.0
523+
@test A.data == [0.0 0.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
524+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
525+
A[[1, 3], [:a, :b]] .= 1.0
526+
@test A.data == [1.0 1.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
527+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
528+
A[[1, 3], [:a, :c]] .= 1.0
529+
@test A.data == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0]
530+
return
531+
end
532+
533+
function test_containers_denseaxisarray_view()
534+
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
535+
B = view(A, :, [:a, :b])
536+
@test_throws KeyError view(A, :, [:d])
537+
@test size(B) == (3, 2)
538+
@test B[1, :a] == A[1, :a]
539+
@test B[3, :a] == A[3, :a]
540+
@test_throws KeyError B[3, :c]
541+
@test sprint(show, B) == sprint(show, B.data)
542+
@test sprint(Base.print_array, B) == sprint(show, B.data)
543+
@test sprint(Base.summary, B) ==
544+
"view(::DenseAxisArray, 1:3, [:a, :b]), over"
545+
return
546+
end
547+
548+
function test_containers_denseaxisarray_jump_3151()
549+
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
550+
E = Containers.DenseAxisArray(ones(3), [:a, :b, :c])
551+
I = [:a, :b]
552+
D[I] = E[I]
553+
@test D.data == [1.0, 1.0, 0.0]
554+
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
555+
I = [:b, :c]
556+
D[I] = E[I]
557+
@test D.data == [0.0, 1.0, 1.0]
558+
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
559+
I = [:a, :c]
560+
D[I] = E[I]
561+
@test D.data == [1.0, 0.0, 1.0]
562+
return
563+
end
564+
565+
function test_containers_denseaxisarray_view_operations()
566+
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
567+
d = view(c, 2:3, :)
568+
@test sum(c) == 60
569+
@test sum(d) == 30
570+
d .= 1
571+
@test sum(d) == 4
572+
@test sum(c) == 34
573+
return
574+
end
575+
576+
function test_containers_denseaxisarray_view_addition()
577+
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
578+
d = view(c, 2:3, :)
579+
@test_throws MethodError d + d
580+
return
581+
end
582+
583+
function test_containers_denseaxisarray_view_colon()
584+
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
585+
d = view(c, 2:3, :)
586+
@test d[:, 2] == Containers.@container([i = 2:3], i + 2 * 2)
587+
return
588+
end
589+
590+
function test_containers_denseaxisarray_setindex_invalid()
591+
c = Containers.@container([i = 1:4, j = 2:3], 0)
592+
d = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
593+
setindex!(c, d, 1:4, 2:3)
594+
@test c == d
595+
c .= 0
596+
setindex!(c, d, 1:4, 2:2)
597+
@test c == Containers.@container([i = 1:4, j = 2:3], (4 + i) * (j == 2))
598+
d = Containers.@container([i = 5:6, j = 2:3], i + 2 * j)
599+
@test_throws KeyError setindex!(c, d, 1:4, 2:3)
600+
return
601+
end
602+
603+
function test_containers_denseaxisarray_setindex_keys()
604+
c = Containers.@container([i = 1:4, j = 2:3], 0)
605+
for (i, k) in enumerate(keys(c))
606+
c[k] = c[k] + i
607+
end
608+
@test c == Containers.@container([i = 1:4, j = 2:3], 4 * (j - 2) + i)
609+
for (i, k) in enumerate(keys(c))
610+
c[k] = c[k] + i
611+
end
612+
@test c == Containers.@container([i = 1:4, j = 2:3], 2 * (4 * (j - 2) + i))
613+
return
614+
end
615+
493616
end # module

0 commit comments

Comments
 (0)