@@ -16,53 +16,27 @@ ArrayInterfaceCore.indices_do_not_alias(::Type{ComponentArray{T,N,A,Axes}}) wher
16
16
ArrayInterfaceCore. instances_do_not_alias (:: Type{ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = ArrayInterfaceCore. instances_do_not_alias (A)
17
17
18
18
# Cats
19
- # TODO : Make this a little less copy-pastey
20
- function Base. hcat (x:: AbstractComponentVecOrMat , y:: AbstractComponentVecOrMat )
21
- ax_x, ax_y = second_axis .((x,y))
22
- if reduce ((accum, key) -> accum || (key in keys (ax_x)), keys (ax_y); init= false ) || getaxes (x)[1 ] != getaxes (y)[1 ]
23
- return hcat (getdata (x), getdata (y))
19
+ function Base. cat (inputs:: ComponentArray... ; dims:: Int )
20
+ combined_data = cat (getdata .(inputs)... ; dims= dims)
21
+ axes_to_merge = [(getaxes (i)... , FlatAxis ())[dims] for i in inputs]
22
+ rest_axes = [getaxes (i)[1 : end .!= dims] for i in inputs]
23
+ no_duplicate_keys = (length (inputs) == 1 || isempty (intersect (keys .(axes_to_merge)... )))
24
+ if no_duplicate_keys && length (Set (rest_axes)) == 1
25
+ offsets = cumsum (size .(inputs, 1 ) .- size (first (inputs), 1 ))
26
+ merged_axis = Axis (merge (indexmap .(reindex .(axes_to_merge, offsets))... ))
27
+ result_axes = (first (rest_axes)[1 : (dims - 1 )]. .. , merged_axis, first (rest_axes)[dims: end ]. .. )
28
+ return ComponentArray (combined_data, result_axes... )
24
29
else
25
- data_x, data_y = getdata .((x, y))
26
- ax_y = reindex (ax_y, size (x,2 ))
27
- idxmap_x, idxmap_y = indexmap .((ax_x, ax_y))
28
- axs = getaxes (x)
29
- return ComponentArray (hcat (data_x, data_y), axs[1 ], Axis ((;idxmap_x... , idxmap_y... )), axs[3 : end ]. .. )
30
+ return combined_data
30
31
end
31
32
end
32
33
33
- second_axis (ca:: AbstractComponentVecOrMat ) = getaxes (ca)[2 ]
34
- second_axis (:: ComponentVector ) = FlatAxis ()
35
-
36
- # Are all these methods necessary?
37
- # TODO : See what we can reduce down to without getting ambiguity errors
38
- Base. vcat (x:: ComponentVector , y:: AbstractVector ) = vcat (getdata (x), y)
39
- Base. vcat (x:: AbstractVector , y:: ComponentVector ) = vcat (x, getdata (y))
40
- function Base. vcat (x:: ComponentVector , y:: ComponentVector )
41
- if reduce ((accum, key) -> accum || (key in keys (x)), keys (y); init= false )
42
- return vcat (getdata (x), getdata (y))
43
- else
44
- data_x, data_y = getdata .((x, y))
45
- ax_x, ax_y = getindex .(getaxes .((x, y)), 1 )
46
- ax_y = reindex (ax_y, length (x))
47
- idxmap_x, idxmap_y = indexmap .((ax_x, ax_y))
48
- return ComponentArray (vcat (data_x, data_y), Axis ((;idxmap_x... , idxmap_y... )))
49
- end
34
+ function Base. _typed_hcat (:: Type{T} , inputs:: Base.AbstractVecOrTuple{ComponentArray} ) where {T}
35
+ return Base. cat (map (i -> T .(i), inputs)... ; dims= 2 )
50
36
end
51
- function Base. vcat (x:: AbstractComponentVecOrMat , y:: AbstractComponentVecOrMat )
52
- ax_x, ax_y = getindex .(getaxes .((x, y)), 1 )
53
- if reduce ((accum, key) -> accum || (key in keys (ax_x)), keys (ax_y); init= false ) || getaxes (x)[2 : end ] != getaxes (y)[2 : end ]
54
- return vcat (getdata (x), getdata (y))
55
- else
56
- data_x, data_y = getdata .((x, y))
57
- ax_y = reindex (ax_y, size (x,1 ))
58
- idxmap_x, idxmap_y = indexmap .((ax_x, ax_y))
59
- return ComponentArray (vcat (data_x, data_y), Axis ((;idxmap_x... , idxmap_y... )), getaxes (x)[2 : end ]. .. )
60
- end
37
+ function Base. _typed_vcat (:: Type{T} , inputs:: Base.AbstractVecOrTuple{ComponentArray} ) where {T}
38
+ return Base. cat (map (i -> T .(i), inputs)... ; dims= 1 )
61
39
end
62
- Base. vcat (x:: CV... ) where {CV<: AdjOrTransComponentArray } = ComponentArray (reduce (vcat, map (y-> getdata (y. parent)' , x)), getaxes (x[1 ]))
63
- Base. vcat (x:: ComponentVector , args... ) = vcat (getdata (x), getdata .(args)... )
64
- Base. vcat (x:: ComponentVector , args:: Union{Number, UniformScaling, AbstractVecOrMat} ...) = vcat (getdata (x), getdata .(args)... )
65
- Base. vcat (x:: ComponentVector , args:: Vararg{AbstractVector{T}, N} ) where {T,N} = vcat (getdata (x), getdata .(args)... )
66
40
67
41
function Base. hvcat (row_lengths:: NTuple{N,Int} , xs:: AbstractComponentVecOrMat... ) where {N}
68
42
i = 1
145
119
Base. stride (x:: ComponentArray , k) = stride (getdata (x), k)
146
120
Base. stride (x:: ComponentArray , k:: Int64 ) = stride (getdata (x), k)
147
121
148
- ArrayInterfaceCore. parent_type (:: Type{ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = A
122
+ ArrayInterfaceCore. parent_type (:: Type{ComponentArray{T,N,A,Axes}} ) where {T,N,A,Axes} = A
0 commit comments