|
1 | 1 | ofeltype(x, y) = convert(float(eltype(x)), y) |
2 | 2 |
|
3 | | -# Considers the src a zero dimensional object. |
4 | | -# Useful for implementing `StatsBase.counts`, `degree`, etc... |
5 | | -# function NNlib.scatter!(op, dst::AbstractArray, src::Number, idx::AbstractArray) |
6 | | -# for k in CartesianIndices(idx) |
7 | | -# # dst_v = NNlib._view(dst, idx[k]) |
8 | | -# # dst_v .= (op).(dst_v, src) |
9 | | -# dst[idx[k]] .= (op).(dst[idx[k]], src) |
10 | | -# end |
11 | | -# dst |
12 | | -# end |
13 | | - |
14 | | -# 10 time faster than the generic version above. |
15 | | -# All the speedup comes from not broadcasting `op`, i dunno why. |
16 | | -function NNlib.scatter!(op, dst::AbstractVector, src::Number, idx::AbstractVector{<:Integer}) |
17 | | - for i in idx |
18 | | - dst[i] = op(dst[i], src) |
19 | | - end |
20 | | -end |
21 | | - |
22 | | -# NNlib._view(X, k) = view(X, k...) |
23 | | -# NNlib._view(X, k::Union{Integer, CartesianIndex}) = view(X, k) |
24 | | - |
25 | | -# Considers src as a zero dimensional object to be scattered |
26 | | -# function NNlib.scatter(op, |
27 | | -# src::Tsrc, |
28 | | -# idx::AbstractArray{Tidx,Nidx}; |
29 | | -# init = nothing, dstsize = nothing) where {Tsrc<:Number,Tidx,Nidx} |
30 | | - |
31 | | -# dstsz = isnothing(dstsize) ? maximum_dims(idx) : dstsize |
32 | | -# dst = similar(src, Tsrc, dstsz) |
33 | | -# xinit = isnothing(init) ? scatter_empty(op, Tsrc) : init |
34 | | -# fill!(dst, xinit) |
35 | | -# scatter!(op, dst, src, idx) |
36 | | -# end |
37 | | - |
38 | | - |
39 | | -function scatter_scalar_kernel!(op, dst, src, idx) |
40 | | - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x |
41 | | - |
42 | | - @inbounds if index <= length(idx) |
43 | | - CUDA.@atomic dst[idx[index]...] = op(dst[idx[index]...], src) |
44 | | - end |
45 | | - return nothing |
46 | | -end |
47 | | - |
48 | | -function NNlib.scatter!(op, dst::AnyCuArray, src::Number, idx::AnyCuArray) |
49 | | - max_idx = length(idx) |
50 | | - args = op, dst, src, idx |
51 | | - |
52 | | - kernel = @cuda launch=false scatter_scalar_kernel!(args...) |
53 | | - config = launch_configuration(kernel.fun; max_threads=256) |
54 | | - threads = min(max_idx, config.threads) |
55 | | - blocks = cld(max_idx, threads) |
56 | | - kernel(args...; threads=threads, blocks=blocks) |
57 | | - return dst |
58 | | -end |
59 | | - |
60 | 3 | """ |
61 | 4 | reduce_nodes(aggr, g, x) |
62 | 5 |
|
@@ -157,3 +100,16 @@ function broadcast_edges(g::GNNGraph, x) |
157 | 100 | return gather(x, gi) |
158 | 101 | end |
159 | 102 |
|
| 103 | +# More generic version of |
| 104 | +# https://github.com/JuliaDiff/ChainRules.jl/pull/586 |
| 105 | +# This applies to all arrays |
| 106 | +# Withouth this, gradient of T.(A) for A dense gpu matrix errors. |
| 107 | +function ChainRulesCore.rrule(::typeof(Broadcast.broadcasted), T::Type{<:Number}, x::AbstractArray) |
| 108 | + proj = ProjectTo(x) |
| 109 | + |
| 110 | + function broadcasted_cast(Δ) |
| 111 | + return NoTangent(), NoTangent(), proj(Δ) |
| 112 | + end |
| 113 | + |
| 114 | + return T.(x), broadcasted_cast |
| 115 | +end |
0 commit comments