Skip to content

Commit 31d90e9

Browse files
committed
Fix getindex of DenseAxisArrayView
1 parent cf35a8a commit 31d90e9

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

src/Containers/DenseAxisArray.jl

+17-6
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,6 @@ function Base.setindex!(
353353
value::DenseAxisArray{T,N},
354354
args...,
355355
) where {T,N}
356-
@show args
357-
@show Base.to_index(A, args)
358356
for key in Base.product(args...)
359357
A[key...] = value[key...]
360358
end
@@ -583,22 +581,30 @@ Base.repeat(x::DenseAxisArray; kwargs...) = repeat(x.data; kwargs...)
583581

584582
_get_subaxis(::Colon, b) = b
585583

586-
function _get_subaxis(a, b)
584+
function _get_subaxis(a::AbstractVector, b)
587585
for ai in a
588586
if !(ai in b)
589587
throw(KeyError(ai))
590588
end
591589
end
592590
return a
593591
end
592+
593+
function _get_subaxis(a::T, b::AbstractVector{T}) where {T}
594+
if !(a in b)
595+
throw(KeyError(a))
596+
end
597+
return a
598+
end
599+
594600
struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N}
595601
data::D
596602
axes::A
597603
function DenseAxisArrayView(
598604
x::Containers.DenseAxisArray{T,N},
599605
args...,
600606
) where {T,N}
601-
axis = tuple([_get_subaxis(a, b) for (a, b) in zip(args, axes(x))]...)
607+
axis = _get_subaxis.(args, axes(x))
602608
return new{T,N,typeof(x),typeof(axis)}(x, axis)
603609
end
604610
end
@@ -612,15 +618,20 @@ Base.size(x::DenseAxisArrayView) = length.(x.axes)
612618
Base.axes(x::DenseAxisArrayView) = x.axes
613619

614620
function Base.getindex(x::DenseAxisArrayView, args...)
615-
return getindex(x.data, args...)
621+
y = _get_subaxis.(args, x.axes)
622+
return getindex(x.data, y...)
616623
end
617624

625+
Base.getindex(a::DenseAxisArrayView, k::DenseAxisArrayKey) = a[k.I...]
626+
618627
function Base.setindex!(x::DenseAxisArrayView, args...)
619628
return setindex!(x.data, args...)
620629
end
621630

622631
function Base.eachindex(A::DenseAxisArrayView)
623-
return DenseAxisArrayKey.(Base.product(A.axes...))
632+
# Return a generator so that we lazily evaluate the product instead of
633+
# collecting into a vector.
634+
return (DenseAxisArrayKey(k) for k in Base.product(A.axes...))
624635
end
625636

626637
Base.show(io::IO, x::DenseAxisArrayView) = print(io, x.data)

test/Containers/DenseAxisArray.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -536,8 +536,7 @@ function test_containers_denseaxisarray_view()
536536
@test size(B) == (3, 2)
537537
@test B[1, :a] == A[1, :a]
538538
@test B[3, :a] == A[3, :a]
539-
# Views are weird, because we can still access the underlying array?
540-
@test B[3, :c] == A[3, :c]
539+
@test_throws KeyError B[3, :c]
541540
@test sprint(show, B) == sprint(show, B.data)
542541
@test sprint(Base.print_array, B) == sprint(show, B.data)
543542
@test sprint(Base.summary, B) ==
@@ -580,6 +579,13 @@ function test_containers_denseaxisarray_view_addition()
580579
return
581580
end
582581

582+
function test_containers_denseaxisarray_view_colon()
583+
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
584+
d = view(c, 2:3, :)
585+
@test d[:, 2] == Containers.@container([i = 2:3], i + 2 * 2)
586+
return
587+
end
588+
583589
function test_containers_denseaxisarray_setindex_invalid()
584590
c = Containers.@container([i = 1:4, j = 2:3], 0)
585591
d = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)

0 commit comments

Comments
 (0)