Skip to content

Commit 7296e78

Browse files
committed
Specialize broadcasting adjoint vector
1 parent 1050707 commit 7296e78

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

src/fillbroadcast.jl

+24-15
Original file line numberDiff line numberDiff line change
@@ -87,39 +87,48 @@ Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{2}
8787
Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{1}) = S
8888
Broadcast.BroadcastStyle(S::LinearAlgebra.StructuredMatrixStyle, ::ZerosStyle{0}) = S
8989

90-
_getindex_value(f::AbstractFill) = getindex_value(f)
91-
_getindex_value(x::Number) = x
92-
_getindex_value(x::Ref) = x[]
93-
function _getindex_value(bc::Broadcast.Broadcasted)
94-
bc.f(map(_getindex_value, bc.args)...)
90+
# Obtain the fill value of a broadcasted object by recursively evaluating the fill components
91+
broadcast_getindex_value(f::AbstractFill) = getindex_value(f)
92+
broadcast_getindex_value(f::Transpose{<:Any,<:AbstractFill}) = getindex_value(parent(f))
93+
broadcast_getindex_value(f::Adjoint{<:Any,<:AbstractFill}) = getindex_value(parent(f))
94+
broadcast_getindex_value(x::Number) = x
95+
broadcast_getindex_value(x::Ref) = x[]
96+
function broadcast_getindex_value(bc::Broadcast.Broadcasted)
97+
bc.f(map(broadcast_getindex_value, bc.args)...)
9598
end
9699

97100
has_static_value(x) = false
98101
has_static_value(x::Union{AbstractZeros, AbstractOnes}) = true
99102
has_static_value(x::Broadcast.Broadcasted) = all(has_static_value, x.args)
100103

104+
# _iszeros and _isones are conservative checks for zeros and ones,
105+
# which are used to determine if a broadcasted object is a Fill, Zeros or Ones.
101106
function _iszeros(bc::Broadcast.Broadcasted)
102-
all(has_static_value, bc.args) && _iszero(_getindex_value(bc))
107+
all(has_static_value, bc.args) && _iszero(broadcast_getindex_value(bc))
103108
end
104109
# conservative check for zeros. In most cases, there isn't a zero element to compare with
105110
_iszero(x::Union{Number, AbstractArray}) = iszero(x)
106111
_iszero(_) = false
107112

108113
function _isones(bc::Broadcast.Broadcasted)
109-
all(has_static_value, bc.args) && _isone(_getindex_value(bc))
114+
all(has_static_value, bc.args) && _isone(broadcast_getindex_value(bc))
110115
end
111116
# conservative check for ones. In most cases, there isn't a unit element to compare with
112117
_isone(x::Union{Number, AbstractArray}) = isone(x)
113118
_isone(_) = false
114119

115-
_isfill(bc::Broadcast.Broadcasted) = all(_isfill, bc.args)
116-
_isfill(f::AbstractFill) = true
117-
_isfill(f::Number) = true
118-
_isfill(f::Ref) = true
119-
_isfill(::Any) = false
120+
# wrappers that are equivalent to an `AbstractFill` may opt in to the broadcasting behavior
121+
# of `AbstractFill` by specializing `isfill` and `broadcast_getindex_value`
122+
isfill(bc::Broadcast.Broadcasted) = all(isfill, bc.args)
123+
isfill(f::AbstractFill) = true
124+
isfill(f::Transpose) = isfill(parent(f))
125+
isfill(f::Adjoint) = isfill(parent(f))
126+
isfill(f::Number) = true
127+
isfill(f::Ref) = true
128+
isfill(::Any) = false
120129

121130
function _copy_fill(bc)
122-
v = _getindex_value(bc)
131+
v = broadcast_getindex_value(bc)
123132
if _iszeros(bc)
124133
return Zeros(typeof(v), axes(bc))
125134
elseif _isones(bc)
@@ -130,7 +139,7 @@ end
130139

131140
# recursively copy the purely fill components
132141
function _preprocess_fill(bc::Broadcast.Broadcasted{<:AbstractFillStyle})
133-
_isfill(bc) ? _copy_fill(bc) : Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...)
142+
isfill(bc) ? _copy_fill(bc) : Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...)
134143
end
135144
_preprocess_fill(bc::Broadcast.Broadcasted) = Broadcast.broadcasted(bc.f, map(_preprocess_fill, bc.args)...)
136145
_preprocess_fill(x) = x
@@ -144,7 +153,7 @@ function _fallback_copy(bc)
144153
end
145154

146155
function Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle})
147-
_isfill(bc) ? _copy_fill(bc) : _fallback_copy(bc)
156+
isfill(bc) ? _copy_fill(bc) : _fallback_copy(bc)
148157
end
149158
# make the zero-dimensional case consistent with Base
150159
Base.copy(bc::Broadcast.Broadcasted{<:AbstractFillStyle{0}}) = _fallback_copy(bc)

test/runtests.jl

+6
Original file line numberDiff line numberDiff line change
@@ -1246,6 +1246,12 @@ end
12461246
A = ones(1,5) .+ (ones(1) .+ (_ -> rand()).(Fill("vec", 2)))
12471247
@test all(==(A[1]), A)
12481248
end
1249+
1250+
@testset "wrappers" begin
1251+
f = Fill(3, 4)
1252+
@test f * f' === Fill(9,4,4)
1253+
@test f * transpose(f) === Fill(9,4,4)
1254+
end
12491255
end
12501256

12511257
@testset "map" begin

0 commit comments

Comments
 (0)