-
-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathcomponentarray.jl
341 lines (286 loc) · 12.5 KB
/
componentarray.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
"""
x = ComponentArray(nt::NamedTuple)
x = ComponentArray(;kwargs...)
x = ComponentArray(data::AbstractVector, ax)
x = ComponentArray{T}(args...; kwargs...) where T
Array type that can be accessed like an arbitrary nested mutable struct.
# Examples
```jldoctest
julia> using ComponentArrays
julia> x = ComponentArray(a=1, b=[2, 1, 4], c=(a=2, b=[1, 2]))
ComponentVector{Int64}(a = 1, b = [2, 1, 4], c = (a = 2, b = [1, 2]))
julia> x.c.a = 400; x
ComponentVector{Int64}(a = 1, b = [2, 1, 4], c = (a = 400, b = [1, 2]))
julia> x[5]
400
julia> collect(x)
7-element Vector{Int64}:
1
2
1
4
400
1
2
```
"""
struct ComponentArray{T,N,A<:AbstractArray{T,N},Axes<:Tuple{Vararg{<:AbstractAxis}}} <: DenseArray{T,N}
data::A
axes::Axes
end
# Entry from type (used for broadcasting)
ComponentArray{Axes}(data) where Axes = ComponentArray(data, getaxes(Axes)...)
ComponentArray(::UndefInitializer, ax::Axes) where Axes<:Tuple =
ComponentArray(similar(Array{Float64}, last_index.(ax)), ax...)
ComponentArray{A}(::UndefInitializer, ax::Axes) where {A<:AbstractArray,Axes<:Tuple} =
ComponentArray(similar(A, last_index.(ax)), ax...)
ComponentArray{T}(::UndefInitializer, ax::Axes) where {T,Axes<:Tuple} =
ComponentArray(similar(Array{T}, last_index.(ax)), ax...)
# Entry from data array and AbstractAxis types dispatches to correct shapes and partitions
# then packs up axes into a tuple for inner constructor
ComponentArray(data, ::FlatAxis...) = data
ComponentArray(data, ax::NotShapedOrPartitionedAxis...) = ComponentArray(data, ax)
ComponentArray(data, ax::NotPartitionedAxis...) = ComponentArray(maybe_reshape(data, ax...), unshape.(ax)...)
function ComponentArray(data, ax::AbstractAxis...)
part_axs = filter_by_type(PartitionedAxis, ax...)
part_data = partition(data, size.(part_axs)...)
axs = Axis.(ax)
return LazyArray(ComponentArray(x, axs) for x in part_data)
end
# Entry from NamedTuple, Dict, or kwargs
ComponentArray{T}(nt::NamedTuple) where T = ComponentArray(make_carray_args(T, nt)...)
ComponentArray{T}(::NamedTuple{(), Tuple{}}) where T = ComponentArray(T[], (FlatAxis(),))
ComponentArray(nt::NamedTuple) = ComponentArray(make_carray_args(nt)...)
ComponentArray(::NamedTuple{(), Tuple{}}) = ComponentArray(Any[], (FlatAxis(),))
ComponentArray(d::AbstractDict) = ComponentArray(NamedTuple{Tuple(keys(d))}(values(d)))
ComponentArray{T}(;kwargs...) where T = ComponentArray{T}((;kwargs...))
ComponentArray(;kwargs...) = ComponentArray((;kwargs...))
ComponentArray(x::ComponentArray) = x
ComponentArray{T}(x::ComponentArray) where {T} = T.(x)
(CA::Type{<:ComponentArray{T,N,A,Ax}})(x::ComponentArray) where {T,N,A,Ax} = ComponentArray(T.(getdata(x)), getaxes(x))
## Some aliases
"""
x = ComponentVector(nt::NamedTuple)
x = ComponentVector(;kwargs...)
x = ComponentVector(data::AbstractVector, ax)
x = ComponentVector{T}(args...; kwargs...) where T
A `ComponentVector` is an alias for a one-dimensional `ComponentArray`.
"""
const ComponentVector{T,A,Axes} = ComponentArray{T,1,A,Axes}
ComponentVector(nt) = ComponentArray(nt)
ComponentVector{T}(nt) where {T} = ComponentArray{T}(nt)
ComponentVector(;kwargs...) = ComponentArray(;kwargs...)
ComponentVector{T}(;kwargs...) where {T} = ComponentArray{T}(;kwargs...)
ComponentVector{T}(::UndefInitializer, ax) where {T} = ComponentArray{T}(undef, ax)
ComponentVector(data::AbstractVector, ax) = ComponentArray(data, ax)
ComponentVector(data::AbstractArray, ax) = throw(DimensionMismatch("A `ComponentVector` must be initialized with a 1-dimensional array. This array is $(ndims(data))-dimensional."))
# Add new fields to component Vector
function ComponentArray(x::ComponentVector; kwargs...)
return foldl((x1, kwarg) -> _maybe_add_field(x1, kwarg), (kwargs...,); init=x)
end
ComponentVector(x::ComponentVector; kwargs...) = ComponentArray(x; kwargs...)
ComponentVector{T}(x::ComponentVector) where {T} = T.(x)
"""
x = ComponentMatrix(data::AbstractMatrix, ax...)
x = ComponentMatrix{T}(data::AbstractMatrix, ax...) where T
A `ComponentMatrix` is an alias for a two-dimensional `ComponentArray`.
"""
const ComponentMatrix{T,A,Axes} = ComponentArray{T,2,A,Axes}
ComponentMatrix{T}(::UndefInitializer, ax...) where {T} = ComponentArray{T}(undef, ax...)
ComponentMatrix(data::AbstractMatrix, ax...) = ComponentArray(data, ax...)
ComponentMatrix(data::AbstractArray, ax...) = throw(DimensionMismatch("A `ComponentMatrix` must be initialized with a 2-dimensional array. This array is $(ndims(data))-dimensional."))
ComponentMatrix(x::ComponentMatrix) = x
ComponentMatrix{T}(x::ComponentMatrix) where {T} = T.(x)
ComponentMatrix() = ComponentMatrix(Array{Any}(undef, 0, 0), (FlatAxis(), FlatAxis()))
ComponentMatrix{T}() where {T} = ComponentMatrix(Array{T}(undef, 0, 0), (FlatAxis(), FlatAxis()))
const CArray = ComponentArray
const CVector = ComponentVector
const CMatrix = ComponentMatrix
const AdjOrTrans{T, A} = Union{Adjoint{T, A}, Transpose{T, A}}
const AdjOrTransComponentArray{T, A} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentArray
const AdjOrTransComponentVector{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentVector
const AdjOrTransComponentMatrix{T} = Union{Adjoint{T, A}, Transpose{T, A}} where A<:ComponentMatrix
const ComponentVecOrMat = Union{ComponentVector, ComponentMatrix}
const AdjOrTransComponentVecOrMat = AdjOrTrans{T, <:ComponentVecOrMat} where T
const AbstractComponentArray = Union{ComponentArray, AdjOrTransComponentArray}
const AbstractComponentVecOrMat = Union{ComponentVecOrMat, AdjOrTransComponentVecOrMat}
const AbstractComponentVector = Union{ComponentVector, AdjOrTransComponentVector}
const AbstractComponentMatrix = Union{ComponentMatrix, AdjOrTransComponentMatrix}
## Constructor helpers
# For making ComponentArrays from named tuples
make_carray_args(::NamedTuple{(), Tuple{}}) = (Any[], FlatAxis())
make_carray_args(::Type{T}, ::NamedTuple{(), Tuple{}}) where {T} = (T[], FlatAxis())
function make_carray_args(nt)
data, ax = make_carray_args(Vector, nt)
data = length(data)==1 ? [data[1]] : reduce(vcat, data)
return (data, ax)
end
make_carray_args(::Type{T}, nt) where {T} = make_carray_args(Vector{T}, nt)
function make_carray_args(A::Type{<:AbstractArray}, nt)
data, idx = make_idx([], nt, 0)
return (A(data), Axis(idx))
end
# Builds up data vector and returns appropriate AbstractAxis type for each input type
function make_idx(data, nt::NamedTuple, last_val)
len = recursive_length(nt)
kvs = []
lv = 0
for (k,v) in zip(keys(nt), values(nt))
(_,val) = make_idx(data, v, lv)
push!(kvs, k => val)
lv = val
end
return (data, ViewAxis(last_index(last_val) .+ (1:len), (;kvs...)))
end
function make_idx(data, pair::Pair, last_val)
data, ax = make_idx(data, pair.second, last_val)
return (data, ViewAxis(last_val:(last_val+len-1), Axis(pair.second)))
end
make_idx(data, x, last_val) = (
push!(data, x),
ViewAxis(last_index(last_val) + 1)
)
make_idx(data, x::ComponentVector, last_val) = (
pushcat!(data, x),
ViewAxis(
last_index(last_val) .+ (1:length(x)),
getaxes(x)[1]
)
)
function make_idx(data, x::AbstractArray, last_val)
pushcat!(data, x)
out = last_index(last_val) .+ (1:length(x))
return (data, ViewAxis(out, ShapedAxis(size(x))))
end
function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:Union{NamedTuple, ComponentArray}}}
len = recursive_length(x)
if eltype(x) |> isconcretetype
out = ()
for elem in x
(_,out) = make_idx(data, elem, last_val)
end
return (
data,
ViewAxis(
last_index(last_val) .+ (1:len),
PartitionedAxis(
len ÷ length(x),
indexmap(out)
)
)
)
else
error("Only homogeneous arrays of inner ComponentArrays are allowed.")
end
end
function make_idx(data, x::A, last_val) where {A<:AbstractArray{<:AbstractArray}}
error("ComponentArrays cannot currently contain arrays of arrays as elements. This one contains: \n $x\n")
end
#TODO: Make all internal function names start with underscores
_maybe_add_field(x, pair) = haskey(x, pair.first) ? _update_field(x, pair) : _add_field(x, pair)
function _add_field(x, pair)
data = copy(getdata(x))
new_data, new_ax = make_idx(data, pair.second, length(data))
new_ax = Axis(NamedTuple{tuple(pair.first)}(tuple(new_ax)))
new_ax = merge(getaxes(x)[1], new_ax)
return ComponentArray(new_data, new_ax)
end
function _update_field(x, pair)
x_copy = copy(x)
x_copy[pair.first] = pair.second
return x_copy
end
pushcat!(a, b) = reduce((x1,x2) -> push!(x1,x2), b; init=a)
# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
shapes = filter_by_type(ShapedAxis, axs...) .|> size
shapes = reduce((tup, s) -> (tup..., s...), shapes)
return reshape(data, shapes)
end
# Recurse through nested ViewAxis types to find the last index
last_index(x) = last(x)
last_index(x::ViewAxis) = last_index(viewindex(x))
last_index(x::AbstractAxis) = last_index(last(indexmap(x)))
# length information is in Axis, use it to make SVector creation type stable
@inline _hasNullOrFlatAxis(ca) = any(map(ax -> ax isa NullorFlatAxis, getaxes(ca)))
function Base.length(ca::ComponentArray)
# vca2 = vcat(ca2', ca2') #has not length - is it a valid ComponentVector
# or rather a Vector<ComponentVector>
_hasNullOrFlatAxis(ca) && return(length(getdata(ca)))
prod(length.(getaxes(ca)))
end
function Base.size(ca::ComponentArray)
_hasNullOrFlatAxis(ca) && return(size(getdata(ca)))
map(length, getaxes(ca))
end
# Reduce singleton dimensions
remove_nulls() = ()
remove_nulls(x1, args...) = (x1, remove_nulls(args...)...)
remove_nulls(::NullAxis, args...) = (remove_nulls(args...)...,)
## Attributes
"""
getdata(x::ComponentArray)
Access ```.data``` field of a ```ComponentArray```, which contains the array that ```ComponentArray``` wraps.
"""
@inline getdata(x::ComponentArray) = getfield(x, :data)
@inline getdata(x) = x
@inline getdata(x::Adjoint) = getdata(x.parent)'
@inline getdata(x::Transpose) = transpose(getdata(x.parent))
"""
getaxes(x::ComponentArray)
Access ```.axes``` field of a ```ComponentArray```. This is different than ```axes(x::ComponentArray)```, which
returns the axes of the contained array.
# Examples
```jldoctest
julia> using ComponentArrays
julia> ax = Axis(a=1:3, b=(4:6, (a=1, b=2:3)))
Axis(a = 1:3, b = (4:6, (a = 1, b = 2:3)))
julia> A = zeros(6,6);
julia> ca = ComponentArray(A, (ax, ax))
6×6 ComponentMatrix{Float64} with axes Axis(a = 1:3, b = (4:6, (a = 1, b = 2:3))) × Axis(a = 1:3, b = (4:6, (a = 1, b = 2:3)))
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
julia> getaxes(ca)
(Axis(a = 1:3, b = (4:6, (a = 1, b = 2:3))), Axis(a = 1:3, b = (4:6, (a = 1, b = 2:3))))
```
"""
@inline getaxes(x::ComponentArray) = getfield(x, :axes)
@inline getaxes(x::AdjOrTrans{T, <:ComponentVector}) where T = (FlatAxis(), getaxes(x.parent)[1])
@inline getaxes(x::AdjOrTrans{T, <:ComponentMatrix}) where T = reverse(getaxes(x.parent))
@inline getaxes(::Type{<:ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = map(x->x(), (Axes.types...,))
@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentVector} = (FlatAxis(), getaxes(CA)[1]) |> typeof
@inline getaxes(::Type{<:AdjOrTrans{T,CA}}) where {T,CA<:ComponentMatrix} = reverse(getaxes(CA)) |> typeof
## Field access through these functions to reserve dot-getting for keys
@inline getaxes(x::VarAxes) = getaxes(typeof(x))
@inline getaxes(Ax::Type{Axes}) where {Axes<:VarAxes} = map(x->x(), (Ax.types...,))
getaxes(x) = ()
"""
valkeys(x::ComponentVector)
valkeys(x::AbstractAxis)
Returns `Val`-wrapped keys of `ComponentVector` for fast iteration over component keys. Also works
directly on an `AbstractAxis`.
# Examples
```julia-repl
julia> using ComponentArrays
julia> ca = ComponentArray(a=1, b=[1,2,3], c=(a=4,))
ComponentVector{Int64}(a = 1, b = [1, 2, 3], c = (a = 4))
julia> [ca[k] for k in valkeys(ca)]
3-element Array{Any,1}:
1
[1, 2, 3]
ComponentVector{Int64,SubArray...}(a = 4)
julia> sum(prod(ca[k]) for k in valkeys(ca))
11
```
"""
@generated function valkeys(ax::AbstractAxis)
idxmap = indexmap(ax)
k = Val.(keys(idxmap))
return :($k)
end
valkeys(ca::ComponentVector) = valkeys(getaxes(ca)[1])