1
+ import torch
2
+ from torch .utils .data import Dataset , DataLoader
3
+ import numpy as np
4
+ import dgl
5
+ from collections import defaultdict as ddict
6
+ from ordered_set import OrderedSet
7
+
8
+ class TrainDataset (Dataset ):
9
+ """
10
+ Training Dataset class.
11
+ Parameters
12
+ ----------
13
+ triples: The triples used for training the model
14
+ num_ent: Number of entities in the knowledge graph
15
+ lbl_smooth: Label smoothing
16
+
17
+ Returns
18
+ -------
19
+ A training Dataset class instance used by DataLoader
20
+ """
21
+ def __init__ (self , triples , num_ent , lbl_smooth ):
22
+ self .triples = triples
23
+ self .num_ent = num_ent
24
+ self .lbl_smooth = lbl_smooth
25
+ self .entities = np .arange (self .num_ent , dtype = np .int32 )
26
+
27
+ def __len__ (self ):
28
+ return len (self .triples )
29
+
30
+ def __getitem__ (self , idx ):
31
+ ele = self .triples [idx ]
32
+ triple , label = torch .LongTensor (ele ['triple' ]), np .int32 (ele ['label' ])
33
+ trp_label = self .get_label (label )
34
+ #label smoothing
35
+ if self .lbl_smooth != 0.0 :
36
+ trp_label = (1.0 - self .lbl_smooth ) * trp_label + (1.0 / self .num_ent )
37
+
38
+ return triple , trp_label
39
+
40
+ @staticmethod
41
+ def collate_fn (data ):
42
+ triples = []
43
+ labels = []
44
+ for triple , label in data :
45
+ triples .append (triple )
46
+ labels .append (label )
47
+ triple = torch .stack (triples , dim = 0 )
48
+ trp_label = torch .stack (labels , dim = 0 )
49
+ return triple , trp_label
50
+
51
+ #for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
52
+ def get_label (self , label ):
53
+ y = np .zeros ([self .num_ent ], dtype = np .float32 )
54
+ for e2 in label :
55
+ y [e2 ] = 1.0
56
+ return torch .FloatTensor (y )
57
+
58
+
59
+ class TestDataset (Dataset ):
60
+ """
61
+ Evaluation Dataset class.
62
+ Parameters
63
+ ----------
64
+ triples: The triples used for evaluating the model
65
+ num_ent: Number of entities in the knowledge graph
66
+
67
+ Returns
68
+ -------
69
+ An evaluation Dataset class instance used by DataLoader for model evaluation
70
+ """
71
+ def __init__ (self , triples , num_ent ):
72
+ self .triples = triples
73
+ self .num_ent = num_ent
74
+
75
+ def __len__ (self ):
76
+ return len (self .triples )
77
+
78
+ def __getitem__ (self , idx ):
79
+ ele = self .triples [idx ]
80
+ triple , label = torch .LongTensor (ele ['triple' ]), np .int32 (ele ['label' ])
81
+ label = self .get_label (label )
82
+
83
+ return triple , label
84
+
85
+ @staticmethod
86
+ def collate_fn (data ):
87
+ triples = []
88
+ labels = []
89
+ for triple , label in data :
90
+ triples .append (triple )
91
+ labels .append (label )
92
+ triple = torch .stack (triples , dim = 0 )
93
+ label = torch .stack (labels , dim = 0 )
94
+ return triple , label
95
+
96
+ #for edges that exist in the graph, the entry is 1.0, otherwise the entry is 0.0
97
+ def get_label (self , label ):
98
+ y = np .zeros ([self .num_ent ], dtype = np .float32 )
99
+ for e2 in label :
100
+ y [e2 ] = 1.0
101
+ return torch .FloatTensor (y )
102
+
103
+
104
+ class Data (object ):
105
+
106
+ def __init__ (self , dataset , lbl_smooth , num_workers , batch_size ):
107
+ """
108
+ Reading in raw triples and converts it into a standard format.
109
+ Parameters
110
+ ----------
111
+ dataset: The name of the dataset
112
+ lbl_smooth: Label smoothing
113
+ num_workers: Number of workers of dataloaders
114
+ batch_size: Batch size of dataloaders
115
+
116
+ Returns
117
+ -------
118
+ self.ent2id: Entity to unique identifier mapping
119
+ self.rel2id: Relation to unique identifier mapping
120
+ self.id2ent: Inverse mapping of self.ent2id
121
+ self.id2rel: Inverse mapping of self.rel2id
122
+ self.num_ent: Number of entities in the knowledge graph
123
+ self.num_rel: Number of relations in the knowledge graph
124
+
125
+ self.g: The dgl graph constucted from the edges in the traing set and all the entities in the knowledge graph
126
+ self.data['train']: Stores the triples corresponding to training dataset
127
+ self.data['valid']: Stores the triples corresponding to validation dataset
128
+ self.data['test']: Stores the triples corresponding to test dataset
129
+ self.data_iter: The dataloader for different data splits
130
+ """
131
+ self .dataset = dataset
132
+ self .lbl_smooth = lbl_smooth
133
+ self .num_workers = num_workers
134
+ self .batch_size = batch_size
135
+
136
+ #read in raw data and get mappings
137
+ ent_set , rel_set = OrderedSet (), OrderedSet ()
138
+ for split in ['train' , 'test' , 'valid' ]:
139
+ for line in open ('./{}/{}.txt' .format (self .dataset , split )):
140
+ sub , rel , obj = map (str .lower , line .strip ().split ('\t ' ))
141
+ ent_set .add (sub )
142
+ rel_set .add (rel )
143
+ ent_set .add (obj )
144
+
145
+ self .ent2id = {ent : idx for idx , ent in enumerate (ent_set )}
146
+ self .rel2id = {rel : idx for idx , rel in enumerate (rel_set )}
147
+ self .rel2id .update ({rel + '_reverse' : idx + len (self .rel2id ) for idx , rel in enumerate (rel_set )})
148
+
149
+ self .id2ent = {idx : ent for ent , idx in self .ent2id .items ()}
150
+ self .id2rel = {idx : rel for rel , idx in self .rel2id .items ()}
151
+
152
+ self .num_ent = len (self .ent2id )
153
+ self .num_rel = len (self .rel2id ) // 2
154
+
155
+ #read in ids of subjects, relations, and objects for train/test/valid
156
+ self .data = ddict (list ) #stores the triples
157
+ sr2o = ddict (set ) #The key of sr20 is (subject, relation), and the items are all the successors following (subject, relation)
158
+ src = []
159
+ dst = []
160
+ rels = []
161
+ inver_src = []
162
+ inver_dst = []
163
+ inver_rels = []
164
+
165
+ for split in ['train' , 'test' , 'valid' ]:
166
+ for line in open ('./{}/{}.txt' .format (self .dataset , split )):
167
+ sub , rel , obj = map (str .lower , line .strip ().split ('\t ' ))
168
+ sub_id , rel_id , obj_id = self .ent2id [sub ], self .rel2id [rel ], self .ent2id [obj ]
169
+ self .data [split ].append ((sub_id , rel_id , obj_id ))
170
+
171
+ if split == 'train' :
172
+ sr2o [(sub_id , rel_id )].add (obj_id )
173
+ sr2o [(obj_id , rel_id + self .num_rel )].add (sub_id ) #append the reversed edges
174
+ src .append (sub_id )
175
+ dst .append (obj_id )
176
+ rels .append (rel_id )
177
+ inver_src .append (obj_id )
178
+ inver_dst .append (sub_id )
179
+ inver_rels .append (rel_id + self .num_rel )
180
+
181
+ #construct dgl graph
182
+ src = src + inver_src
183
+ dst = dst + inver_dst
184
+ rels = rels + inver_rels
185
+ self .g = dgl .graph ((src , dst ), num_nodes = self .num_ent )
186
+ self .g .edata ['etype' ] = torch .Tensor (rels ).long ()
187
+
188
+ #identify in and out edges
189
+ in_edges_mask = [True ] * (self .g .num_edges ()// 2 ) + [False ] * (self .g .num_edges ()// 2 )
190
+ out_edges_mask = [False ] * (self .g .num_edges ()// 2 ) + [True ] * (self .g .num_edges ()// 2 )
191
+ self .g .edata ['in_edges_mask' ] = torch .Tensor (in_edges_mask )
192
+ self .g .edata ['out_edges_mask' ] = torch .Tensor (out_edges_mask )
193
+
194
+ #Prepare train/valid/test data
195
+ self .data = dict (self .data )
196
+ self .sr2o = {k : list (v ) for k , v in sr2o .items ()} #store only the train data
197
+
198
+ for split in ['test' , 'valid' ]:
199
+ for sub , rel , obj in self .data [split ]:
200
+ sr2o [(sub , rel )].add (obj )
201
+ sr2o [(obj , rel + self .num_rel )].add (sub )
202
+
203
+ self .sr2o_all = {k : list (v ) for k , v in sr2o .items ()} #store all the data
204
+ self .triples = ddict (list )
205
+
206
+ for (sub , rel ), obj in self .sr2o .items ():
207
+ self .triples ['train' ].append ({'triple' :(sub , rel , - 1 ), 'label' : self .sr2o [(sub , rel )]})
208
+
209
+ for split in ['test' , 'valid' ]:
210
+ for sub , rel , obj in self .data [split ]:
211
+ rel_inv = rel + self .num_rel
212
+ self .triples ['{}_{}' .format (split , 'tail' )].append ({'triple' : (sub , rel , obj ), 'label' : self .sr2o_all [(sub , rel )]})
213
+ self .triples ['{}_{}' .format (split , 'head' )].append ({'triple' : (obj , rel_inv , sub ), 'label' : self .sr2o_all [(obj , rel_inv )]})
214
+
215
+ self .triples = dict (self .triples )
216
+
217
+ def get_train_data_loader (split , batch_size , shuffle = True ):
218
+ return DataLoader (
219
+ TrainDataset (self .triples [split ], self .num_ent , self .lbl_smooth ),
220
+ batch_size = batch_size ,
221
+ shuffle = shuffle ,
222
+ num_workers = max (0 , self .num_workers ),
223
+ collate_fn = TrainDataset .collate_fn
224
+ )
225
+
226
+ def get_test_data_loader (split , batch_size , shuffle = True ):
227
+ return DataLoader (
228
+ TestDataset (self .triples [split ], self .num_ent ),
229
+ batch_size = batch_size ,
230
+ shuffle = shuffle ,
231
+ num_workers = max (0 , self .num_workers ),
232
+ collate_fn = TestDataset .collate_fn
233
+ )
234
+
235
+ #train/valid/test dataloaders
236
+ self .data_iter = {
237
+ 'train' : get_train_data_loader ('train' , self .batch_size ),
238
+ 'valid_head' : get_test_data_loader ('valid_head' , self .batch_size ),
239
+ 'valid_tail' : get_test_data_loader ('valid_tail' , self .batch_size ),
240
+ 'test_head' : get_test_data_loader ('test_head' , self .batch_size ),
241
+ 'test_tail' : get_test_data_loader ('test_tail' , self .batch_size ),
242
+ }
243
+
0 commit comments