Skip to content

Commit f48d2da

Browse files
differentiable adjacency_matrix for dense
1 parent d57eec1 commit f48d2da

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

src/GNNGraphs/convert.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,14 @@ function to_dense(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing, weighted=t
120120
if !weighted
121121
val = T(1)
122122
end
123-
A = fill!(similar(s, T, (n, n)), 0)
124-
v = vec(A) # vec view of A
125123
idxs = s .+ n .* (t .- 1)
124+
125+
## using scatter instead of indexing since there could be multiple edges
126+
# A = fill!(similar(s, T, (n, n)), 0)
127+
# v = vec(A) # vec view of A
126128
# A[idxs] .= val # exploiting linear indexing
127-
NNlib.scatter!(+, v, val, idxs) # using scatter instead of indexing since there could be multiple edges
129+
v = NNlib.scatter(+, val, idxs, dstsize=n^2)
130+
A = reshape(v, (n, n))
128131
return A, n, length(s)
129132
end
130133

src/GNNGraphs/query.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -413,13 +413,11 @@ function has_multi_edges(g::GNNGraph)
413413
length(union(idxs)) < length(idxs)
414414
end
415415

416-
417416
@non_differentiable adjacency_list(x...)
418-
@non_differentiable adjacency_matrix(g::GNNGraph{<:ADJMAT_T}) # TODO remove this in the future
419417
@non_differentiable graph_indicator(x...)
420418
@non_differentiable has_multi_edges(x...)
421419
@non_differentiable Graphs.has_self_loops(x...)
422420
@non_differentiable is_bidirected(x...)
423421
@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
424422
@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
425-
@non_differentiable scaled_laplacian(x...)
423+
@non_differentiable scaled_laplacian(x...) # TODO remove this in the future

test/runtests.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ tests = [
4040

4141
!CUDA.functional() && @warn("CUDA unavailable, not testing GPU support")
4242

43-
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:coo, :dense, :sparse)
43+
@testset "GraphNeuralNetworks: graph format $graph_type" for graph_type in (:dense, :coo, :sparse)
4444
global GRAPH_T = graph_type
45-
global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse)
46-
45+
# global TEST_GPU = CUDA.functional() && (GRAPH_T != :sparse)
46+
global TEST_GPU = false
47+
48+
4749
for t in tests
4850
startswith(t, "examples") && GRAPH_T == :dense && continue # not testing :dense since causes OutOfMememory on github's CI
4951
include("$t.jl")

0 commit comments

Comments
 (0)