Skip to content

Commit 843a52d

Browse files
authored
Restrict type broadcast rule to numbers (#1179)
* restrict type broadcast to number * add a test
1 parent 2a2095c commit 843a52d

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.35"
3+
version = "0.6.36"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/lib/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ end
144144
end
145145
end
146146

147-
@adjoint broadcasted(::Type{T}, x::Numeric) where T =
147+
@adjoint broadcasted(::Type{T}, x::Numeric) where {T<:Number} =
148148
T.(x), ȳ -> (nothing, _project(x, ȳ),)
149149

150150
# General Fallback

test/gradcheck.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,9 +1407,17 @@ end
14071407
@test all(gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) .≈ ([12, 12], [3 3 3]))
14081408
@test all(gradient((x,y) -> sum(x ./ y), [1,2], 5) .≈ ([0.2, 0.2], -0.12))
14091409

1410+
# https://github.com/FluxML/Zygote.jl/pull/1171
14101411
sm = sprand(5, 5, 0.5)
14111412
@test gradient(x -> sum(abs2, Float32.(x)), sm)[1] gradient(x -> sum(abs2, x), Matrix{Float32}(sm))[1]
14121413
@test gradient(x -> real(sum(ComplexF32.(x) .+ 1 .+ im)), sm)[1] isa SparseMatrixCSC{Float64}
1414+
1415+
# https://github.com/FluxML/Zygote.jl/issues/1178
1416+
function f1179(x)
1417+
fs = Ref.(x)
1418+
getindex.(fs)
1419+
end
1420+
@test gradient(sumf1179, ones(2)) == ([2.0, 2.0],)
14131421
end
14141422

14151423
using Zygote: Buffer

0 commit comments

Comments
 (0)