-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathshortest_path.py
72 lines (53 loc) · 2.31 KB
/
shortest_path.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import annotations
from typing import Tuple, Dict, List
from torch.multiprocessing import spawn
import networkx as nx
from torch_geometric.data import Data, Batch
from torch_geometric.utils.convert import to_networkx
def floyd_warshall_source_to_all(G, source, cutoff=None):
if source not in G:
raise nx.NodeNotFound("Source {} not in G".format(source))
edges = {edge: i for i, edge in enumerate(G.edges())}
level = 0 # the current level
nextlevel = {source: 1} # list of nodes to check at next level
node_paths = {source: [source]} # paths dictionary (paths to key from source)
edge_paths = {source: []}
while nextlevel:
thislevel = nextlevel
nextlevel = {}
for v in thislevel:
for w in G[v]:
if w not in node_paths:
node_paths[w] = node_paths[v] + [w]
edge_paths[w] = edge_paths[v] + [edges[tuple(node_paths[w][-2:])]]
nextlevel[w] = 1
level = level + 1
if (cutoff is not None and cutoff <= level):
break
return node_paths, edge_paths
def all_pairs_shortest_path(G) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
paths = {n: floyd_warshall_source_to_all(G, n) for n in G}
node_paths = {n: paths[n][0] for n in paths}
edge_paths = {n: paths[n][1] for n in paths}
return node_paths, edge_paths
def shortest_path_distance(data: Data) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
G = to_networkx(data)
node_paths, edge_paths = all_pairs_shortest_path(G)
return node_paths, edge_paths
def batched_shortest_path_distance(data) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
graphs = [to_networkx(sub_data) for sub_data in data.to_data_list()]
relabeled_graphs = []
shift = 0
for i in range(len(graphs)):
num_nodes = graphs[i].number_of_nodes()
relabeled_graphs.append(nx.relabel_nodes(graphs[i], {i: i + shift for i in range(num_nodes)}))
shift += num_nodes
paths = [all_pairs_shortest_path(G) for G in relabeled_graphs]
node_paths = {}
edge_paths = {}
for path in paths:
for k, v in path[0].items():
node_paths[k] = v
for k, v in path[1].items():
edge_paths[k] = v
return node_paths, edge_paths