Skip to content

Commit d5bd477

Browse files
committed
need to pen out code structure
1 parent fef329f commit d5bd477

8 files changed

+375
-3
lines changed

graph_ml/models/gensim_node2vec.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import numpy as np
2+
import gensim
3+
from ..utils import config, torch_utils

graph_ml/models/node2vec.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import numpy as np
2+
import torch
3+
from torch_geometric.nn import Node2Vec as PyGNode2Vec
4+
5+
6+
from ..utils import config, torch_utils
7+
8+
9+
class Node2Vec(object):
10+
def __init__(self, adj_list, embedding_dim, walk_length, context_size, device=config.DEVICE,
11+
logging=config.LOGGING, **params):
12+
edge_index = torch_utils.adj_list_to_edge_index(adj_list)
13+
self.model = PyGNode2Vec(
14+
edge_index, embedding_dim, walk_length, context_size, **params
15+
).to(device)
16+
self.num_workers = config.WORKER_COUNT
17+
self.logging = logging
18+
self.loader = self.optimizer = None
19+
20+
def fit(self, epochs=1, learning_rate=.1, batch_size=128):
21+
22+
# TODO (ashutosh): check if training two times works
23+
self.loader = self.model.loader(
24+
batch_size=batch_size, shuffle=True, num_workers=self.num_workers
25+
)
26+
self.optimizer = torch.optim.SparseAdam(self.model.parameters(), lr=learning_rate)
27+
self.model.train()
28+
total_loss = [0] * epochs
29+
for epoch in range(epochs):
30+
for pos_rw, neg_rw in self.loader:
31+
self.optimizer.zero_grad()
32+
loss = self.model.loss(pos_rw.to(self.model.device), neg_rw.to(self.model.device))
33+
loss.backward()
34+
self.optimizer.step()
35+
total_loss[epoch] += loss.item()
36+
total_loss[epoch] /= len(self.loader)
37+
if self.logging:
38+
print(f"Epoch: {epoch}, Loss: {total_loss[epoch]}")
39+
return sum(total_loss) / epochs
40+
41+
def transform(self, nodes=None, type_=np.ndarray):
42+
if nodes is None:
43+
nodes = torch.arange(self.model.num_nodes)
44+
if type_ is np.ndarray:
45+
return self.model(nodes).detach().cpu().numpy()
46+
return self.model(nodes).detach()
47+
48+
def fit_transform(self, epochs=1, learning_rate=.1, batch_size=128, nodes=None, type_=np.ndarray):
49+
self.fit(epochs, learning_rate, batch_size)
50+
return self.transform(nodes, type_)
51+
52+
53+
54+
55+
56+
57+
58+
59+

graph_ml/utils/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ def get_formatted_os():
2525
if PLATFORM == "darwin":
2626
return "MacOS"
2727
assert False, f"Unsupported platform: {PLATFORM}"
28+
return None

graph_ml/utils/gpu_utils.py

Whitespace-only changes.

graph_ml/utils/torch_utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
# TODO (ashutosh): move these assert statements once we have a stable code
6+
def convert_to_tensor(data):
7+
if isinstance(data, np.ndarray):
8+
return torch.from_numpy(data)
9+
assert isinstance(data, torch.Tensor)
10+
return data
11+
12+
13+
def convert_to_numpy(data):
14+
if isinstance(data, torch.Tensor):
15+
return data.cpu().numpy()
16+
assert isinstance(data, np.ndarray)
17+
return data
18+
19+
20+
def move_to_device(data, device):
21+
if isinstance(data, np.ndarray):
22+
return convert_to_tensor(data).to(device)
23+
assert isinstance(data, torch.Tensor)
24+
return data.to(device)

0 commit comments

Comments
 (0)