@@ -497,33 +497,53 @@ end
497
497
import Base. Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
498
498
using Base. Broadcast: combine_styles
499
499
500
- struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
500
+ @static if fieldcount (Base. Broadcast. Broadcasted) == 4
501
+ struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
502
+ style:: S
503
+ StructArrayStyle {N} (style) where {N} = new {N, typeof(style)} (style)
504
+ end
505
+ StructArrayStyle {N} (style:: StructArrayStyle ) where {N} = StructArrayStyle {N} (style. style)
506
+ parent_style (s:: BroadcastStyle ) = s
507
+ parent_style (s:: StructArrayStyle ) = s. style
508
+ style (bc:: Broadcasted ) = bc. style
509
+ const broadcasted = Broadcasted
510
+ else
511
+ struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
512
+ StructArrayStyle {N} (style) where {N} = new {N, typeof(style)} ()
513
+ end
514
+ StructArrayStyle {N} (style:: StructArrayStyle{M, S} ) where {N, M, S} = StructArrayStyle {N} (S ())
515
+ parent_style (s:: BroadcastStyle ) = s
516
+ parent_style (:: StructArrayStyle{N, S} ) where {N, S} = S ()
517
+ style (:: Broadcasted{Style} ) where {Style} = Style ()
518
+ broadcasted (s, f, args, axes) = Broadcasted {typeof(s)} (f, args, axes)
519
+ end
520
+ StructArrayStyle {N, S} () where {N, S} = StructArrayStyle {N} (S ())
521
+ parent_style (bc:: Broadcasted ) = parent_style (style (bc))
522
+ ofstyle (s, bc:: Broadcasted ) = broadcasted (s, bc. f, bc. args, bc. axes)
501
523
502
524
# Here we define the dimension tracking behavior of StructArrayStyle
503
- function StructArrayStyle {S, M } (:: Val{N} ) where {S, M, N}
525
+ function StructArrayStyle {M, S } (:: Val{N} ) where {S, M, N}
504
526
T = S <: AbstractArrayStyle{M} ? typeof (S (Val {N} ())) : S
505
- return StructArrayStyle {T, N } ()
527
+ return StructArrayStyle {N, T } ()
506
528
end
507
529
508
530
# StructArrayStyle is a wrapped style.
509
531
# Here we try our best to resolve style conflict.
510
- function BroadcastStyle (b:: AbstractArrayStyle{M} , a:: StructArrayStyle{S, N } ) where {S, N, M}
532
+ function BroadcastStyle (b:: AbstractArrayStyle{M} , a:: StructArrayStyle{N, S } ) where {S, N, M}
511
533
N′ = M === Any || N === Any ? Any : max (M, N)
512
- S′ = Broadcast. result_style (S (), b)
513
- return S′ isa StructArrayStyle ? typeof (S′)(Val {N′} ()) : StructArrayStyle {typeof(S′), N′} ()
534
+ return StructArrayStyle {N′} (Broadcast. result_style (parent_style (a), b))
514
535
end
515
536
BroadcastStyle (:: StructArrayStyle , :: DefaultArrayStyle ) = Unknown ()
516
537
517
538
@inline combine_style_types (:: Type{A} , args... ) where {A<: AbstractArray } =
518
539
combine_style_types (BroadcastStyle (A), args... )
519
540
@inline combine_style_types (s:: BroadcastStyle , :: Type{A} , args... ) where {A<: AbstractArray } =
520
541
combine_style_types (Broadcast. result_style (s, BroadcastStyle (A)), args... )
521
- combine_style_types (:: StructArrayStyle{S} ) where {S} = S () # avoid nested StructArrayStyle
522
542
combine_style_types (s:: BroadcastStyle ) = s
523
543
524
544
Base. @pure cst (:: Type{SA} ) where {SA} = combine_style_types (array_types (SA). parameters... )
525
545
526
- BroadcastStyle (:: Type{SA} ) where {SA<: StructArray } = StructArrayStyle {typeof(cst(SA)), ndims(SA)} ()
546
+ BroadcastStyle (:: Type{SA} ) where {SA<: StructArray } = StructArrayStyle {ndims(SA)} (cst (SA) )
527
547
528
548
"""
529
549
always_struct_broadcast(style::BroadcastStyle)
@@ -551,8 +571,8 @@ See also [`always_struct_broadcast`](@ref).
551
571
"""
552
572
try_struct_copy (bc:: Broadcasted ) = copy (bc)
553
573
554
- function Base. copy (bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
555
- if always_struct_broadcast (S ( ))
574
+ function Base. copy (bc:: Broadcasted{<: StructArrayStyle} )
575
+ if always_struct_broadcast (parent_style (bc ))
556
576
return invoke (copy, Tuple{Broadcasted}, bc)
557
577
else
558
578
return try_struct_copy (replace_structarray (bc))
@@ -567,55 +587,49 @@ an equivalent one without it. This is not a must if the root `BroadcastStyle`
567
587
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
568
588
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
569
589
"""
570
- function replace_structarray (bc:: Broadcasted{Style} ) where {Style}
590
+ function replace_structarray (bc:: Broadcasted )
571
591
args = replace_structarray_args (bc. args)
572
- Style′ = parent_style (Style () )
573
- return Broadcasted {Style′} ( bc. f, args, bc. axes)
592
+ style = parent_style (bc )
593
+ return broadcasted (style, bc. f, args, bc. axes)
574
594
end
575
595
function replace_structarray (A:: StructArray )
576
596
f = Instantiator (eltype (A))
577
597
args = Tuple (components (A))
578
- Style = typeof ( combine_styles (args... ) )
579
- return Broadcasted {Style} ( f, args, axes (A))
598
+ style = combine_styles (args... )
599
+ return broadcasted (style, f, args, axes (A))
580
600
end
581
601
replace_structarray (@nospecialize (A)) = A
582
602
583
603
replace_structarray_args (args:: Tuple ) = (replace_structarray (args[1 ]), replace_structarray_args (tail (args))... )
584
604
replace_structarray_args (:: Tuple{} ) = ()
585
605
586
- parent_style (@nospecialize (x)) = typeof (x)
587
- parent_style (:: StructArrayStyle{S, N} ) where {S, N} = S
588
- parent_style (:: StructArrayStyle{S, N} ) where {N, S<: AbstractArrayStyle{N} } = S
589
- parent_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle{Any} , N} = S
590
- parent_style (:: StructArrayStyle{S, N} ) where {S<: AbstractArrayStyle , N} = typeof (S (Val (N)))
591
-
592
606
# `instantiate` and `_axes` might be overloaded for static axes.
593
- function Broadcast. instantiate (bc:: Broadcasted{Style} ) where {Style <: StructArrayStyle }
594
- Style′ = parent_style (Style ())
595
- bc′ = Broadcast. instantiate (convert (Broadcasted{Style′}, bc))
596
- return convert (Broadcasted{Style}, bc′)
607
+ function Broadcast. instantiate (bc:: Broadcasted{<:StructArrayStyle} )
608
+ bc′ = Broadcast. instantiate (ofstyle (parent_style (bc), bc))
609
+ return ofstyle (style (bc), bc′)
597
610
end
598
611
599
- function Broadcast. _axes (bc:: Broadcasted{Style} , :: Nothing ) where {Style <: StructArrayStyle }
600
- Style′ = parent_style (Style ())
601
- return Broadcast. _axes (convert (Broadcasted{Style′}, bc), nothing )
612
+ function Broadcast. _axes (bc:: Broadcasted{<:StructArrayStyle} , :: Nothing )
613
+ return Broadcast. _axes (ofstyle (parent_style (bc), bc), nothing )
602
614
end
603
615
604
616
# Here we use `similar` defined for `S` to build the dest Array.
605
- function Base. similar (bc:: Broadcasted{StructArrayStyle{S, N}} , :: Type{ElType} ) where {S, N, ElType}
606
- bc′ = convert (Broadcasted{S} , bc)
617
+ function Base. similar (bc:: Broadcasted{<: StructArrayStyle} , :: Type{ElType} ) where {ElType}
618
+ bc′ = ofstyle ( parent_style (bc) , bc)
607
619
return isnonemptystructtype (ElType) ? buildfromschema (T -> similar (bc′, T), ElType) : similar (bc′, ElType)
608
620
end
609
621
610
622
# Unwrapper to recover the behaviour defined by parent style.
611
- @inline function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{StructArrayStyle{S, N}} ) where {S, N}
612
- bc′ = always_struct_broadcast (S ()) ? convert (Broadcasted{S}, bc) : replace_structarray (bc)
623
+ @inline function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:StructArrayStyle} )
624
+ ps = parent_style (bc)
625
+ bc′ = always_struct_broadcast (ps) ? ofstyle (ps, bc) : replace_structarray (bc)
613
626
return copyto! (dest, bc′)
614
627
end
615
628
616
- @inline function Broadcast. materialize! (:: StructArrayStyle{S} , dest, bc:: Broadcasted ) where {S}
617
- bc′ = always_struct_broadcast (S ()) ? bc : replace_structarray (bc)
618
- return Broadcast. materialize! (S (), dest, bc′)
629
+ @inline function Broadcast. materialize! (s:: StructArrayStyle , dest, bc:: Broadcasted )
630
+ ps = parent_style (s)
631
+ bc′ = always_struct_broadcast (ps) ? bc : replace_structarray (bc)
632
+ return Broadcast. materialize! (ps, dest, bc′)
619
633
end
620
634
621
635
# for aliasing analysis during broadcast
0 commit comments