Skip to content

Commit 2090d72

Browse files
committed
feat: graph saint sampler
1 parent b2fb118 commit 2090d72

File tree

4 files changed

+45
-23
lines changed

4 files changed

+45
-23
lines changed

biomedkg/data_module.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch_geometric.transforms as T
22
from lightning import LightningDataModule
3-
from torch_geometric.loader import NeighborLoader
3+
from torch_geometric.loader import GraphSAINTRandomWalkSampler, NeighborLoader
44

55
from biomedkg.data import dataset, node
66

@@ -86,22 +86,44 @@ def all_dataloader(self):
8686
num_workers=0,
8787
)
8888

89-
def train_dataloader(self):
90-
return NeighborLoader(
91-
data=self.train_data,
92-
batch_size=self.batch_size,
93-
num_neighbors=[30] * 3,
94-
num_workers=0,
95-
shuffle=True,
96-
)
97-
98-
def val_dataloader(self):
99-
return NeighborLoader(
100-
data=self.val_data,
101-
batch_size=self.batch_size,
102-
num_neighbors=[30] * 3,
103-
num_workers=0,
104-
)
89+
def train_dataloader(self, loader_type: str = "neighbor"):
90+
assert loader_type in ["neighbor", "saint"]
91+
92+
if loader_type == "neighbor":
93+
return NeighborLoader(
94+
data=self.train_data,
95+
batch_size=self.batch_size,
96+
num_neighbors=[30] * 3,
97+
num_workers=0,
98+
shuffle=True,
99+
)
100+
elif loader_type == "saint":
101+
return GraphSAINTRandomWalkSampler(
102+
data=self.train_data,
103+
batch_size=self.batch_size,
104+
walk_length=10,
105+
num_steps=1000,
106+
num_workers=0,
107+
)
108+
109+
def val_dataloader(self, loader_type: str = "neighbor"):
110+
assert loader_type in ["neighbor", "saint"]
111+
112+
if loader_type == "neighbor":
113+
return NeighborLoader(
114+
data=self.val_data,
115+
batch_size=self.batch_size,
116+
num_neighbors=[30] * 3,
117+
num_workers=0,
118+
)
119+
elif loader_type == "saint":
120+
return GraphSAINTRandomWalkSampler(
121+
data=self.val_data,
122+
batch_size=self.batch_size,
123+
walk_length=10,
124+
num_steps=100,
125+
num_workers=0,
126+
)
105127

106128
def test_dataloader(self):
107129
return NeighborLoader(

train_dpi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def main(cfg: DictConfig):
9494

9595
trainer.fit(
9696
model=model,
97-
train_dataloaders=data_module.train_dataloader(),
98-
val_dataloaders=data_module.val_dataloader(),
97+
train_dataloaders=data_module.train_dataloader(loader_type="saint"),
98+
val_dataloaders=data_module.val_dataloader(loader_type="saint"),
9999
)
100100

101101
test_args = {

train_gcl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def main(cfg: DictConfig):
107107

108108
trainer.fit(
109109
model=model,
110-
train_dataloaders=data_module.train_dataloader(),
111-
val_dataloaders=data_module.val_dataloader(),
110+
train_dataloaders=data_module.train_dataloader(loader_type="neighbor"),
111+
val_dataloaders=data_module.val_dataloader(loader_type="neighbor"),
112112
)
113113

114114
test_args = {

train_kge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def main(cfg: DictConfig):
8787

8888
trainer.fit(
8989
model=model,
90-
train_dataloaders=data_module.train_dataloader(),
91-
val_dataloaders=data_module.val_dataloader(),
90+
train_dataloaders=data_module.train_dataloader(loader_type="saint"),
91+
val_dataloaders=data_module.val_dataloader(loader_type="saint"),
9292
)
9393

9494
test_args = {

0 commit comments

Comments
 (0)