Skip to content

RFC: accept Tuple Tangent for arrays? #444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

mcabbott
Copy link
Member

This does this:

julia> ProjectTo([1 2; 3 4])(Tangent{Tuple}(0.1,2,30,400))
2×2 Matrix{Float64}:
 0.1   30.0
 2.0  400.0

julia> ProjectTo([1,2,3,4]')(Tangent{Tuple}(0.1,2,30,400))
ERROR: DimensionMismatch("array with ndims(x) == 1 >  0 cannot have dx::Number")

as a step towards solving this:

julia> Zygote.gradient(x -> +(x...), [1 2; 3 4])[1]
(1, 1, 1, 1)

I'm not entirely sure this should be handled here not in Zygote, what thoughts?

@codecov-commenter
Copy link

codecov-commenter commented Aug 19, 2021

Codecov Report

Merging #444 (abde997) into master (2208660) will decrease coverage by 0.35%.
The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #444      +/-   ##
==========================================
- Coverage   92.37%   92.02%   -0.36%     
==========================================
  Files          14       14              
  Lines         787      790       +3     
==========================================
  Hits          727      727              
- Misses         60       63       +3     
Impacted Files Coverage Δ
src/projection.jl 94.44% <0.00%> (-1.34%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2208660...abde997. Read the comment docs.

@mzgubic mzgubic added ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values labels Aug 19, 2021
@@ -241,6 +241,12 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
return fill(project.element(dx))
end

# Accept the Tangent corresponding to a Tuple -- Zygote's splats produce these
function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Any, <:Tuple})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i prefer this to be written as:

Suggested change
function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Any, <:Tuple})
function (project::ProjectTo{AbstractArray})(dx::Tangent{<:Tuple})

so we don't need to mention how this backed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Zygote managed to produce a Tangent{Any, Tuple{...}}, which was why my pirate method specified that, here:

https://github.com/FluxML/Zygote.jl/pull/1044/files#diff-e0bc7da8f1a33a59f5ecfa67257c04038f0b4915b3f74bdf39780818fd0010a2R162

But I was squashing bugs as fast as I could & didn't track it down any further. Can circle back now that things basically pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, fair enough.
If we have to do it this way that is also fine.
Maybe just a commond is woth adding to clarify that this is the tangent for something with primal type tangent

Zygote does lose primal type information, that is a thing
We might well be able to teach it a bit more about tuples more easily than the general case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, easy examples of this Tangent{Any} problem:

julia> using Zygote

julia> function Zygote.ChainRulesCore.rrule(::typeof(identity), x::AbstractArray)
         x, dx -> (Zygote.NoTangent(), @show dx)
       end

julia> Zygote.gradient(x -> max(identity(x)...), [1,2,3])
dx = Tangent{Any}(0, 0, 1)
((0, 0, 1),)

julia> Zygote.gradient(x -> sum(identity(x).parent), [1,2,3]')
dx = Tangent{Any}(parent = 3-element Fill{Int64}: entries equal to 1,)
((parent = 3-element Fill{Int64}: entries equal to 1,),)

@oxinabox
Copy link
Member

oxinabox commented Aug 19, 2021

I think I am ok with this.
Sometimes you end up with the wrong kind of iterator.
A common incorrect type of iterator is a tuple.

We might want to demand that it is only going into a AbstractVector.
But that might be too restictive since we might want to do 1 row matrixes etc also?

@mcabbott
Copy link
Member Author

Thanks for thinking. For Zygote at least I do think it wants to allow matrices, which also splat to tuples.

It's possible that this should be (project::ProjectTo{<:AbstractArray})(dx::Tangent{...}) to apply to splatted e.g. row vectors. But this might land us in dispatch hell, and might be best to explore after #430 .

Dispatch hell might be another reason to do this within Zygote rather than here. Diffractor also has difficulty wish splats, but not the same difficulty, so it's not obvious whether this is more widely useful.

@mcabbott mcabbott closed this Oct 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ProjectTo related to the projection functionality Structural Tangent Related to the `Tangent` type for structured (composite) values
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants