Skip to content

Commit b078564

Browse files
authored
Merge pull request #164 from JuliaArrays/teh/unknown_index_types
Try converting indexes once and then defer to the parent
2 parents 7d55ab1 + 286baf6 commit b078564

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

src/indexing.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,22 @@ end
114114
@propagate_inbounds Base.setindex!(A::AxisArray, v, idx::AbstractArray{Bool}) = (A.data[idx] = v)
115115

116116
### Fancier indexing capabilities provided only by AxisArrays ###
117-
@propagate_inbounds Base.getindex(A::AxisArray, idxs...) = A[to_index(A,idxs...)...]
118-
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs...) = (A[to_index(A,idxs...)...] = v)
117+
# To avoid StackOverflowErrors on indexes that we don't know how to convert, we
118+
# give AxisArrays once chance to convert into known format and then defer to the parent
119+
@propagate_inbounds Base.getindex(A::AxisArray, idxs...) = getindex_converted(A, to_index(A,idxs...)...)
120+
@propagate_inbounds Base.setindex!(A::AxisArray, v, idxs...) = setindex!_converted(A, v, to_index(A,idxs...)...)
119121
# Deal with lots of ambiguities here
120122
@propagate_inbounds Base.view(A::AxisArray, idxs::ViewIndex...) = view(A, to_index(A,idxs...)...)
121123
@propagate_inbounds Base.view(A::AxisArray, idxs::Union{ViewIndex,AbstractCartesianIndex}...) = view(A, to_index(A,Base.IteratorsMD.flatten(idxs)...)...)
122124
@propagate_inbounds Base.view(A::AxisArray, idxs...) = view(A, to_index(A,idxs...)...)
123125

126+
@propagate_inbounds getindex_converted(A, idxs::Idx...) = A[idxs...]
127+
@propagate_inbounds setindex!_converted(A, v, idxs::Idx...) = (A[idxs...] = v)
128+
129+
@propagate_inbounds getindex_converted(A, idxs...) = A.data[idxs...]
130+
@propagate_inbounds setindex!_converted(A, v, idxs...) = (A.data[idxs...] = v)
131+
132+
124133
# First is indexing by named axis. We simply sort the axes and re-dispatch.
125134
# When indexing by named axis the shapes of omitted dimensions are preserved
126135
# TODO: should we handle multidimensional Axis indexes? It could be interpreted
@@ -179,7 +188,7 @@ axisindexes(t, ax, idx) = error("cannot index $(typeof(ax)) with $(typeof(idx));
179188
# Maybe extend error message to all <: Numbers if Base allows it?
180189
axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::Real) =
181190
throw(ArgumentError("invalid index: $idx. Use `atvalue` when indexing by value."))
182-
function axisindexes(::Type{Dimensional}, ax::AbstractVector, idx)
191+
function axisindexes(::Type{Dimensional}, ax::AbstractVector{T}, idx::T) where T
183192
idxs = searchsorted(ax, ClosedInterval(idx,idx))
184193
length(idxs) > 1 && error("more than one datapoint lies on axis value $idx; use an interval to return all values")
185194
if length(idxs) == 1
@@ -218,6 +227,8 @@ function axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::ExactValue)
218227
throw(BoundsError(ax, idx))
219228
end
220229
end
230+
# For index types that AxisArrays doesn't know about
231+
axisindexes(::Type{Dimensional}, ax::AbstractVector, idx) = idx
221232

222233
# Dimensional axes may be indexed by intervals to select a range
223234
axisindexes(::Type{Dimensional}, ax::AbstractVector, idx::ClosedInterval) = searchsorted(ax, idx)

test/indexing.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,20 @@ A = AxisArray(1:365, Date(2017,1,1):Day(1):Date(2017,12,31))
285285
@test A[(-Day(13)..Day(14)) + Date(2017,2,14)] == collect(31 .+ (1:28))
286286
@test A[(-Day(14)..Day(14)) + DateTime(2017,2,14,12)] == collect(31 .+ (1:28))
287287
@test A[(Day(0)..Day(6)) + (Date(2017,1,1):Month(1):Date(2017,4,12))] == [1:7 32:38 60:66 91:97]
288+
289+
# Test using index types that AxisArrays doesn't understand
290+
# This example is inspired from Interpolations.jl and would implement linear interpolation
291+
struct WeightedIndex{T}
292+
idx::Int
293+
weights::Tuple{T,T}
294+
end
295+
Base.to_indices(A, I::Tuple{Vararg{Union{Int,WeightedIndex}}}) = I
296+
@inline function Base._getindex(::IndexStyle, A::AbstractVector, I::WeightedIndex)
297+
A[I.idx]*I.weights[1] + A[I.idx+1]*I.weights[2]
298+
end
299+
idx = WeightedIndex(2, (0.2, 0.8))
300+
301+
a = [2, 3, 7]
302+
@test a[idx] 6.2
303+
aa = AxisArray(a, :x)
304+
@test aa[idx] 6.2

0 commit comments

Comments
 (0)