Skip to content

Commit 583aa76

Browse files
authored
[Example] Pytorch Seal example (dmlc#2638)
* add seal example * 1. add paper infomation in examples/README 2. adjust codes 3. option test * use latest `to_simple` to replace coalesce graph function * remove outdated codes * remove useless comment
1 parent 0526b88 commit 583aa76

File tree

7 files changed

+932
-0
lines changed

7 files changed

+932
-0
lines changed

examples/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ The folder contains example implementations of selected research papers related
7878
| [Dynamic Graph CNN for Learning on Point Clouds](#dgcnnpoint) | | | | | |
7979
| [Supervised Community Detection with Line Graph Neural Networks](#lgnn) | | | | | |
8080
| [Text Generation from Knowledge Graphs with Graph Transformers](#graphwriter) | | | | | |
81+
| [Link Prediction Based on Graph Neural Networks](#seal) | | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: |
8182

8283

8384
## 2020
@@ -239,6 +240,11 @@ The folder contains example implementations of selected research papers related
239240
- Pooling module: [PyTorch](https://docs.dgl.ai/api/python/nn.pytorch.html#sortpooling), [TensorFlow](https://docs.dgl.ai/api/python/nn.tensorflow.html#sortpooling), [MXNet](https://docs.dgl.ai/api/python/nn.mxnet.html#sortpooling)
240241
- Tags: graph classification
241242

243+
- <a name="seal"></a> Zhang et al. Link Prediction Based on Graph Neural Networks. [Paper link](https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf).
244+
- Example code: [pytorch](../examples/pytorch/seal)
245+
- Tags: link prediction, sampling
246+
247+
242248
## 2017
243249

244250
- <a name="gcn"></a> Kipf and Welling. Semi-Supervised Classification with Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1609.02907).

examples/pytorch/seal/README.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# DGL Implementation of the SEAL Paper
2+
This DGL example implements the link prediction model proposed in the paper
3+
[Link Prediction Based on Graph Neural Networks](https://arxiv.org/pdf/1802.09691.pdf)
4+
and [REVISITING GRAPH NEURAL NETWORKS FOR LINK PREDICTION](https://arxiv.org/pdf/2010.16103.pdf)
5+
The author's codes of implementation is in [SEAL](https://github.com/muhanzhang/SEAL) (pytorch)
6+
and [SEAL_ogb](https://github.com/facebookresearch/SEAL_OGB) (torch_geometric)
7+
8+
Example implementor
9+
----------------------
10+
This example was implemented by [Smile](https://github.com/Smilexuhc) during his intern work at the AWS Shanghai AI Lab.
11+
12+
The graph dataset used in this example
13+
---------------------------------------
14+
15+
ogbl-collab
16+
- NumNodes: 235868
17+
- NumEdges: 2358104
18+
- NumNodeFeats: 128
19+
- NumEdgeWeights: 1
20+
- NumValidEdges: 160084
21+
- NumTestEdges: 146329
22+
23+
Dependencies
24+
--------------------------------
25+
26+
- python 3.6+
27+
- Pytorch 1.5.0+
28+
- dgl 0.6.0 +
29+
- ogb
30+
- pandas
31+
- tqdm
32+
- scipy
33+
34+
35+
How to run example files
36+
--------------------------------
37+
In the seal_dgl folder
38+
run on cpu:
39+
```shell script
40+
python main.py --gpu_id=-1 --subsample_ratio=0.1
41+
```
42+
run on gpu:
43+
```shell script
44+
python main.py --gpu_id=0 --subsample_ratio=0.1
45+
```
46+
47+
Performance
48+
-------------------------
49+
experiment on `ogbl-collab`
50+
51+
| method | valid-hits@50 | test-hits@50 |
52+
| ------ | ------------- | ------------ |
53+
| paper | 63.89(0.49) | 53.71(0.47) |
54+
| ours | 63.56(0.71) | 53.61(0.78) |
55+
56+
Note: We only perform 5 trails in the experiment.

examples/pytorch/seal/logger.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import logging
2+
import time
3+
import os
4+
5+
6+
def _transform_log_level(str_level):
7+
if str_level == 'info':
8+
return logging.INFO
9+
elif str_level == 'warning':
10+
return logging.WARNING
11+
elif str_level == 'critical':
12+
return logging.CRITICAL
13+
elif str_level == 'debug':
14+
return logging.DEBUG
15+
elif str_level == 'error':
16+
return logging.ERROR
17+
else:
18+
raise KeyError('Log level error')
19+
20+
21+
class LightLogging(object):
22+
def __init__(self, log_path=None, log_name='lightlog', log_level='debug'):
23+
24+
log_level = _transform_log_level(log_level)
25+
26+
if log_path:
27+
if not log_path.endswith('/'):
28+
log_path += '/'
29+
if not os.path.exists(log_path):
30+
os.mkdir(log_path)
31+
32+
if log_name.endswith('-') or log_name.endswith('_'):
33+
log_name = log_path+log_name + time.strftime('%Y-%m-%d-%H:%M', time.localtime(time.time())) + '.log'
34+
else:
35+
log_name = log_path+log_name + '_' + time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time())) + '.log'
36+
37+
logging.basicConfig(level=log_level,
38+
format="%(asctime)s %(levelname)s: %(message)s",
39+
datefmt='%Y-%m-%d-%H:%M',
40+
handlers=[
41+
logging.FileHandler(log_name, mode='w'),
42+
logging.StreamHandler()
43+
])
44+
logging.info('Start Logging')
45+
logging.info('Log file path: {}'.format(log_name))
46+
47+
else:
48+
logging.basicConfig(level=log_level,
49+
format="%(asctime)s %(levelname)s: %(message)s",
50+
datefmt='%Y-%m-%d-%H:%M',
51+
handlers=[
52+
logging.StreamHandler()
53+
])
54+
logging.info('Start Logging')
55+
56+
def debug(self, msg):
57+
logging.debug(msg)
58+
59+
def info(self, msg):
60+
logging.info(msg)
61+
62+
def critical(self, msg):
63+
logging.critical(msg)
64+
65+
def warning(self, msg):
66+
logging.warning(msg)
67+
68+
def error(self, msg):
69+
logging.error(msg)

examples/pytorch/seal/main.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import time
2+
from tqdm import tqdm
3+
import numpy as np
4+
import torch
5+
from torch.nn import BCEWithLogitsLoss
6+
from dgl import NID, EID
7+
from dgl.dataloading import GraphDataLoader
8+
from utils import parse_arguments
9+
from utils import load_ogb_dataset, evaluate_hits
10+
from sampler import SEALData
11+
from model import GCN, DGCNN
12+
from logger import LightLogging
13+
14+
'''
15+
Part of the code are adapted from
16+
https://github.com/facebookresearch/SEAL_OGB
17+
'''
18+
19+
20+
def train(model, dataloader, loss_fn, optimizer, device, num_graphs=32, total_graphs=None):
21+
model.train()
22+
23+
total_loss = 0
24+
for g, labels in tqdm(dataloader, ncols=100):
25+
g = g.to(device)
26+
labels = labels.to(device)
27+
optimizer.zero_grad()
28+
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID])
29+
loss = loss_fn(logits, labels)
30+
loss.backward()
31+
optimizer.step()
32+
total_loss += loss.item() * num_graphs
33+
34+
return total_loss / total_graphs
35+
36+
37+
@torch.no_grad()
38+
def evaluate(model, dataloader, device):
39+
model.eval()
40+
41+
y_pred, y_true = [], []
42+
for g, labels in tqdm(dataloader, ncols=100):
43+
g = g.to(device)
44+
logits = model(g, g.ndata['z'], g.ndata[NID], g.edata[EID])
45+
y_pred.append(logits.view(-1).cpu())
46+
y_true.append(labels.view(-1).cpu().to(torch.float))
47+
48+
y_pred, y_true = torch.cat(y_pred), torch.cat(y_true)
49+
pos_pred = y_pred[y_true == 1]
50+
neg_pred = y_pred[y_true == 0]
51+
52+
return pos_pred, neg_pred
53+
54+
55+
def main(args, print_fn=print):
56+
print_fn("Experiment arguments: {}".format(args))
57+
58+
if args.random_seed:
59+
torch.manual_seed(args.random_seed)
60+
else:
61+
torch.manual_seed(123)
62+
# Load dataset
63+
if args.dataset.startswith('ogbl'):
64+
graph, split_edge = load_ogb_dataset(args.dataset)
65+
else:
66+
raise NotImplementedError
67+
68+
num_nodes = graph.num_nodes()
69+
70+
# set gpu
71+
if args.gpu_id >= 0 and torch.cuda.is_available():
72+
device = 'cuda:{}'.format(args.gpu_id)
73+
else:
74+
device = 'cpu'
75+
76+
if args.dataset == 'ogbl-collab':
77+
# ogbl-collab dataset is multi-edge graph
78+
use_coalesce = True
79+
else:
80+
use_coalesce = False
81+
82+
# Generate positive and negative edges and corresponding labels
83+
# Sampling subgraphs and generate node labeling features
84+
seal_data = SEALData(g=graph, split_edge=split_edge, hop=args.hop, neg_samples=args.neg_samples,
85+
subsample_ratio=args.subsample_ratio, use_coalesce=use_coalesce, prefix=args.dataset,
86+
save_dir=args.save_dir, num_workers=args.num_workers, print_fn=print_fn)
87+
node_attribute = seal_data.ndata['feat']
88+
edge_weight = seal_data.edata['edge_weight'].float()
89+
90+
train_data = seal_data('train')
91+
val_data = seal_data('valid')
92+
test_data = seal_data('test')
93+
94+
train_graphs = len(train_data.graph_list)
95+
96+
# Set data loader
97+
98+
train_loader = GraphDataLoader(train_data, batch_size=args.batch_size, num_workers=args.num_workers)
99+
val_loader = GraphDataLoader(val_data, batch_size=args.batch_size, num_workers=args.num_workers)
100+
test_loader = GraphDataLoader(test_data, batch_size=args.batch_size, num_workers=args.num_workers)
101+
102+
# set model
103+
if args.model == 'gcn':
104+
model = GCN(num_layers=args.num_layers,
105+
hidden_units=args.hidden_units,
106+
gcn_type=args.gcn_type,
107+
pooling_type=args.pooling,
108+
node_attributes=node_attribute,
109+
edge_weights=edge_weight,
110+
node_embedding=None,
111+
use_embedding=True,
112+
num_nodes=num_nodes,
113+
dropout=args.dropout)
114+
elif args.model == 'dgcnn':
115+
model = DGCNN(num_layers=args.num_layers,
116+
hidden_units=args.hidden_units,
117+
k=args.sort_k,
118+
gcn_type=args.gcn_type,
119+
node_attributes=node_attribute,
120+
edge_weights=edge_weight,
121+
node_embedding=None,
122+
use_embedding=True,
123+
num_nodes=num_nodes,
124+
dropout=args.dropout)
125+
else:
126+
raise ValueError('Model error')
127+
128+
model = model.to(device)
129+
parameters = model.parameters()
130+
optimizer = torch.optim.Adam(parameters, lr=args.lr)
131+
loss_fn = BCEWithLogitsLoss()
132+
print_fn("Total parameters: {}".format(sum([p.numel() for p in model.parameters()])))
133+
134+
# train and evaluate loop
135+
summary_val = []
136+
summary_test = []
137+
for epoch in range(args.epochs):
138+
start_time = time.time()
139+
loss = train(model=model,
140+
dataloader=train_loader,
141+
loss_fn=loss_fn,
142+
optimizer=optimizer,
143+
device=device,
144+
num_graphs=args.batch_size,
145+
total_graphs=train_graphs)
146+
train_time = time.time()
147+
if epoch % args.eval_steps == 0:
148+
val_pos_pred, val_neg_pred = evaluate(model=model,
149+
dataloader=val_loader,
150+
device=device)
151+
test_pos_pred, test_neg_pred = evaluate(model=model,
152+
dataloader=test_loader,
153+
device=device)
154+
155+
val_metric = evaluate_hits(args.dataset, val_pos_pred, val_neg_pred, args.hits_k)
156+
test_metric = evaluate_hits(args.dataset, test_pos_pred, test_neg_pred, args.hits_k)
157+
evaluate_time = time.time()
158+
print_fn("Epoch-{}, train loss: {:.4f}, hits@{}: val-{:.4f}, test-{:.4f}, "
159+
"cost time: train-{:.1f}s, total-{:.1f}s".format(epoch, loss, args.hits_k, val_metric, test_metric,
160+
train_time - start_time,
161+
evaluate_time - start_time))
162+
summary_val.append(val_metric)
163+
summary_test.append(test_metric)
164+
165+
summary_test = np.array(summary_test)
166+
167+
print_fn("Experiment Results:")
168+
print_fn("Best hits@{}: {:.4f}, epoch: {}".format(args.hits_k, np.max(summary_test), np.argmax(summary_test)))
169+
170+
171+
if __name__ == '__main__':
172+
args = parse_arguments()
173+
logger = LightLogging(log_name='SEAL', log_path='./logs')
174+
main(args, logger.info)

0 commit comments

Comments
 (0)