Skip to content

Commit 92e8e95

Browse files
committed
finalize desgin
1 parent 71a8068 commit 92e8e95

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

graph_ml/models/node2vec.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import numpy as np
2-
1+
import torch
32

43
from ..utils import config
54

@@ -26,11 +25,16 @@ def __init__(
2625
def fit(self, epochs=1, learning_rate=0.1, batch_size=128):
2726
return self._fit(epochs, learning_rate, batch_size)
2827

29-
def transform(self, nodes=None, type_=np.ndarray):
28+
def transform(self, nodes=None, type_=torch.Tensor):
3029
return self._transform(nodes, type_)
3130

3231
def fit_transform(
33-
self, epochs=1, learning_rate=0.1, batch_size=128, nodes=None, type_=np.ndarray
32+
self,
33+
epochs=1,
34+
learning_rate=0.1,
35+
batch_size=128,
36+
nodes=None,
37+
type_=torch.Tensor,
3438
):
3539
self.fit(epochs, learning_rate, batch_size)
3640
return self.transform(nodes, type_)

graph_ml/models/torch_node2vec.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from torch_geometric.nn import Node2Vec as PyGNode2Vec
33
import torch
44
from ..models.node2vec import Node2Vec
5-
from ..utils import torch_utils
65

76

87
class TorchNode2Vec(Node2Vec):
@@ -21,7 +20,7 @@ def __init__(self, **params):
2120
@property
2221
def edge_index(self):
2322
# should not be called too often, no caching here
24-
return torch_utils.adj_list_to_edge_index(self.adj_list)
23+
return self.adj_list.nonzero().t().contiguous()
2524

2625
def _fit(self, epochs, learning_rate, batch_size, shuffle=True):
2726
# TODO (ashutosh): check if training two times works
@@ -47,7 +46,7 @@ def _fit(self, epochs, learning_rate, batch_size, shuffle=True):
4746
print(f"Epoch: {epoch}, Loss: {total_loss[epoch]}")
4847
return self
4948

50-
def _transform(self, nodes=None, type_=np.ndarray):
49+
def _transform(self, nodes=None, type_=torch.Tensor):
5150
self.model.eval()
5251
if nodes is None:
5352
nodes = torch.arange(self.num_nodes, device=self.device)

0 commit comments

Comments
 (0)