Skip to content

Projections do not play well with GPUCompiler #429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pabloferz opened this issue Aug 9, 2021 · 2 comments · May be fixed by #430
Open

Projections do not play well with GPUCompiler #429

pabloferz opened this issue Aug 9, 2021 · 2 comments · May be fixed by #430
Labels
ProjectTo related to the projection functionality

Comments

@pabloferz
Copy link

pabloferz commented Aug 9, 2021

Here's an example that does not play well with GPUCompiler: https://gist.github.com/pabloferz/1390d85383e3243015be7ad5b162bcc4

A possible, but probably incomplete fix discussed with @mcabbott, is having the following specializations:

function ProjectTo(x::AbstractArray{T}) where {T <: AbstractFloat}
    return ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
end

ProjectTo(x::AbstractArray{T}) where {T <: Bool} = ProjectTo{NoTangent}()

function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S}) where {S <: Number}
    T = ChainRulesCore.project_type(project.element)
    return S <: T ? dx : map(project.element, dx)
end
@pabloferz pabloferz changed the title Projections do no play well with GPUCompiler Projections do not play well with GPUCompiler Aug 9, 2021
@mcabbott
Copy link
Member

mcabbott commented Aug 9, 2021

I think we should do something like the first fix. The whole constructor function can be dispatch not branching, without loss, it just happened to get written that way:
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/master/src/projection.jl#L181-L188

Edit -- like this commit, maybe: 7e5ae8e

The second seems trickier, it's avoiding this if hasproperty(project, :element)
https://github.com/JuliaDiff/ChainRulesCore.jl/blob/master/src/projection.jl#L216
But also:

  • It avoids the reshape completely. Can we bypass this in more cases?
  • It assumes that the projector being applied to dx::AbstractArray{<:Number} has .element, but that's need not be true, e.g. for x = Any[1,2,3].

It's possible that we should insist that every array projector has .element, if necessary a trivial one. That might help in-place stuff too.

It's also possible that we should mark the two cases in some way easier to dispatch off of?

One iteration of this thing had ProjectTo{AbstractArray{Float32}}(...) always encoding the eltype. For every number type, the projector is in fact fully described by the element type. This design of literally storing the inner projector seemed simpler. For instance: do you make Projector{Diagonal{Float32}}(...) too, or do wrapper types always delegate this?

(It's also possible that we should @inline many things. No idea if this would help here; was reluctant to add clutter until finding at least one example where it does.)

@mcabbott
Copy link
Member

mcabbott commented Sep 5, 2021

ChainRulesCore v1.3.1, latest, has the above branch-free construction of projectors, but still has if hasproperty(project, :element) in applying them.

If I try the linked gist, on latest everything (CUDA v3.4.2) I get a warning on the first run, but subsequent runs are find. Can you confirm what you see, and whether you think there are sill problems here?

julia> val_and_grad(dihedral_angle, CUDA.rand(Float64, 3, 4))
(┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007f5e3878d710.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
...[etc]...
└ @ GPUArrays ~/.julia/packages/GPUArrays/UBzTm/src/host/indexing.jl:56
fill(-0.33681546256995853), fill(SVector{3, Float64}[[0.4074603842207669, -1.441045304947525, 0.5486342400104408], [0.36579849547697396, -0.2231744971894618, 0.43044376309880106], [5.034006482016548, -7.684693458589903, 6.191228243615761], [-5.807265361714289, 9.34891326072689, -7.170306246725003]]))

julia> CUDA.allowscalar(false)

julia> val_and_grad(dihedral_angle, CUDA.rand(Float64, 3, 4))
(fill(-0.780010516490922), fill(SVector{3, Float64}[[0.845104396943803, 1.9836409707095195, 2.047662646035268], [-1.7434321992965327, -0.7401691262870027, -0.25745246549838496], [-0.24669189468086605, -0.0020683788027422434, 0.08506441265560705], [1.1450196970335957, -1.2414034656197745, -1.87527459319249]]))

PR #430 removes the if hasproperty(project, :element) branch, but I'm not certain it won't land us in dispatch hell. Would need careful thought, at least.

@mcabbott mcabbott added the ProjectTo related to the projection functionality label Sep 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ProjectTo related to the projection functionality
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants