@@ -497,33 +497,53 @@ end
497497import Base. Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498498using Base. Broadcast: combine_styles
499499
500- struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
500+ @static if fieldcount(Base. Broadcast. Broadcasted) == 4
501+ struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
502+ style:: S
503+ StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}(style)
504+ end
505+ StructArrayStyle{N}(style:: StructArrayStyle ) where {N} = StructArrayStyle{N}(style. style)
506+ parent_style(s:: BroadcastStyle ) = s
507+ parent_style(s:: StructArrayStyle ) = s. style
508+ style(bc:: Broadcasted ) = bc. style
509+ const broadcasted = Broadcasted
510+ else
511+ struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
512+ StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}()
513+ end
514+ StructArrayStyle{N}(style:: StructArrayStyle{M, S} ) where {N, M, S} = StructArrayStyle{N}(S())
515+ parent_style(s:: BroadcastStyle ) = s
516+ parent_style(:: StructArrayStyle{N, S} ) = S()
517+ style(:: Broadcasted{Style} ) where {Style} = Style()
518+ broadcasted(s, f, args, axes) = Broadcasted{typeof(s)}(f, args, axes)
519+ end
520+ StructArrayStyle{N, S}() where {N, S} = StructArrayStyle{N}(S())
521+ parent_style(bc:: Broadcasted ) = parent_style(style(bc))
522+ ofstyle(s, bc:: Broadcasted ) = broadcasted(s, bc. f, bc. args, bc. axes)
501523
502524# Here we define the dimension tracking behavior of StructArrayStyle
503- function StructArrayStyle{S, M }(:: Val{N} ) where {S, M, N}
525+ function StructArrayStyle{M, S }(:: Val{N} ) where {S, M, N}
504526 T = S <: AbstractArrayStyle{M} ? typeof(S(Val{N}())) : S
505- return StructArrayStyle{T, N }()
527+ return StructArrayStyle{N, T }()
506528end
507529
508530# StructArrayStyle is a wrapped style.
509531# Here we try our best to resolve style conflict.
510- function BroadcastStyle(b:: AbstractArrayStyle{M} , a:: StructArrayStyle{S, N } ) where {S, N, M}
532+ function BroadcastStyle(b:: AbstractArrayStyle{M} , a:: StructArrayStyle{N, S } ) where {S, N, M}
511533 N′ = M === Any || N === Any ? Any : max(M, N)
512- S′ = Broadcast. result_style(S(), b)
513- return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
534+ return StructArrayStyle{N′}(Broadcast. result_style(parent_style(a), b))
514535end
515536BroadcastStyle(:: StructArrayStyle , :: DefaultArrayStyle ) = Unknown()
516537
517538@inline combine_style_types(:: Type{A} , args... ) where {A<: AbstractArray } =
518539 combine_style_types(BroadcastStyle(A), args... )
519540@inline combine_style_types(s:: BroadcastStyle , :: Type{A} , args... ) where {A<: AbstractArray } =
520541 combine_style_types(Broadcast. result_style(s, BroadcastStyle(A)), args... )
521- combine_style_types(:: StructArrayStyle{S} ) where {S} = S() # avoid nested StructArrayStyle
522542combine_style_types(s:: BroadcastStyle ) = s
523543
524544Base. @pure cst(:: Type{SA} ) where {SA} = combine_style_types(array_types(SA). parameters... )
525545
526- BroadcastStyle(:: Type{SA} ) where {SA<: StructArray } = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
546+ BroadcastStyle(:: Type{SA} ) where {SA<: StructArray } = StructArrayStyle{ndims(SA)}(cst(SA) )
527547
528548"""
529549 always_struct_broadcast(style::BroadcastStyle)
@@ -551,8 +571,8 @@ See also [`always_struct_broadcast`](@ref).
551571"""
552572try_struct_copy(bc:: Broadcasted ) = copy(bc)
553573
554- function Base. copy(bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
555- if always_struct_broadcast(S( ))
574+ function Base. copy(bc:: Broadcasted{<: StructArrayStyle} )
575+ if always_struct_broadcast(parent_style(bc ))
556576 return invoke(copy, Tuple{Broadcasted}, bc)
557577 else
558578 return try_struct_copy(replace_structarray(bc))
@@ -567,55 +587,49 @@ an equivalent one without it. This is not a must if the root `BroadcastStyle`
567587supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
568588e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
569589"""
570- function replace_structarray(bc:: Broadcasted{Style} ) where {Style}
590+ function replace_structarray(bc:: Broadcasted )
571591 args = replace_structarray_args(bc. args)
572- Style′ = parent_style(Style() )
573- return Broadcasted{Style′}( bc. f, args, bc. axes)
592+ style = parent_style(bc )
593+ return broadcasted(style, bc. f, args, bc. axes)
574594end
575595function replace_structarray(A:: StructArray )
576596 f = Instantiator(eltype(A))
577597 args = Tuple(components(A))
578- Style = typeof( combine_styles(args... ) )
579- return Broadcasted{Style}( f, args, axes(A))
598+ style = combine_styles(args... )
599+ return broadcasted(style, f, args, axes(A))
580600end
581601replace_structarray(@nospecialize(A)) = A
582602
583603replace_structarray_args(args:: Tuple ) = (replace_structarray(args[1 ]), replace_structarray_args(tail(args)). .. )
584604replace_structarray_args(:: Tuple{} ) = ()
585605
586- parent_style(@nospecialize(x)) = typeof(x)
587- parent_style(:: StructArrayStyle{S, N} ) where {S, N} = S
588- parent_style(:: StructArrayStyle{S, N} ) where {N, S<: AbstractArrayStyle{N} } = S
589- parent_style(:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle{Any} , N} = S
590- parent_style(:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle , N} = typeof(S(Val(N)))
591-
592606# `instantiate` and `_axes` might be overloaded for static axes.
593- function Broadcast. instantiate(bc:: Broadcasted{Style} ) where {Style <: StructArrayStyle }
594- Style′ = parent_style(Style())
595- bc′ = Broadcast. instantiate(convert(Broadcasted{Style′}, bc))
596- return convert(Broadcasted{Style}, bc′)
607+ function Broadcast. instantiate(bc:: Broadcasted{<:StructArrayStyle} )
608+ bc′ = Broadcast. instantiate(ofstyle(parent_style(bc), bc))
609+ return ofstyle(style(bc), bc′)
597610end
598611
599- function Broadcast. _axes(bc:: Broadcasted{Style} , :: Nothing ) where {Style <: StructArrayStyle }
600- Style′ = parent_style(Style())
601- return Broadcast. _axes(convert(Broadcasted{Style′}, bc), nothing )
612+ function Broadcast. _axes(bc:: Broadcasted{<:StructArrayStyle} , :: Nothing )
613+ return Broadcast. _axes(ofstyle(parent_style(bc), bc), nothing )
602614end
603615
604616# Here we use `similar` defined for `S` to build the dest Array.
605- function Base. similar(bc:: Broadcasted{StructArrayStyle{S, N}} , :: Type{ElType} ) where {S, N, ElType}
606- bc′ = convert(Broadcasted{S} , bc)
617+ function Base. similar(bc:: Broadcasted{<: StructArrayStyle} , :: Type{ElType} ) where {ElType}
618+ bc′ = ofstyle(parent_style(bc) , bc)
607619 return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
608620end
609621
610622# Unwrapper to recover the behaviour defined by parent style.
611- @inline function Base. copyto!(dest:: AbstractArray , bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
612- bc′ = always_struct_broadcast(S()) ? convert(Broadcasted{S}, bc) : replace_structarray(bc)
623+ @inline function Base. copyto!(dest:: AbstractArray , bc:: Broadcasted{<:StructArrayStyle} )
624+ ps = parent_style(bc)
625+ bc′ = always_struct_broadcast(ps) ? ofstyle(ps, bc) : replace_structarray(bc)
613626 return copyto!(dest, bc′)
614627end
615628
616- @inline function Broadcast. materialize!(:: StructArrayStyle{S} , dest, bc:: Broadcasted ) where {S}
617- bc′ = always_struct_broadcast(S()) ? bc : replace_structarray(bc)
618- return Broadcast. materialize!(S(), dest, bc′)
629+ @inline function Broadcast. materialize!(s:: StructArrayStyle , dest, bc:: Broadcasted )
630+ ps = parent_style(s)
631+ bc′ = always_struct_broadcast(ps) ? bc : replace_structarray(bc)
632+ return Broadcast. materialize!(ps, dest, bc′)
619633end
620634
621635# for aliasing analysis during broadcast
0 commit comments