Skip to content

Commit 8b9d67d

Browse files
lxvmmcabbott
andauthored
Fix adjoint Iterators.product behavior with nothing (#1170)
* Fix adjoint Iterators.product behavior with nothing * Apply suggestions from code review Co-authored-by: Michael Abbott <[email protected]> * add Iterators.product adjoint tests * Update test/lib/array.jl Co-authored-by: Michael Abbott <[email protected]> Co-authored-by: Michael Abbott <[email protected]>
1 parent e56375e commit 8b9d67d

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/lib/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,10 @@ _ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)
286286
function back(dy::AbstractArray)
287287
d = 1
288288
ntuple(length(xs)) do n
289-
first(dy)[n] === nothing && return nothing
290289
nd = _ndims(xs[n])
291290
dims = ntuple(i -> i<d ? i : i+nd, ndims(dy)-nd)
292291
d += nd
292+
first(dy)[n] === nothing && return nothing
293293
init = zero.(first(dy)[n]) # allows for tuples, which accum can add:
294294
red = mapreduce(StaticGetter{n}(), accum, dy; dims=dims, init=init)
295295
return _project(xs[n], reshape(red, axes(xs[n])))

test/lib/array.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
using ChainRulesTestUtils
22
using LinearAlgebra
3-
using Zygote: ZygoteRuleConfig
3+
using Zygote: ZygoteRuleConfig, _pullback
44

55
# issue 897
66

77
test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), ones(2); rrule_f=rrule_via_ad, check_inferred=false)
88
test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_via_ad, check_inferred=false)
9+
10+
@testset "adjoints of Iterators.product" begin
11+
y, back = _pullback(Iterators.product, 1:5, 1:3, 1:2)
12+
@test back(collect(y)) == (nothing, [6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], [15.0, 30.0])
13+
@test back([(nothing, j, k) for i in 1:5, j in 1:3, k in 1:2]) == (nothing, nothing, [10.0, 20.0, 30.0], [15.0, 30.0])
14+
@test back([(i, nothing, k) for i in 1:5, j in 1:3, k in 1:2]) == (nothing, [6.0, 12.0, 18.0, 24.0, 30.0], nothing, [15.0, 30.0])
15+
@test back([(i, j, nothing) for i in 1:5, j in 1:3, k in 1:2]) == (nothing, [6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], nothing)
16+
17+
# This was wrong before https://github.com/FluxML/Zygote.jl/pull/1170
18+
@test gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])[1] [320, 320, 320, 320]
19+
@test gradient(x -> sum(y[2] * y[3] for y in Iterators.product(x, x, x, x)), [1,2,3,4])[1] [320, 320, 320, 320]
20+
end

0 commit comments

Comments
 (0)