@@ -353,8 +353,6 @@ function Base.setindex!(
353
353
value:: DenseAxisArray{T,N} ,
354
354
args... ,
355
355
) where {T,N}
356
- @show args
357
- @show Base. to_index (A, args)
358
356
for key in Base. product (args... )
359
357
A[key... ] = value[key... ]
360
358
end
@@ -583,22 +581,30 @@ Base.repeat(x::DenseAxisArray; kwargs...) = repeat(x.data; kwargs...)
583
581
584
582
_get_subaxis (:: Colon , b) = b
585
583
586
- function _get_subaxis (a, b)
584
+ function _get_subaxis (a:: AbstractVector , b)
587
585
for ai in a
588
586
if ! (ai in b)
589
587
throw (KeyError (ai))
590
588
end
591
589
end
592
590
return a
593
591
end
592
+
593
+ function _get_subaxis (a:: T , b:: AbstractVector{T} ) where {T}
594
+ if ! (a in b)
595
+ throw (KeyError (a))
596
+ end
597
+ return a
598
+ end
599
+
594
600
struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N}
595
601
data:: D
596
602
axes:: A
597
603
function DenseAxisArrayView (
598
604
x:: Containers.DenseAxisArray{T,N} ,
599
605
args... ,
600
606
) where {T,N}
601
- axis = tuple ([ _get_subaxis (a, b) for (a, b) in zip ( args, axes (x))] . .. )
607
+ axis = _get_subaxis .( args, axes (x))
602
608
return new {T,N,typeof(x),typeof(axis)} (x, axis)
603
609
end
604
610
end
@@ -612,15 +618,20 @@ Base.size(x::DenseAxisArrayView) = length.(x.axes)
612
618
Base. axes (x:: DenseAxisArrayView ) = x. axes
613
619
614
620
function Base. getindex (x:: DenseAxisArrayView , args... )
615
- return getindex (x. data, args... )
621
+ y = _get_subaxis .(args, x. axes)
622
+ return getindex (x. data, y... )
616
623
end
617
624
625
+ Base. getindex (a:: DenseAxisArrayView , k:: DenseAxisArrayKey ) = a[k. I... ]
626
+
618
627
function Base. setindex! (x:: DenseAxisArrayView , args... )
619
628
return setindex! (x. data, args... )
620
629
end
621
630
622
631
function Base. eachindex (A:: DenseAxisArrayView )
623
- return DenseAxisArrayKey .(Base. product (A. axes... ))
632
+ # Return a generator so that we lazily evaluate the product instead of
633
+ # collecting into a vector.
634
+ return (DenseAxisArrayKey (k) for k in Base. product (A. axes... ))
624
635
end
625
636
626
637
Base. show (io:: IO , x:: DenseAxisArrayView ) = print (io, x. data)
0 commit comments