|
1 | 1 | using ChainRulesTestUtils
|
2 | 2 | using LinearAlgebra
|
3 |
| -using Zygote: ZygoteRuleConfig |
| 3 | +using Zygote: ZygoteRuleConfig, _pullback |
4 | 4 |
|
5 | 5 | # issue 897
|
6 | 6 |
|
7 | 7 | test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), ones(2); rrule_f=rrule_via_ad, check_inferred=false)
|
8 | 8 | test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_via_ad, check_inferred=false)
|
| 9 | + |
| 10 | +@testset "adjoints of Iterators.product" begin |
| 11 | + y, back = _pullback(Iterators.product, 1:5, 1:3, 1:2) |
| 12 | + @test back(collect(y)) == (nothing, [6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], [15.0, 30.0]) |
| 13 | + @test back([(nothing, j, k) for i in 1:5, j in 1:3, k in 1:2]) == (nothing, nothing, [10.0, 20.0, 30.0], [15.0, 30.0]) |
| 14 | + @test back([(i, nothing, k) for i in 1:5, j in 1:3, k in 1:2]) == (nothing, [6.0, 12.0, 18.0, 24.0, 30.0], nothing, [15.0, 30.0]) |
| 15 | + @test back([(i, j, nothing) for i in 1:5, j in 1:3, k in 1:2]) == (nothing, [6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], nothing) |
| 16 | + |
| 17 | + # This was wrong before https://github.com/FluxML/Zygote.jl/pull/1170 |
| 18 | + @test gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])[1] ≈ [320, 320, 320, 320] |
| 19 | + @test gradient(x -> sum(y[2] * y[3] for y in Iterators.product(x, x, x, x)), [1,2,3,4])[1] ≈ [320, 320, 320, 320] |
| 20 | +end |
0 commit comments