Skip to content

Commit 8540bde

Browse files
authored
Fix *cat inconsistencies
1 parent cbb24ef commit 8540bde

File tree

1 file changed

+16
-42
lines changed

1 file changed

+16
-42
lines changed

src/array_interface.jl

+16-42
Original file line numberDiff line numberDiff line change
@@ -16,53 +16,27 @@ ArrayInterfaceCore.indices_do_not_alias(::Type{ComponentArray{T,N,A,Axes}}) wher
1616
ArrayInterfaceCore.instances_do_not_alias(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = ArrayInterfaceCore.instances_do_not_alias(A)
1717

1818
# Cats
19-
# TODO: Make this a little less copy-pastey
20-
function Base.hcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat)
21-
ax_x, ax_y = second_axis.((x,y))
22-
if reduce((accum, key) -> accum || (key in keys(ax_x)), keys(ax_y); init=false) || getaxes(x)[1] != getaxes(y)[1]
23-
return hcat(getdata(x), getdata(y))
19+
function Base.cat(inputs::ComponentArray...; dims::Int)
20+
combined_data = cat(getdata.(inputs)...; dims=dims)
21+
axes_to_merge = [(getaxes(i)..., FlatAxis())[dims] for i in inputs]
22+
rest_axes = [getaxes(i)[1:end .!= dims] for i in inputs]
23+
no_duplicate_keys = (length(inputs) == 1 || isempty(intersect(keys.(axes_to_merge)...)))
24+
if no_duplicate_keys && length(Set(rest_axes)) == 1
25+
offsets = cumsum(size.(inputs, 1) .- size(first(inputs), 1))
26+
merged_axis = Axis(merge(indexmap.(reindex.(axes_to_merge, offsets))...))
27+
result_axes = (first(rest_axes)[1:(dims - 1)]..., merged_axis, first(rest_axes)[dims:end]...)
28+
return ComponentArray(combined_data, result_axes...)
2429
else
25-
data_x, data_y = getdata.((x, y))
26-
ax_y = reindex(ax_y, size(x,2))
27-
idxmap_x, idxmap_y = indexmap.((ax_x, ax_y))
28-
axs = getaxes(x)
29-
return ComponentArray(hcat(data_x, data_y), axs[1], Axis((;idxmap_x..., idxmap_y...)), axs[3:end]...)
30+
return combined_data
3031
end
3132
end
3233

33-
second_axis(ca::AbstractComponentVecOrMat) = getaxes(ca)[2]
34-
second_axis(::ComponentVector) = FlatAxis()
35-
36-
# Are all these methods necessary?
37-
# TODO: See what we can reduce down to without getting ambiguity errors
38-
Base.vcat(x::ComponentVector, y::AbstractVector) = vcat(getdata(x), y)
39-
Base.vcat(x::AbstractVector, y::ComponentVector) = vcat(x, getdata(y))
40-
function Base.vcat(x::ComponentVector, y::ComponentVector)
41-
if reduce((accum, key) -> accum || (key in keys(x)), keys(y); init=false)
42-
return vcat(getdata(x), getdata(y))
43-
else
44-
data_x, data_y = getdata.((x, y))
45-
ax_x, ax_y = getindex.(getaxes.((x, y)), 1)
46-
ax_y = reindex(ax_y, length(x))
47-
idxmap_x, idxmap_y = indexmap.((ax_x, ax_y))
48-
return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)))
49-
end
34+
function Base._typed_hcat(::Type{T}, inputs::Base.AbstractVecOrTuple{ComponentArray}) where {T}
35+
return Base.cat(map(i -> T.(i), inputs)...; dims=2)
5036
end
51-
function Base.vcat(x::AbstractComponentVecOrMat, y::AbstractComponentVecOrMat)
52-
ax_x, ax_y = getindex.(getaxes.((x, y)), 1)
53-
if reduce((accum, key) -> accum || (key in keys(ax_x)), keys(ax_y); init=false) || getaxes(x)[2:end] != getaxes(y)[2:end]
54-
return vcat(getdata(x), getdata(y))
55-
else
56-
data_x, data_y = getdata.((x, y))
57-
ax_y = reindex(ax_y, size(x,1))
58-
idxmap_x, idxmap_y = indexmap.((ax_x, ax_y))
59-
return ComponentArray(vcat(data_x, data_y), Axis((;idxmap_x..., idxmap_y...)), getaxes(x)[2:end]...)
60-
end
37+
function Base._typed_vcat(::Type{T}, inputs::Base.AbstractVecOrTuple{ComponentArray}) where {T}
38+
return Base.cat(map(i -> T.(i), inputs)...; dims=1)
6139
end
62-
Base.vcat(x::CV...) where {CV<:AdjOrTransComponentArray} = ComponentArray(reduce(vcat, map(y->getdata(y.parent)', x)), getaxes(x[1]))
63-
Base.vcat(x::ComponentVector, args...) = vcat(getdata(x), getdata.(args)...)
64-
Base.vcat(x::ComponentVector, args::Union{Number, UniformScaling, AbstractVecOrMat}...) = vcat(getdata(x), getdata.(args)...)
65-
Base.vcat(x::ComponentVector, args::Vararg{AbstractVector{T}, N}) where {T,N} = vcat(getdata(x), getdata.(args)...)
6640

6741
function Base.hvcat(row_lengths::NTuple{N,Int}, xs::AbstractComponentVecOrMat...) where {N}
6842
i = 1
@@ -145,4 +119,4 @@ end
145119
Base.stride(x::ComponentArray, k) = stride(getdata(x), k)
146120
Base.stride(x::ComponentArray, k::Int64) = stride(getdata(x), k)
147121

148-
ArrayInterfaceCore.parent_type(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = A
122+
ArrayInterfaceCore.parent_type(::Type{ComponentArray{T,N,A,Axes}}) where {T,N,A,Axes} = A

0 commit comments

Comments
 (0)