Skip to content

Commit 00f1b23

Browse files
author
Ashutosh Tiwari
committed
added first dataset for testing
1 parent ea4cfc8 commit 00f1b23

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

datasets/dataset.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,31 @@
11
from typing import List, Tuple
2+
3+
import pandas as pd
4+
import scipy
25
from torch_geometric.data import download_url
36

47

58
class Dataset(object):
69
def __init__(self, root: str = '/tmp/', group_col: str = None, urls: List = None) -> None:
710
self.root = root
811
self.group_col = group_col
9-
self.data = None
10-
self.adj = None
1112
self.file_paths = []
1213
urls = urls or []
1314
for url in urls:
1415
self.file_paths.append(download_url(url, root))
15-
self._set_data_df()
16-
self._set_adj()
16+
self.data = self._set_data_df()
17+
self.adj = self._set_adj()
1718

1819
@property
1920
def X(self):
2021
return self.data.loc[:, self.data.columns != self.group_col].values
2122

2223
@property
2324
def y(self):
24-
return self.data[:, self.group_col].values
25+
return self.data[:, self.group_col].values
26+
27+
def _set_data_df(self) -> pd.DataFrame:
28+
raise NotImplementedError
29+
30+
def _set_adj(self) -> scipy.sparse:
31+
raise NotImplementedError

datasets/polbooks.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pandas as pd
22
import networkx as nx
3+
from torch_geometric.data import extract_zip
34

45

56
from datasets import dataset
@@ -12,7 +13,16 @@ def __init__(self, root: str = "/tmp/"):
1213
urls=["https://websites.umich.edu/~mejn/netdata/polbooks.zip"])
1314

1415
def _set_data_df(self):
15-
pass
16+
extract_zip(self.file_paths[0], self.root)
17+
self.graph = nx.read_gml(self.root + "polbooks.gml")
18+
node_dict = dict(self.graph.nodes(data=True))
19+
df = pd.DataFrame.from_dict(node_dict, orient='index').reset_index(drop=True). \
20+
rename(columns={'value': 'political_leaning'})
21+
df['political_leaning'] = df['political_leaning'].map(
22+
{'n': 0, 'c': 1, 'l': 2}
23+
)
24+
return df
1625

1726
def _set_adj(self):
27+
return nx.adjacency_matrix(self.graph)
1828
pass

0 commit comments

Comments
 (0)