Skip to content

Commit 8d5efcb

Browse files
authored
Merge pull request #1001 from mcabbott/broadget
Faster generic broadcasting
2 parents b170521 + 8424c3e commit 8d5efcb

File tree

4 files changed

+87
-32
lines changed

4 files changed

+87
-32
lines changed

src/lib/array.jl

+16-7
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ end
194194

195195
struct StaticGetter{i} end
196196
(::StaticGetter{i})(v) where {i} = v[i]
197+
(::StaticGetter{i})(::Nothing) where {i} = nothing
197198
@generated function _unzip(tuples, ::Val{N}) where {N}
198199
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i 1:N)...)
199200
end
@@ -214,19 +215,27 @@ _tryreverse(m, x) = x
214215
_tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)
215216

216217
for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
217-
@eval function $∇mapfunc(cx, f, args...)
218+
@eval function $∇mapfunc(cx, f::F, args...) where {F}
218219
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
219220
if isempty(ys_and_backs)
220221
ys_and_backs, _ -> nothing
221222
else
222-
ys, backs = unzip(ys_and_backs)
223+
ys = map(first, ys_and_backs)
223224
ys, function (Δ)
224225
isnothing(Δ) && return nothing
225-
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
226-
Δf_and_args_zipped = $mapfunc((f, δ) -> f(δ), _tryreverse($mapfunc, backs, Δ)...)
227-
Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
228-
Δf = reduce(accum, Δf_and_args[1])
229-
(Δf, Δf_and_args[2:end]...)
226+
if Base.issingletontype(F) && length(args) == 1
227+
Δarg = $mapfunc(((_,pb), δ) -> last(pb(δ)), ys_and_backs, Δ) # No unzip needed
228+
(nothing, Δarg)
229+
elseif Base.issingletontype(F) # Ensures `f` is pure: nothing captured & no state
230+
Δargs = unzip($mapfunc(((_,pb), δ) -> Base.tail(pb(δ)), ys_and_backs, Δ))
231+
(nothing, Δargs...)
232+
else
233+
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
234+
Δf_and_args_zipped = $mapfunc(((_,pb), δ) -> pb(δ), _tryreverse($mapfunc, ys_and_backs, Δ)...)
235+
Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
236+
Δf = reduce(accum, Δf_and_args[1])
237+
(Δf, Δf_and_args[2:end]...)
238+
end
230239
end
231240
end
232241
end

src/lib/broadcast.jl

+28-14
Original file line numberDiff line numberDiff line change
@@ -164,31 +164,48 @@ end
164164
# Avoid hitting special cases for `Adjoint` etc.
165165
_broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
166166

167-
_get(x::Tuple, i) = x[i]
168-
_get(::Nothing, i) = nothing
169167
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
170168
collapse_nothings(xs) = xs
171169

172-
@adjoint function broadcasted(::AbstractArrayStyle, f, args...)
170+
_dual_purefun(::Type{F}) where {F<:Function} = Base.issingletontype(F)
171+
_dual_purefun(::Type) = false
172+
_dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers
173+
174+
_dual_safearg(x::Numeric{<:Real}) = true
175+
_dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
176+
_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types
177+
_dual_safearg(x) = false
178+
179+
@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
180+
T = Broadcast.combine_eltypes(f, args)
181+
# Avoid generic broadcasting in two easy cases:
182+
if T == Bool
183+
return f.(args...), _->nothing
184+
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args)
185+
y, back = broadcast_forward(f, args...)
186+
return y, ȳ -> (nothing, nothing, back(ȳ)...)
187+
end
173188
len = inclen(args)
174189
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
175-
y = map(x -> x[1], y∂b)
176-
∂b = map(x -> x[2], y∂b)
177-
y, function (ȳ)
178-
dxs_zip = map((∂b, ȳ) -> ∂b(ȳ), ∂b, ȳ)
179-
dxs = collapse_nothings.(ntuple(i -> map(x -> _get(x, i), dxs_zip), len))
190+
y = map(first, y∂b)
191+
function ∇broadcasted(ȳ)
192+
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
193+
dxs = ntuple(len) do i
194+
collapse_nothings(map(StaticGetter{i}(), dxs_zip))
195+
end
180196
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
181197
end
198+
y, ∇broadcasted
182199
end
183200

184201
@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...)
185-
len = inclen(args)
186202
y, ∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
187-
y, function (ȳ)
203+
function ∇broadcasted0(ȳ)
188204
dxs = ∂b(ȳ)
189205
dxs === nothing && return nothing
190206
(nothing, dxs...)
191207
end
208+
y, ∇broadcasted0
192209
end
193210

194211
# Use the `map` adjoint in this special case, which is the same but applies
@@ -202,17 +219,14 @@ end
202219

203220
@adjoint! (b::typeof(broadcast))(f, args...) = _pullback(__context__, broadcasted, f, args...)
204221

205-
# Forward Mode (mainly necessary for CUDA)
222+
# Forward Mode -- necessary for CUDA, also used as a fast path above
206223

207224
import ForwardDiff
208225
using ForwardDiff: Dual
209226

210227
dual(x, p) = x
211228
dual(x::Real, p) = Dual(x, p)
212229

213-
dualtype(::Type{Dual{G,T,P}}) where {G,T,P} = T
214-
dualtype(T) = T
215-
216230
function dual_function(f::F) where F
217231
function (args::Vararg{Any,N}) where N
218232
ds = map(args, ntuple(identity,Val(N))) do x, i

test/features.jl

+41-10
Original file line numberDiff line numberDiff line change
@@ -500,14 +500,45 @@ end
500500
@test 150_000_000 > @allocated gradient(loss, ones(1000,1000))
501501
end
502502

503-
@testset "tuples & broadcasting" begin
504-
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
505-
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
506-
@test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)
507-
508-
# https://github.com/FluxML/Zygote.jl/issues/975
509-
gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
510-
gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
511-
@test gt[1] == gv[1]
512-
@test collect(gt[2]) gv[2]
503+
@testset "tricky broadcasting" begin
504+
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
505+
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
506+
@test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)
507+
508+
# https://github.com/FluxML/Zygote.jl/issues/975
509+
gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
510+
gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
511+
@test gt[1] == gv[1]
512+
@test collect(gt[2]) gv[2]
513+
514+
# closure captures y -- can't use ForwardDiff
515+
@test gradient((x,y) -> sum((z->z^2+y[1]).(x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
516+
@test gradient((x,y) -> sum((z->z^2+y[1]), x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
517+
@test gradient((x,y) -> sum(map((z->z^2+y[1]), x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
518+
@test gradient((x,y) -> mapreduce((z->z^2+y[1]), +, x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
519+
520+
# type unstable
521+
@test gradient(xs -> sum((x -> x<2 ? false : x^2).(xs)), [1,2,3])[1][2:3] == [4, 6]
522+
@test gradient(xs -> sum((x -> x<2 ? false : x^2), xs), [1,2,3])[1][2:3] == [4, 6]
523+
@test gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6]
524+
@test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6]
525+
526+
# with Ref, Val, Symbol
527+
@test gradient(x -> sum(x .+ Ref(x[1])), [1,2,3]) == ([4,1,1],)
528+
@test gradient(x -> sum(x .+ (x[1],)), [1,2,3]) == ([4,1,1],)
529+
@test gradient(x -> sum((firsttuple).(x, :ignore)), [1,2,3]) == ([1,1,1],)
530+
@test gradient(x -> sum((firsttuple).(x, Symbol)), [1,2,3]) == ([1,1,1],)
531+
_f(x,::Val{y}=Val(2)) where {y} = x/y
532+
@test gradient(x -> sum(_f.(x, Val(2))), [1,2,3]) == ([0.5, 0.5, 0.5],)
533+
@test gradient(x -> sum(_f.(x)), [1,2,3]) == ([0.5, 0.5, 0.5],)
534+
@test gradient(x -> sum(map(_f, x)), [1,2,3]) == ([0.5, 0.5, 0.5],)
535+
536+
@test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],)
537+
@test gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],)
538+
539+
# negative powers
540+
@test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], [1,-1,2])[1] [1.0, -0.25, 8.0]
541+
@test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], -1)[1] [-1.0, -0.25, -0.0625]
542+
@test gradient((x,p) -> sum(z -> z^p, x), [1.0,2.0,4.0], -1)[1] [-1.0, -0.25, -0.0625]
543+
@test gradient((x,p) -> mapreduce(z -> z^p, +, x), [1.0,2.0,4.0], -1)[1] [-1.0, -0.25, -0.0625]
513544
end

test/gradcheck.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -1295,7 +1295,8 @@ end
12951295
end
12961296

12971297
@testset "broadcast" begin
1298-
@test gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1
1298+
# Before https://github.com/FluxML/Zygote.jl/pull/1001 this gave [1 1 1; 1 0 1; 1 1 -1]
1299+
@test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] [1 0 0; 0 0 0; 0 0 -1]
12991300

13001301
a = rand(3)
13011302
b = rand(2,2)

0 commit comments

Comments
 (0)