Skip to content

Commit 4c097fd

Browse files
author
Ashutosh Tiwari
committed
need to add spinx for doc generation
1 parent 99be972 commit 4c097fd

File tree

6 files changed

+198
-281
lines changed

6 files changed

+198
-281
lines changed

examples/sbm_sampler.ipynb

+79-277
Large diffs are not rendered by default.

graph_ml/datasets/dataset.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ def __init__(
2020

2121
@property
2222
def X(self):
23-
return self.data.loc[:, self.data.columns != self.group_col].values
23+
# can potentially be overloaded by child class
24+
return self.data.values
2425

2526
@property
2627
def y(self):
27-
return self.data[:, self.group_col].values
28+
return self.data[self.group_col].to_numpy()
2829

2930
def _set_data_df(self) -> pd.DataFrame:
3031
raise NotImplementedError

graph_ml/datasets/polbooks.py

-1
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,3 @@ def _set_data_df(self):
2828

2929
def _set_adj(self):
3030
return nx.adjacency_matrix(self.graph)
31-
pass

graph_ml/utils/utils.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
from scipy import sparse
3+
from numba import njit
4+
5+
def to_member_matrix(group_ids):
6+
"""
7+
create a member matrix U such that U[i,k] = 1 if i belongs to group k otherwise U[i,k]=0
8+
:param group_ids:
9+
:return:
10+
"""
11+
Nr = group_ids.shape[0] # equal to number of samples
12+
Nc = int(np.max(group_ids) + 1) # number of classes
13+
U = sparse.csr_matrix((np.ones_like(group_ids), (np.arange(Nr), group_ids)), shape=(Nr, Nc))
14+
U.data = U.data * 0 + 1
15+
return U
16+
17+
18+
def matrix_sum_power(A, T):
19+
"""
20+
compute the sum of the powers of the matrix A i.e.,
21+
sum_{t=1}^{T} A^t
22+
:param A:
23+
:param T:
24+
:return:
25+
"""
26+
At = np.eye(A.shape[0])
27+
As = np.zeros((A.shape[0], A.shape[0]))
28+
for _ in range(T):
29+
At = A @ At
30+
As += At
31+
return As
32+
33+
34+
@njit(nogil=True)
35+
def csr_row_cumsum(indptr, data):
36+
"""
37+
compute the cumulative sum of the data array using the indptr array
38+
:param indptr:
39+
:param data:
40+
:return:
41+
"""
42+
out = np.zeros_like(data)
43+
for i in range(len(indptr) - 1):
44+
start = indptr[i]
45+
end = indptr[i + 1]
46+
out[start:end] = np.cumsum(data[start:end])
47+
return out
48+
49+
50+

poetry.lock

+65-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ torch-geometric = "^2.5.0"
4141
# url = "https://data.pyg.org/whl/torch-2.1.2+cu121.html"
4242
# priority = "primary"
4343
gensim = "^4.3.2"
44+
numba = "^0.59.0"
4445

4546
[build-system]
4647
requires = ["poetry_core>=1.9.0"]

0 commit comments

Comments
 (0)