Skip to content

Commit 5c0ae3d

Browse files
committed
Update to the latest broadcast implement.
On master `Broadcasted` store style by field. Update accordingly.
1 parent 99f0556 commit 5c0ae3d

File tree

2 files changed

+52
-38
lines changed

2 files changed

+52
-38
lines changed

src/structarray.jl

+49-35
Original file line numberDiff line numberDiff line change
@@ -497,33 +497,53 @@ end
497497
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498498
using Base.Broadcast: combine_styles
499499

500-
struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
500+
@static if fieldcount(Base.Broadcast.Broadcasted) == 4
501+
struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
502+
style::S
503+
StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}(style)
504+
end
505+
StructArrayStyle{N}(style::StructArrayStyle) where {N} = StructArrayStyle{N}(style.style)
506+
parent_style(s::BroadcastStyle) = s
507+
parent_style(s::StructArrayStyle) = s.style
508+
style(bc::Broadcasted) = bc.style
509+
const broadcasted = Broadcasted
510+
else
511+
struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
512+
StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}()
513+
end
514+
StructArrayStyle{N}(style::StructArrayStyle{M, S}) where {N, M, S} = StructArrayStyle{N}(S())
515+
parent_style(s::BroadcastStyle) = s
516+
parent_style(::StructArrayStyle{N, S}) where {N, S} = S()
517+
style(::Broadcasted{Style}) where {Style} = Style()
518+
broadcasted(s, f, args, axes) = Broadcasted{typeof(s)}(f, args, axes)
519+
end
520+
StructArrayStyle{N, S}() where {N, S} = StructArrayStyle{N}(S())
521+
parent_style(bc::Broadcasted) = parent_style(style(bc))
522+
ofstyle(s, bc::Broadcasted) = broadcasted(s, bc.f, bc.args, bc.axes)
501523

502524
# Here we define the dimension tracking behavior of StructArrayStyle
503-
function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
525+
function StructArrayStyle{M, S}(::Val{N}) where {S, M, N}
504526
T = S <: AbstractArrayStyle{M} ? typeof(S(Val{N}())) : S
505-
return StructArrayStyle{T, N}()
527+
return StructArrayStyle{N, T}()
506528
end
507529

508530
# StructArrayStyle is a wrapped style.
509531
# Here we try our best to resolve style conflict.
510-
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
532+
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{N, S}) where {S, N, M}
511533
N′ = M === Any || N === Any ? Any : max(M, N)
512-
S′ = Broadcast.result_style(S(), b)
513-
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
534+
return StructArrayStyle{N′}(Broadcast.result_style(parent_style(a), b))
514535
end
515536
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()
516537

517538
@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
518539
combine_style_types(BroadcastStyle(A), args...)
519540
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
520541
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
521-
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
522542
combine_style_types(s::BroadcastStyle) = s
523543

524544
Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)
525545

526-
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
546+
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{ndims(SA)}(cst(SA))
527547

528548
"""
529549
always_struct_broadcast(style::BroadcastStyle)
@@ -551,8 +571,8 @@ See also [`always_struct_broadcast`](@ref).
551571
"""
552572
try_struct_copy(bc::Broadcasted) = copy(bc)
553573

554-
function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
555-
if always_struct_broadcast(S())
574+
function Base.copy(bc::Broadcasted{<:StructArrayStyle})
575+
if always_struct_broadcast(parent_style(bc))
556576
return invoke(copy, Tuple{Broadcasted}, bc)
557577
else
558578
return try_struct_copy(replace_structarray(bc))
@@ -567,55 +587,49 @@ an equivalent one without it. This is not a must if the root `BroadcastStyle`
567587
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
568588
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
569589
"""
570-
function replace_structarray(bc::Broadcasted{Style}) where {Style}
590+
function replace_structarray(bc::Broadcasted)
571591
args = replace_structarray_args(bc.args)
572-
Style′ = parent_style(Style())
573-
return Broadcasted{Style′}(bc.f, args, bc.axes)
592+
style = parent_style(bc)
593+
return broadcasted(style, bc.f, args, bc.axes)
574594
end
575595
function replace_structarray(A::StructArray)
576596
f = Instantiator(eltype(A))
577597
args = Tuple(components(A))
578-
Style = typeof(combine_styles(args...))
579-
return Broadcasted{Style}(f, args, axes(A))
598+
style = combine_styles(args...)
599+
return broadcasted(style, f, args, axes(A))
580600
end
581601
replace_structarray(@nospecialize(A)) = A
582602

583603
replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
584604
replace_structarray_args(::Tuple{}) = ()
585605

586-
parent_style(@nospecialize(x)) = typeof(x)
587-
parent_style(::StructArrayStyle{S, N}) where {S, N} = S
588-
parent_style(::StructArrayStyle{S, N}) where {N, S<:AbstractArrayStyle{N}} = S
589-
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle{Any}, N} = S
590-
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle, N} = typeof(S(Val(N)))
591-
592606
# `instantiate` and `_axes` might be overloaded for static axes.
593-
function Broadcast.instantiate(bc::Broadcasted{Style}) where {Style <: StructArrayStyle}
594-
Style′ = parent_style(Style())
595-
bc′ = Broadcast.instantiate(convert(Broadcasted{Style′}, bc))
596-
return convert(Broadcasted{Style}, bc′)
607+
function Broadcast.instantiate(bc::Broadcasted{<:StructArrayStyle})
608+
bc′ = Broadcast.instantiate(ofstyle(parent_style(bc), bc))
609+
return ofstyle(style(bc), bc′)
597610
end
598611

599-
function Broadcast._axes(bc::Broadcasted{Style}, ::Nothing) where {Style <: StructArrayStyle}
600-
Style′ = parent_style(Style())
601-
return Broadcast._axes(convert(Broadcasted{Style′}, bc), nothing)
612+
function Broadcast._axes(bc::Broadcasted{<:StructArrayStyle}, ::Nothing)
613+
return Broadcast._axes(ofstyle(parent_style(bc), bc), nothing)
602614
end
603615

604616
# Here we use `similar` defined for `S` to build the dest Array.
605-
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
606-
bc′ = convert(Broadcasted{S}, bc)
617+
function Base.similar(bc::Broadcasted{<:StructArrayStyle}, ::Type{ElType}) where {ElType}
618+
bc′ = ofstyle(parent_style(bc), bc)
607619
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
608620
end
609621

610622
# Unwrapper to recover the behaviour defined by parent style.
611-
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
612-
bc′ = always_struct_broadcast(S()) ? convert(Broadcasted{S}, bc) : replace_structarray(bc)
623+
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:StructArrayStyle})
624+
ps = parent_style(bc)
625+
bc′ = always_struct_broadcast(ps) ? ofstyle(ps, bc) : replace_structarray(bc)
613626
return copyto!(dest, bc′)
614627
end
615628

616-
@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
617-
bc′ = always_struct_broadcast(S()) ? bc : replace_structarray(bc)
618-
return Broadcast.materialize!(S(), dest, bc′)
629+
@inline function Broadcast.materialize!(s::StructArrayStyle, dest, bc::Broadcasted)
630+
ps = parent_style(s)
631+
bc′ = always_struct_broadcast(ps) ? bc : replace_structarray(bc)
632+
return Broadcast.materialize!(ps, dest, bc′)
619633
end
620634

621635
# for aliasing analysis during broadcast

test/runtests.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1227,7 +1227,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
12271227
ares = map(a->a.re, as)
12281228
aims = map(a->a.im, as)
12291229
style = Broadcast.combine_styles(ares...)
1230-
@test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}()
1230+
@test Broadcast.combine_styles(as...) === StructArrayStyle{1,typeof(style)}()
12311231
if !(style in tested_style)
12321232
push!(tested_style, style)
12331233
if style isa Broadcast.ArrayStyle{MyArray3}
@@ -1249,8 +1249,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
12491249
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
12501250

12511251
#parent_style
1252-
@test StructArrays.parent_style(StructArrayStyle{Broadcast.DefaultArrayStyle{0},2}()) == Broadcast.DefaultArrayStyle{2}
1253-
@test StructArrays.parent_style(StructArrayStyle{Broadcast.Style{Tuple},2}()) == Broadcast.Style{Tuple}
1252+
@test StructArrays.parent_style(StructArrayStyle{2,Broadcast.DefaultArrayStyle{0}}()) == Broadcast.DefaultArrayStyle{0}()
1253+
@test StructArrays.parent_style(StructArrayStyle{2,Broadcast.Style{Tuple}}()) == Broadcast.Style{Tuple}()
12541254

12551255
# allocation test for overloaded `broadcast_unalias`
12561256
StructArrays.always_struct_broadcast(::Broadcast.ArrayStyle{MyArray1}) = false

0 commit comments

Comments
 (0)