Skip to content

Commit e61138b

Browse files
[Example] CompGCN (dmlc#2768)
* compgcn * readme * readme * update * readme Co-authored-by: zhjwy9343 <[email protected]>
1 parent cfe6e70 commit e61138b

File tree

8 files changed

+858
-0
lines changed

8 files changed

+858
-0
lines changed

examples/README.md

+5
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ The folder contains example implementations of selected research papers related
8585
| [Directional Message Passing for Molecular Graphs](#dimenet) | | | :heavy_check_mark: | | |
8686
| [Link Prediction Based on Graph Neural Networks](#seal) | | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: |
8787
| [Variational Graph Auto-Encoders](#vgae) | | :heavy_check_mark: | | | |
88+
| [Composition-based Multi-Relational Graph Convolutional Networks](#compgcn)| | :heavy_check_mark: | | | |
8889
| [GNNExplainer: Generating Explanations for Graph Neural Networks](#gnnexplainer) | :heavy_check_mark: | | | | |
8990

9091
## 2020
@@ -124,6 +125,10 @@ The folder contains example implementations of selected research papers related
124125
- Example code: [Pytorch](../examples/pytorch/tgn)
125126
- Tags: over-smoothing, node classification
126127

128+
- <a name="compgcn"></a> Vashishth, Shikhar, et al. Composition-based Multi-Relational Graph Convolutional Networks. [Paper link](https://arxiv.org/abs/1911.03082).
129+
- Example code: [Pytorch](../examples/pytorch/compGCN)
130+
- Tags: multi-relational graphs, graph neural network
131+
127132
## 2019
128133

129134

examples/pytorch/compGCN/README.md

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# DGL Implementation of the CompGCN Paper
2+
3+
This DGL example implements the GNN model proposed in the paper [CompositionGCN](https://arxiv.org/abs/1911.03082).
4+
The author's codes of implementation is in [here](https://github.com/malllabiisc/CompGCN)
5+
6+
Example implementor
7+
----------------------
8+
This example was implemented by [zhjwy9343](https://github.com/zhjwy9343) and [KounianhuaDu](https://github.com/KounianhuaDu) at the AWS Shanghai AI Lab.
9+
10+
Dependencies
11+
----------------------
12+
- pytorch 1.7.1
13+
- dgl 0.6.0
14+
- numpy 1.19.4
15+
- ordered_set 4.0.2
16+
17+
Dataset
18+
---------------------------------------
19+
The datasets used for link predictions are FB15k-237 constructed from Freebase and WN18RR constructed from WordNet. The statistics are summarized as followings:
20+
21+
**FB15k-237**
22+
23+
- Nodes: 14541
24+
- Relation types: 237
25+
- Reversed relation types: 237
26+
- Train: 272115
27+
- Valid: 17535
28+
- Test: 20466
29+
30+
**WN18RR**
31+
32+
- Nodes: 40943
33+
- Relation types: 11
34+
- Reversed relation types: 11
35+
- Train: 86835
36+
- Valid: 3034
37+
- Test: 3134
38+
39+
How to run
40+
--------------------------------
41+
First to get the data, one can run
42+
43+
```python
44+
sh get_fb15k-237.sh
45+
```
46+
```python
47+
sh get_wn18rr.sh
48+
```
49+
50+
Then for FB15k-237, run
51+
52+
```python
53+
python main.py --score_func conve --opn ccorr --gpu 0 --data FB15k-237
54+
```
55+
56+
For WN18RR, run
57+
58+
```python
59+
python main.py --score_func conve --opn ccorr --gpu 0 --data wn18rr
60+
```
61+
62+
63+
Performance
64+
-------------------------
65+
**Link Prediction Results**
66+
67+
| Dataset | FB15k-237 | WN18RR |
68+
|---------| ------------------------ | ------------------------ |
69+
| Metric | Paper / ours (dgl) | Paper / ours (dgl) |
70+
| MRR | 0.355 / 0.349 | 0.479 / 0.471 |
71+
| MR | 197 / 208 | 3533 / 3550 |
72+
| Hit@10 | 0.535 / 0.526 | 0.546 / 0.532 |
73+
| Hit@3 | 0.390 / 0.381 | 0.494 / 0.480 |
74+
| Hit@1 | 0.264 / 0.260 | 0.443 / 0.438 |
75+
76+
77+
78+
+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import torch
2+
from torch.utils.data import Dataset, DataLoader
3+
import numpy as np
4+
import dgl
5+
from collections import defaultdict as ddict
6+
from ordered_set import OrderedSet
7+
8+
class TrainDataset(Dataset):
9+
"""
10+
Training Dataset class.
11+
Parameters
12+
----------
13+
triples: The triples used for training the model
14+
num_ent: Number of entities in the knowledge graph
15+
lbl_smooth: Label smoothing
16+
17+
Returns
18+
-------
19+
A training Dataset class instance used by DataLoader
20+
"""
21+
def __init__(self, triples, num_ent, lbl_smooth):
22+
self.triples = triples
23+
self.num_ent = num_ent
24+
self.lbl_smooth = lbl_smooth
25+
self.entities = np.arange(self.num_ent, dtype=np.int32)
26+
27+
def __len__(self):
28+
return len(self.triples)
29+
30+
def __getitem__(self, idx):
31+
ele = self.triples[idx]
32+
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
33+
trp_label = self.get_label(label)
34+
#label smoothing
35+
if self.lbl_smooth != 0.0:
36+
trp_label = (1.0 - self.lbl_smooth) * trp_label + (1.0 / self.num_ent)
37+
38+
return triple, trp_label
39+
40+
@staticmethod
41+
def collate_fn(data):
42+
triples = []
43+
labels = []
44+
for triple, label in data:
45+
triples.append(triple)
46+
labels.append(label)
47+
triple = torch.stack(triples, dim=0)
48+
trp_label = torch.stack(labels, dim=0)
49+
return triple, trp_label
50+
51+
#for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
52+
def get_label(self, label):
53+
y = np.zeros([self.num_ent], dtype=np.float32)
54+
for e2 in label:
55+
y[e2] = 1.0
56+
return torch.FloatTensor(y)
57+
58+
59+
class TestDataset(Dataset):
60+
"""
61+
Evaluation Dataset class.
62+
Parameters
63+
----------
64+
triples: The triples used for evaluating the model
65+
num_ent: Number of entities in the knowledge graph
66+
67+
Returns
68+
-------
69+
An evaluation Dataset class instance used by DataLoader for model evaluation
70+
"""
71+
def __init__(self, triples, num_ent):
72+
self.triples = triples
73+
self.num_ent = num_ent
74+
75+
def __len__(self):
76+
return len(self.triples)
77+
78+
def __getitem__(self, idx):
79+
ele = self.triples[idx]
80+
triple, label = torch.LongTensor(ele['triple']), np.int32(ele['label'])
81+
label = self.get_label(label)
82+
83+
return triple, label
84+
85+
@staticmethod
86+
def collate_fn(data):
87+
triples = []
88+
labels = []
89+
for triple, label in data:
90+
triples.append(triple)
91+
labels.append(label)
92+
triple = torch.stack(triples, dim=0)
93+
label = torch.stack(labels, dim=0)
94+
return triple, label
95+
96+
#for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
97+
def get_label(self, label):
98+
y = np.zeros([self.num_ent], dtype=np.float32)
99+
for e2 in label:
100+
y[e2] = 1.0
101+
return torch.FloatTensor(y)
102+
103+
104+
class Data(object):
105+
106+
def __init__(self, dataset, lbl_smooth, num_workers, batch_size):
107+
"""
108+
Reading in raw triples and converts it into a standard format.
109+
Parameters
110+
----------
111+
dataset: The name of the dataset
112+
lbl_smooth: Label smoothing
113+
num_workers: Number of workers of dataloaders
114+
batch_size: Batch size of dataloaders
115+
116+
Returns
117+
-------
118+
self.ent2id: Entity to unique identifier mapping
119+
self.rel2id: Relation to unique identifier mapping
120+
self.id2ent: Inverse mapping of self.ent2id
121+
self.id2rel: Inverse mapping of self.rel2id
122+
self.num_ent: Number of entities in the knowledge graph
123+
self.num_rel: Number of relations in the knowledge graph
124+
125+
self.g: The dgl graph constucted from the edges in the traing set and all the entities in the knowledge graph
126+
self.data['train']: Stores the triples corresponding to training dataset
127+
self.data['valid']: Stores the triples corresponding to validation dataset
128+
self.data['test']: Stores the triples corresponding to test dataset
129+
self.data_iter: The dataloader for different data splits
130+
"""
131+
self.dataset = dataset
132+
self.lbl_smooth = lbl_smooth
133+
self.num_workers = num_workers
134+
self.batch_size = batch_size
135+
136+
#read in raw data and get mappings
137+
ent_set, rel_set = OrderedSet(), OrderedSet()
138+
for split in ['train', 'test', 'valid']:
139+
for line in open('./{}/{}.txt'.format(self.dataset, split)):
140+
sub, rel, obj = map(str.lower, line.strip().split('\t'))
141+
ent_set.add(sub)
142+
rel_set.add(rel)
143+
ent_set.add(obj)
144+
145+
self.ent2id = {ent: idx for idx, ent in enumerate(ent_set)}
146+
self.rel2id = {rel: idx for idx, rel in enumerate(rel_set)}
147+
self.rel2id.update({rel+'_reverse': idx+len(self.rel2id) for idx, rel in enumerate(rel_set)})
148+
149+
self.id2ent = {idx: ent for ent, idx in self.ent2id.items()}
150+
self.id2rel = {idx: rel for rel, idx in self.rel2id.items()}
151+
152+
self.num_ent = len(self.ent2id)
153+
self.num_rel = len(self.rel2id) // 2
154+
155+
#read in ids of subjects, relations, and objects for train/test/valid
156+
self.data = ddict(list) #stores the triples
157+
sr2o = ddict(set) #The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation)
158+
src=[]
159+
dst=[]
160+
rels = []
161+
inver_src = []
162+
inver_dst = []
163+
inver_rels = []
164+
165+
for split in ['train', 'test', 'valid']:
166+
for line in open('./{}/{}.txt'.format(self.dataset, split)):
167+
sub, rel, obj = map(str.lower, line.strip().split('\t'))
168+
sub_id, rel_id, obj_id = self.ent2id[sub], self.rel2id[rel], self.ent2id[obj]
169+
self.data[split].append((sub_id, rel_id, obj_id))
170+
171+
if split == 'train':
172+
sr2o[(sub_id, rel_id)].add(obj_id)
173+
sr2o[(obj_id, rel_id+self.num_rel)].add(sub_id) #append the reversed edges
174+
src.append(sub_id)
175+
dst.append(obj_id)
176+
rels.append(rel_id)
177+
inver_src.append(obj_id)
178+
inver_dst.append(sub_id)
179+
inver_rels.append(rel_id+self.num_rel)
180+
181+
#construct dgl graph
182+
src = src + inver_src
183+
dst = dst + inver_dst
184+
rels = rels + inver_rels
185+
self.g = dgl.graph((src, dst), num_nodes=self.num_ent)
186+
self.g.edata['etype'] = torch.Tensor(rels).long()
187+
188+
#identify in and out edges
189+
in_edges_mask = [True] * (self.g.num_edges()//2) + [False] * (self.g.num_edges()//2)
190+
out_edges_mask = [False] * (self.g.num_edges()//2) + [True] * (self.g.num_edges()//2)
191+
self.g.edata['in_edges_mask'] = torch.Tensor(in_edges_mask)
192+
self.g.edata['out_edges_mask'] = torch.Tensor(out_edges_mask)
193+
194+
#Prepare train/valid/test data
195+
self.data = dict(self.data)
196+
self.sr2o = {k: list(v) for k, v in sr2o.items()} #store only the train data
197+
198+
for split in ['test', 'valid']:
199+
for sub, rel, obj in self.data[split]:
200+
sr2o[(sub, rel)].add(obj)
201+
sr2o[(obj, rel+self.num_rel)].add(sub)
202+
203+
self.sr2o_all = {k: list(v) for k, v in sr2o.items()} #store all the data
204+
self.triples = ddict(list)
205+
206+
for (sub, rel), obj in self.sr2o.items():
207+
self.triples['train'].append({'triple':(sub, rel, -1), 'label': self.sr2o[(sub, rel)]})
208+
209+
for split in ['test', 'valid']:
210+
for sub, rel, obj in self.data[split]:
211+
rel_inv = rel + self.num_rel
212+
self.triples['{}_{}'.format(split, 'tail')].append({'triple': (sub, rel, obj), 'label': self.sr2o_all[(sub, rel)]})
213+
self.triples['{}_{}'.format(split, 'head')].append({'triple': (obj, rel_inv, sub), 'label': self.sr2o_all[(obj, rel_inv)]})
214+
215+
self.triples = dict(self.triples)
216+
217+
def get_train_data_loader(split, batch_size, shuffle=True):
218+
return DataLoader(
219+
TrainDataset(self.triples[split], self.num_ent, self.lbl_smooth),
220+
batch_size = batch_size,
221+
shuffle = shuffle,
222+
num_workers = max(0, self.num_workers),
223+
collate_fn = TrainDataset.collate_fn
224+
)
225+
226+
def get_test_data_loader(split, batch_size, shuffle=True):
227+
return DataLoader(
228+
TestDataset(self.triples[split], self.num_ent),
229+
batch_size = batch_size,
230+
shuffle = shuffle,
231+
num_workers = max(0, self.num_workers),
232+
collate_fn = TestDataset.collate_fn
233+
)
234+
235+
#train/valid/test dataloaders
236+
self.data_iter = {
237+
'train': get_train_data_loader('train', self.batch_size),
238+
'valid_head': get_test_data_loader('valid_head', self.batch_size),
239+
'valid_tail': get_test_data_loader('valid_tail', self.batch_size),
240+
'test_head': get_test_data_loader('test_head', self.batch_size),
241+
'test_tail': get_test_data_loader('test_tail', self.batch_size),
242+
}
243+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
wget https://dgl-data.s3.cn-north-1.amazonaws.com.cn/dataset/FB15k-237.zip
2+
unzip FB15k-237.zip
+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
wget https://dgl-data.s3.cn-north-1.amazonaws.com.cn/dataset/wn18rr.zip
2+
unzip wn18rr.zip

0 commit comments

Comments
 (0)