Skip to content

Commit b69b48f

Browse files
authored
fix broadcast_dims with groupby (#684)
* fix broadcastdims after broupby * use OpaqueArray
1 parent c09557d commit b69b48f

File tree

4 files changed

+51
-11
lines changed

4 files changed

+51
-11
lines changed

src/Dimensions/format.jl

+6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ function format(dims::Tuple{<:Pair,Vararg{Pair}}, A::AbstractArray)
2626
end
2727
return format(dims, A)
2828
end
29+
# Make a dummy array that assumes the dims are the correct length and don't hold `Colon`s
30+
function format(dims::DimTuple)
31+
ax = map(parent first axes, dims)
32+
A = CartesianIndices(ax)
33+
return format(dims, A)
34+
end
2935
format(dims::Tuple{Vararg{Any,N}}, A::AbstractArray{<:Any,N}) where N = format(dims, axes(A))
3036
@noinline format(dims::Tuple{Vararg{Any,M}}, A::AbstractArray{<:Any,N}) where {N,M} =
3137
throw(DimensionMismatch("Array A has $N axes, while the number of dims is $M: $(map(basetypeof, dims))"))

src/dimindices.jl

+11-3
Original file line numberDiff line numberDiff line change
@@ -294,12 +294,20 @@ struct DimSlices{T,N,D<:Tuple{Vararg{Dimension}},P} <: AbstractDimArrayGenerator
294294
end
295295
DimSlices(x; dims, drop=true) = DimSlices(x, dims; drop)
296296
function DimSlices(x, dims; drop=true)
297-
newdims = length(dims) == 0 ? map(d -> rebuild(d, :), DD.dims(x)) : dims
298-
inds = map(d -> rebuild(d, first(axes(x, d))), newdims)
297+
newdims = if length(dims) == 0
298+
map(d -> rebuild(d, :), DD.dims(x))
299+
else
300+
dims
301+
end
302+
inds = map(basedims(newdims)) do d
303+
rebuild(d, first(axes(x, d)))
304+
end
305+
# `getindex` returns these views
299306
T = typeof(view(x, inds...))
300307
N = length(newdims)
301308
D = typeof(newdims)
302-
return DimSlices{T,N,D,typeof(x)}(x, newdims)
309+
P = typeof(x)
310+
return DimSlices{T,N,D,P}(x, newdims)
303311
end
304312

305313
rebuild(ds::A; dims) where {A<:DimSlices{T,N}} where {T,N} =

src/groupby.jl

+17-5
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ end
4646
rebuild(A, data, dims, refdims, name, metadata) # Rebuild as a reguilar DimArray
4747
end
4848

49-
function Base.summary(io::IO, A::DimGroupByArray{T,N}) where {T,N}
49+
function Base.summary(io::IO, A::DimGroupByArray{T,N}) where {T<:AbstractArray{T1,N1},N} where {T1,N1}
5050
print_ndims(io, size(A))
51-
print(io, string(nameof(typeof(A)), "{$(nameof(T)),$N}"))
51+
print(io, string(nameof(typeof(A)), "{$(nameof(T)){$T1,$N1},$N}"))
5252
end
5353

5454
function show_after(io::IO, mime, A::DimGroupByArray)
@@ -80,6 +80,17 @@ function Base.show(io::IO, s::DimSummariser)
8080
end
8181
Base.alignment(io::IO, s::DimSummariser) = (textwidth(sprint(show, s)), 0)
8282

83+
# An array that doesn't know what it holds, to simplify dispatch
84+
struct OpaqueArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
85+
parent::A
86+
end
87+
Base.parent(A::OpaqueArray) = A.parent
88+
Base.size(A::OpaqueArray) = size(parent(A))
89+
for f in (:getindex, :view, :dotview)
90+
@eval Base.$f(A::OpaqueArray, args...) = Base.$f(parent(A), args...)
91+
end
92+
Base.setindex!(A::OpaqueArray, args...) = Base.setindex!(parent(A), args...)
93+
8394

8495
abstract type AbstractBins <: Function end
8596

@@ -331,9 +342,11 @@ function DataAPI.groupby(A::DimArrayOrStack, dimfuncs::DimTuple)
331342
end
332343
# Separate lookups dims from indices
333344
group_dims = map(first, dim_groups_indices)
334-
indices = map(rebuild, dimfuncs, map(last, dim_groups_indices))
345+
# Get indices for each group wrapped with dims for indexing
346+
indices = map(rebuild, group_dims, map(last, dim_groups_indices))
335347

336-
views = DimSlices(A, indices)
348+
# Hide that the parent is a DimSlices
349+
views = OpaqueArray(DimSlices(A, indices))
337350
# Put the groupby query in metadata
338351
meta = map(d -> dim2key(d) => val(d), dimfuncs)
339352
metadata = Dict{Symbol,Any}(:groupby => length(meta) == 1 ? only(meta) : meta)
@@ -394,7 +407,6 @@ function _group_indices(dim::Dimension, bins::AbstractBins; labels=bins.labels)
394407
return _group_indices(transformed_lookup, group_lookup; labels)
395408
end
396409

397-
398410
# Get a vector of intervals for the bins
399411
_groups_from(_, bins::Bins{<:Any,<:AbstractArray}) = bins.bins
400412
function _groups_from(transformed, bins::Bins{<:Any,<:Integer})

test/groupby.jl

+17-3
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,22 @@ end
6767
end
6868
end
6969
@test all(collect(mean.(gb)) .=== manualmeans)
70-
@test all(
71-
mean.(gb) .=== manualmeans
72-
)
70+
@test all(mean.(gb) .=== manualmeans)
7371
end
7472

73+
@testset "broadcastdims runs after groupby" begin
74+
dimlist = (
75+
Ti(Date("2021-12-01"):Day(1):Date("2022-12-31")),
76+
X(range(1, 10, length=10)),
77+
Y(range(1, 5, length=15)),
78+
Dim{:Variable}(["var1", "var2"])
79+
)
80+
data = rand(396, 10, 15, 2)
81+
A = DimArray(data, dimlist)
82+
month_length = DimArray(daysinmonth, dims(A, Ti))
83+
g_tempo = DimensionalData.groupby(month_length, Ti=>seasons(; start=December))
84+
sum_days = sum.(g_tempo, dims=Ti)
85+
weights = map(./, g_tempo, sum_days)
86+
G = DimensionalData.groupby(A, Ti=>seasons(; start=December))
87+
G_w = broadcast_dims.(*, weights, G)
88+
end

0 commit comments

Comments
 (0)