Skip to content

Commit a1e78a4

Browse files
committed
lot of changes, no cleanup on setup, new node2vec class structure, need to setup sampler
1 parent e990ec0 commit a1e78a4

File tree

7 files changed

+94
-31
lines changed

7 files changed

+94
-31
lines changed

Makefile

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ clean:
99

1010

1111
.PHONY: setup
12-
setup: clean
12+
setup:
1313
@echo "setting up..."
1414
ifeq ($(OS),Darwin)
1515
@echo "Mac"
@@ -20,7 +20,7 @@ endif
2020
@poetry install --only main -vvv
2121

2222
.PHONY: setup_all
23-
setup_all: clean
23+
setup_all:
2424
@echo "setting up..."
2525
@poetry config virtualenvs.in-project true
2626
@poetry install -vvv

graph_ml/models/gensim_node2vec.py

+18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
11
import numpy as np
22
import gensim
33
from ..utils import config, torch_utils
4+
5+
from ..models.node2vec import Node2Vec
6+
7+
class GensimNode2Vec(Node2Vec):
8+
def __init__(self, **params):
9+
super().__init__(**params)
10+
self.model_params = {
11+
"vector_size": self.embedding_dim,
12+
"window": self.context_size,
13+
"min_count": 0,
14+
"sg": 1,
15+
"hs": 0,
16+
"negative": 1,
17+
"ns_exponent": 0.5,
18+
"epochs": 1,
19+
"workers": self.num_workers
20+
}
21+

graph_ml/models/node2vec.py

+7-29
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,19 @@
99
class Node2Vec(object):
1010
def __init__(self, adj_list, embedding_dim, walk_length, context_size, device=config.DEVICE,
1111
logging=config.LOGGING, **params):
12-
edge_index = torch_utils.adj_list_to_edge_index(adj_list)
13-
self.model = PyGNode2Vec(
14-
edge_index, embedding_dim, walk_length, context_size, **params
15-
).to(device)
12+
self.adj_list = adj_list
1613
self.num_workers = config.WORKER_COUNT
1714
self.logging = logging
18-
self.loader = self.optimizer = None
15+
self.embedding_dim = embedding_dim
16+
self.walk_length = walk_length
17+
self.context_size = context_size
18+
self.device = device
1919

2020
def fit(self, epochs=1, learning_rate=.1, batch_size=128):
21-
22-
# TODO (ashutosh): check if training two times works
23-
self.loader = self.model.loader(
24-
batch_size=batch_size, shuffle=True, num_workers=self.num_workers
25-
)
26-
self.optimizer = torch.optim.SparseAdam(self.model.parameters(), lr=learning_rate)
27-
self.model.train()
28-
total_loss = [0] * epochs
29-
for epoch in range(epochs):
30-
for pos_rw, neg_rw in self.loader:
31-
self.optimizer.zero_grad()
32-
loss = self.model.loss(pos_rw.to(self.model.device), neg_rw.to(self.model.device))
33-
loss.backward()
34-
self.optimizer.step()
35-
total_loss[epoch] += loss.item()
36-
total_loss[epoch] /= len(self.loader)
37-
if self.logging:
38-
print(f"Epoch: {epoch}, Loss: {total_loss[epoch]}")
39-
return sum(total_loss) / epochs
21+
return self._fit(epochs, learning_rate, batch_size)
4022

4123
def transform(self, nodes=None, type_=np.ndarray):
42-
if nodes is None:
43-
nodes = torch.arange(self.model.num_nodes)
44-
if type_ is np.ndarray:
45-
return self.model(nodes).detach().cpu().numpy()
46-
return self.model(nodes).detach()
24+
return self._transform(nodes, type_)
4725

4826
def fit_transform(self, epochs=1, learning_rate=.1, batch_size=128, nodes=None, type_=np.ndarray):
4927
self.fit(epochs, learning_rate, batch_size)

graph_ml/models/torch_node2vec.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
from torch_geometric.nn import Node2Vec as PyGNode2Vec
3+
import torch
4+
from ..models.node2vec import Node2Vec
5+
from ..utils import config, torch_utils
6+
7+
8+
class TorchNode2Vec(Node2Vec):
9+
def __init__(self, **params):
10+
super().__init__(**params)
11+
self.model = PyGNode2Vec(
12+
edge_index=self.edge_index,
13+
embedding_dim=self.embedding_dim,
14+
walk_length=self.walk_length,
15+
context_size=self.context_size,
16+
sparse=True,
17+
**params
18+
).to(self.device)
19+
self.loader = self.optimizer = None
20+
21+
@property
22+
def edge_index(self):
23+
# should not be called too often, no caching here
24+
return torch_utils.adj_list_to_edge_index(self.adj_list)
25+
26+
27+
def _fit(self, epochs, learning_rate, batch_size, shuffle=True):
28+
29+
# TODO (ashutosh): check if training two times works
30+
self.loader = self.model.loader(
31+
batch_size=batch_size, shuffle=shuffle, num_workers=self.num_workers
32+
)
33+
self.optimizer = torch.optim.SparseAdam(self.model.parameters(), lr=learning_rate)
34+
self.model.train()
35+
total_loss = [0] * epochs
36+
for epoch in range(epochs):
37+
for pos_rw, neg_rw in self.loader:
38+
self.optimizer.zero_grad()
39+
loss = self.model.loss(pos_rw.to(self.model.device), neg_rw.to(self.device))
40+
loss.backward()
41+
self.optimizer.step()
42+
total_loss[epoch] += loss.item()
43+
total_loss[epoch] /= len(self.loader)
44+
if self.logging:
45+
print(f"Epoch: {epoch}, Loss: {total_loss[epoch]}")
46+
return self
47+
48+
def _transform(self, nodes=None, type_=np.ndarray):
49+
self.model.eval()
50+
if nodes is None:
51+
nodes = torch.arange(self.num_nodes, device=self.device)
52+
with torch.no_grad():
53+
emb = self.model(torch.tensor(nodes, device=self.device)).detach()
54+
if type_ is np.ndarray:
55+
return emb.cpu().numpy()
56+
return emb

graph_ml/transformations/samplers/random_walk_sampler.py

Whitespace-only changes.

graph_ml/transformations/samplers/sampler.py

Whitespace-only changes.

graph_ml/utils/config.py

+11
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@
1818

1919
GPU_AVAILABLE = DEVICE_TYPE in ["cuda", "mps"]
2020

21+
try:
22+
import pyg_lib # noqa
23+
WITH_PYG_LIB = True
24+
except:
25+
WITH_PYG_LIB = False
26+
27+
try:
28+
import torch_cluster # noqa
29+
WITH_TORCH_CLUSTER = True
30+
except:
31+
WITH_TORCH_CLUSTER = False
2132

2233
def get_formatted_os():
2334
if PLATFORM == "linux":

0 commit comments

Comments
 (0)