Skip to content

Commit cbe4dbc

Browse files
author
Michael Abbott
committed
add keyword indexing + constructor
1 parent 032327a commit cbe4dbc

File tree

4 files changed

+32
-1
lines changed

4 files changed

+32
-1
lines changed

src/core.jl

+5
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ function AxisArray(A::AbstractArray{T,N}, names::NTuple{N,Symbol}, steps::NTuple
235235
AxisArray(A, axs...)
236236
end
237237

238+
# Alternative constructor, takes names as keywords:
239+
AxisArray(A; kw...) = AxisArray(A, nt_to_axes(kw.data)...)
240+
@generated nt_to_axes(nt::NamedTuple) =
241+
Expr(:tuple, (:(Axis{$(QuoteNode(n))}(getfield(nt, $(QuoteNode(n))))) for n in nt.names)...)
242+
238243
AxisArray(A::AxisArray) = A
239244
AxisArray(A::AxisArray, ax::Vararg{Axis, N}) where N =
240245
AxisArray(A.data, ax..., last(Base.IteratorsMD.split(axes(A), Val(N)))...)

src/indexing.jl

+13-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ end
129129
@propagate_inbounds getindex_converted(A, idxs...) = A.data[idxs...]
130130
@propagate_inbounds setindex!_converted(A, v, idxs...) = (A.data[idxs...] = v)
131131

132-
133132
# First is indexing by named axis. We simply sort the axes and re-dispatch.
134133
# When indexing by named axis the shapes of omitted dimensions are preserved
135134
# TODO: should we handle multidimensional Axis indexes? It could be interpreted
@@ -155,6 +154,19 @@ function Base.reshape(A::AxisArray, ::Val{N}) where N
155154
AxisArray(reshape(A.data, Val(N)), Base.front(axN))
156155
end
157156

157+
# Keyword indexing, reconstructs the Axis{}() objects
158+
@propagate_inbounds Base.view(A::AxisArray; kw...) =
159+
view(A, kw_to_axes(parent(A), kw.data)...)
160+
@propagate_inbounds Base.getindex(A::AxisArray; kw...) =
161+
getindex(A, kw_to_axes(parent(A), kw.data)...)
162+
@propagate_inbounds Base.setindex!(A::AxisArray, val; kw...) =
163+
setindex!(A, val, kw_to_axes(parent(A), kw.data)...)
164+
165+
function kw_to_axes(A::AbstractArray, nt::NamedTuple)
166+
length(nt) == 0 && throw(BoundsError(A, ())) # Trivial case A[] lands here
167+
nt_to_axes(nt)
168+
end
169+
158170
### Indexing along values of the axes ###
159171

160172
# Default axes indexing throws an error

test/core.jl

+5
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,11 @@ B = AxisArray([1 4; 2 5; 3 6], (:x, :y), (0.2, 100))
132132
B = AxisArray([1 4; 2 5; 3 6], (:x, :y), (0.2, 100), (-3,14))
133133
@test axisnames(B) == (:x, :y)
134134
@test axisvalues(B) == (-3:0.2:-2.6, 14:100:114)
135+
# Keyword constructor
136+
C = AxisArray([1 4; 2 5; 3 6], x=10:10:30, y=[:a, :b])
137+
@test axisnames(C) == (:x, :y)
138+
@test axisvalues(C) == (10:10:30, [:a, :b])
139+
@test @inferred(AxisArray(parent(C), x=1:3, y=1:2)) isa AxisArray
135140

136141
@test AxisArrays.HasAxes(A) == AxisArrays.HasAxes{true}()
137142
@test AxisArrays.HasAxes([1]) == AxisArrays.HasAxes{false}()

test/indexing.jl

+9
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,12 @@ a = [2, 3, 7]
302302
@test a[idx] 6.2
303303
aa = AxisArray(a, :x)
304304
@test aa[idx] 6.2
305+
306+
# Keyword indexing
307+
A = AxisArray([1 2; 3 4], Axis{:x}(10:10:20), Axis{:y}(["c", "d"]))
308+
@test @inferred(A[x=1, y=1]) == 1
309+
@test @inferred(A[x=1]) == [1, 2]
310+
@test axisnames(A[x=1]) == (:y,)
311+
@test @inferred(view(A, x=1)) == [1,2]
312+
@test parent(view(A, x=1)) isa SubArray
313+
@test @inferred(A[x=atvalue(20), y=atvalue("d")]) == 4

0 commit comments

Comments
 (0)