-
Notifications
You must be signed in to change notification settings - Fork 63
RFC: accept Tuple Tangent for arrays? #444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #444 +/- ##
==========================================
- Coverage 92.37% 92.02% -0.36%
==========================================
Files 14 14
Lines 787 790 +3
==========================================
Hits 727 727
- Misses 60 63 +3
Continue to review full report at Codecov.
|
@@ -241,6 +241,12 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro | |||
return fill(project.element(dx)) | |||
end | |||
|
|||
# Accept the Tangent corresponding to a Tuple -- Zygote's splats produce these | |||
function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Any, <:Tuple}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think i prefer this to be written as:
function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Any, <:Tuple}) | |
function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Tuple}) |
so we don't need to mention how this backed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Zygote managed to produce a Tangent{Any, Tuple{...}}
, which was why my pirate method specified that, here:
But I was squashing bugs as fast as I could & didn't track it down any further. Can circle back now that things basically pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right, fair enough.
If we have to do it this way that is also fine.
Maybe just a commond is woth adding to clarify that this is the tangent for something with primal type tangent
Zygote does lose primal type information, that is a thing
We might well be able to teach it a bit more about tuples more easily than the general case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, easy examples of this Tangent{Any} problem:
julia> using Zygote
julia> function Zygote.ChainRulesCore.rrule(::typeof(identity), x::AbstractArray)
x, dx -> (Zygote.NoTangent(), @show dx)
end
julia> Zygote.gradient(x -> max(identity(x)...), [1,2,3])
dx = Tangent{Any}(0, 0, 1)
((0, 0, 1),)
julia> Zygote.gradient(x -> sum(identity(x).parent), [1,2,3]')
dx = Tangent{Any}(parent = 3-element Fill{Int64}: entries equal to 1,)
((parent = 3-element Fill{Int64}: entries equal to 1,),)
I think I am ok with this. We might want to demand that it is only going into a |
Thanks for thinking. For Zygote at least I do think it wants to allow matrices, which also splat to tuples. It's possible that this should be Dispatch hell might be another reason to do this within Zygote rather than here. Diffractor also has difficulty wish splats, but not the same difficulty, so it's not obvious whether this is more widely useful. |
This does this:
as a step towards solving this:
I'm not entirely sure this should be handled here not in Zygote, what thoughts?