6
6
import dgl .function as fn
7
7
from dgl .nn .pytorch import GATConv
8
8
9
- #Semantic attention in the metapath-based aggregation (the same as that in the HAN)
9
+
10
+ # Semantic attention in the metapath-based aggregation (the same as that in the HAN)
10
11
class SemanticAttention (nn .Module ):
11
12
def __init__ (self , in_size , hidden_size = 128 ):
12
13
super (SemanticAttention , self ).__init__ ()
13
14
14
15
self .project = nn .Sequential (
15
16
nn .Linear (in_size , hidden_size ),
16
17
nn .Tanh (),
17
- nn .Linear (hidden_size , 1 , bias = False )
18
+ nn .Linear (hidden_size , 1 , bias = False ),
18
19
)
19
20
20
21
def forward (self , z ):
21
- '''
22
+ """
22
23
Shape of z: (N, M , D*K)
23
24
N: number of nodes
24
25
M: number of metapath patterns
25
26
D: hidden_size
26
27
K: number of heads
27
- '''
28
- w = self .project (z ).mean (0 ) # (M, 1)
29
- beta = torch .softmax (w , dim = 0 ) # (M, 1)
30
- beta = beta .expand ((z .shape [0 ],) + beta .shape ) # (N, M, 1)
28
+ """
29
+ w = self .project (z ).mean (0 ) # (M, 1)
30
+ beta = torch .softmax (w , dim = 0 ) # (M, 1)
31
+ beta = beta .expand ((z .shape [0 ],) + beta .shape ) # (N, M, 1)
32
+
33
+ return (beta * z ).sum (1 ) # (N, D * K)
31
34
32
- return (beta * z ).sum (1 ) # (N, D * K)
33
35
34
- #Metapath-based aggregation (the same as the HANLayer)
36
+ # Metapath-based aggregation (the same as the HANLayer)
35
37
class HANLayer (nn .Module ):
36
- def __init__ (self , meta_path_patterns , in_size , out_size , layer_num_heads , dropout ):
38
+ def __init__ (
39
+ self , meta_path_patterns , in_size , out_size , layer_num_heads , dropout
40
+ ):
37
41
super (HANLayer , self ).__init__ ()
38
42
39
43
# One GAT layer for each meta path based adjacency matrix
40
44
self .gat_layers = nn .ModuleList ()
41
45
for i in range (len (meta_path_patterns )):
42
- self .gat_layers .append (GATConv (in_size , out_size , layer_num_heads ,
43
- dropout , dropout , activation = F .elu ,
44
- allow_zero_in_degree = True ))
45
- self .semantic_attention = SemanticAttention (in_size = out_size * layer_num_heads )
46
- self .meta_path_patterns = list (tuple (meta_path_pattern ) for meta_path_pattern in meta_path_patterns )
46
+ self .gat_layers .append (
47
+ GATConv (
48
+ in_size ,
49
+ out_size ,
50
+ layer_num_heads ,
51
+ dropout ,
52
+ dropout ,
53
+ activation = F .elu ,
54
+ allow_zero_in_degree = True ,
55
+ )
56
+ )
57
+ self .semantic_attention = SemanticAttention (
58
+ in_size = out_size * layer_num_heads
59
+ )
60
+ self .meta_path_patterns = list (
61
+ tuple (meta_path_pattern ) for meta_path_pattern in meta_path_patterns
62
+ )
47
63
48
64
self ._cached_graph = None
49
65
self ._cached_coalesced_graph = {}
50
66
51
67
def forward (self , g , h ):
52
68
semantic_embeddings = []
53
- #obtain metapath reachable graph
69
+ # obtain metapath reachable graph
54
70
if self ._cached_graph is None or self ._cached_graph is not g :
55
71
self ._cached_graph = g
56
72
self ._cached_coalesced_graph .clear ()
57
73
for meta_path_pattern in self .meta_path_patterns :
58
- self ._cached_coalesced_graph [meta_path_pattern ] = dgl .metapath_reachable_graph (
59
- g , meta_path_pattern )
74
+ self ._cached_coalesced_graph [
75
+ meta_path_pattern
76
+ ] = dgl .metapath_reachable_graph (g , meta_path_pattern )
60
77
61
78
for i , meta_path_pattern in enumerate (self .meta_path_patterns ):
62
79
new_g = self ._cached_coalesced_graph [meta_path_pattern ]
63
80
semantic_embeddings .append (self .gat_layers [i ](new_g , h ).flatten (1 ))
64
- semantic_embeddings = torch .stack (semantic_embeddings , dim = 1 ) # (N, M, D * K)
81
+ semantic_embeddings = torch .stack (
82
+ semantic_embeddings , dim = 1
83
+ ) # (N, M, D * K)
84
+
85
+ return self .semantic_attention (semantic_embeddings ) # (N, D * K)
65
86
66
- return self .semantic_attention (semantic_embeddings ) # (N, D * K)
67
87
68
- #Relational neighbor aggregation
88
+ # Relational neighbor aggregation
69
89
class RelationalAGG (nn .Module ):
70
90
def __init__ (self , g , in_size , out_size , dropout = 0.1 ):
71
91
super (RelationalAGG , self ).__init__ ()
72
92
self .in_size = in_size
73
93
self .out_size = out_size
74
94
75
- #Transform weights for different types of edges
76
- self .W_T = nn .ModuleDict ({
77
- name : nn .Linear (in_size , out_size , bias = False ) for name in g .etypes
78
- })
95
+ # Transform weights for different types of edges
96
+ self .W_T = nn .ModuleDict (
97
+ {
98
+ name : nn .Linear (in_size , out_size , bias = False )
99
+ for name in g .etypes
100
+ }
101
+ )
79
102
80
- #Attention weights for different types of edges
81
- self .W_A = nn .ModuleDict ({
82
- name : nn .Linear (out_size , 1 , bias = False ) for name in g .etypes
83
- } )
103
+ # Attention weights for different types of edges
104
+ self .W_A = nn .ModuleDict (
105
+ { name : nn .Linear (out_size , 1 , bias = False ) for name in g .etypes }
106
+ )
84
107
85
- #layernorm
108
+ # layernorm
86
109
self .layernorm = nn .LayerNorm (out_size )
87
110
88
- #dropout layer
111
+ # dropout layer
89
112
self .dropout = nn .Dropout (dropout )
90
113
91
114
def forward (self , g , feat_dict ):
92
- funcs = {}
115
+ funcs = {}
93
116
for srctype , etype , dsttype in g .canonical_etypes :
94
- g .nodes [dsttype ].data ['h' ] = feat_dict [dsttype ] #nodes' original feature
95
- g .nodes [srctype ].data ['h' ] = feat_dict [srctype ]
96
- g .nodes [srctype ].data ['t_h' ] = self .W_T [etype ](feat_dict [srctype ]) #src nodes' transformed feature
97
-
98
- #compute the attention numerator (exp)
99
- g .apply_edges (fn .u_mul_v ('t_h' ,'h' ,'x' ),etype = etype )
100
- g .edges [etype ].data ['x' ] = torch .exp (self .W_A [etype ](g .edges [etype ].data ['x' ]))
101
-
102
- #first update to compute the attention denominator (\sum exp)
103
- funcs [etype ] = (fn .copy_e ('x' , 'm' ), fn .sum ('m' , 'att' ))
104
- g .multi_update_all (funcs , 'sum' )
105
-
106
- funcs = {}
117
+ g .nodes [dsttype ].data ["h" ] = feat_dict [
118
+ dsttype
119
+ ] # nodes' original feature
120
+ g .nodes [srctype ].data ["h" ] = feat_dict [srctype ]
121
+ g .nodes [srctype ].data ["t_h" ] = self .W_T [etype ](
122
+ feat_dict [srctype ]
123
+ ) # src nodes' transformed feature
124
+
125
+ # compute the attention numerator (exp)
126
+ g .apply_edges (fn .u_mul_v ("t_h" , "h" , "x" ), etype = etype )
127
+ g .edges [etype ].data ["x" ] = torch .exp (
128
+ self .W_A [etype ](g .edges [etype ].data ["x" ])
129
+ )
130
+
131
+ # first update to compute the attention denominator (\sum exp)
132
+ funcs [etype ] = (fn .copy_e ("x" , "m" ), fn .sum ("m" , "att" ))
133
+ g .multi_update_all (funcs , "sum" )
134
+
135
+ funcs = {}
107
136
for srctype , etype , dsttype in g .canonical_etypes :
108
- g .apply_edges (fn .e_div_v ('x' , 'att' , 'att' ),etype = etype ) #compute attention weights (numerator/denominator)
109
- funcs [etype ] = (fn .u_mul_e ('h' , 'att' , 'm' ), fn .sum ('m' , 'h' )) #\sum(h0*att) -> h1
110
- #second update to obtain h1
111
- g .multi_update_all (funcs , 'sum' )
112
-
113
- #apply activation, layernorm, and dropout
114
- feat_dict = {}
137
+ g .apply_edges (
138
+ fn .e_div_v ("x" , "att" , "att" ), etype = etype
139
+ ) # compute attention weights (numerator/denominator)
140
+ funcs [etype ] = (
141
+ fn .u_mul_e ("h" , "att" , "m" ),
142
+ fn .sum ("m" , "h" ),
143
+ ) # \sum(h0*att) -> h1
144
+ # second update to obtain h1
145
+ g .multi_update_all (funcs , "sum" )
146
+
147
+ # apply activation, layernorm, and dropout
148
+ feat_dict = {}
115
149
for ntype in g .ntypes :
116
- feat_dict [ntype ] = self .dropout (self .layernorm (F .relu_ (g .nodes [ntype ].data ['h' ]))) #apply activation, layernorm, and dropout
117
-
150
+ feat_dict [ntype ] = self .dropout (
151
+ self .layernorm (F .relu_ (g .nodes [ntype ].data ["h" ]))
152
+ ) # apply activation, layernorm, and dropout
153
+
118
154
return feat_dict
119
155
156
+
120
157
class TAHIN (nn .Module ):
121
- def __init__ (self , g , meta_path_patterns , in_size , out_size , num_heads , dropout ):
158
+ def __init__ (
159
+ self , g , meta_path_patterns , in_size , out_size , num_heads , dropout
160
+ ):
122
161
super (TAHIN , self ).__init__ ()
123
162
124
- #embeddings for different types of nodes, h0
163
+ # embeddings for different types of nodes, h0
125
164
self .initializer = nn .init .xavier_uniform_
126
- self .feature_dict = nn .ParameterDict ({
127
- ntype : nn .Parameter (self .initializer (torch .empty (g .num_nodes (ntype ), in_size ))) for ntype in g .ntypes
128
- })
165
+ self .feature_dict = nn .ParameterDict (
166
+ {
167
+ ntype : nn .Parameter (
168
+ self .initializer (torch .empty (g .num_nodes (ntype ), in_size ))
169
+ )
170
+ for ntype in g .ntypes
171
+ }
172
+ )
129
173
130
- #relational neighbor aggregation, this produces h1
174
+ # relational neighbor aggregation, this produces h1
131
175
self .RelationalAGG = RelationalAGG (g , in_size , out_size )
132
176
133
- #metapath-based aggregation modules for user and item, this produces h2
134
- self .meta_path_patterns = meta_path_patterns
135
- #one HANLayer for user, one HANLayer for item
136
- self .hans = nn .ModuleDict ({
137
- key : HANLayer (value , in_size , out_size , num_heads , dropout ) for key , value in self .meta_path_patterns .items ()
138
- })
139
-
140
- #layers to combine h0, h1, and h2
141
- #used to update node embeddings
142
- self .user_layer1 = nn .Linear ((num_heads + 1 )* out_size , out_size , bias = True )
143
- self .user_layer2 = nn .Linear (2 * out_size , out_size , bias = True )
144
- self .item_layer1 = nn .Linear ((num_heads + 1 )* out_size , out_size , bias = True )
145
- self .item_layer2 = nn .Linear (2 * out_size , out_size , bias = True )
146
-
147
- #layernorm
177
+ # metapath-based aggregation modules for user and item, this produces h2
178
+ self .meta_path_patterns = meta_path_patterns
179
+ # one HANLayer for user, one HANLayer for item
180
+ self .hans = nn .ModuleDict (
181
+ {
182
+ key : HANLayer (value , in_size , out_size , num_heads , dropout )
183
+ for key , value in self .meta_path_patterns .items ()
184
+ }
185
+ )
186
+
187
+ # layers to combine h0, h1, and h2
188
+ # used to update node embeddings
189
+ self .user_layer1 = nn .Linear (
190
+ (num_heads + 1 ) * out_size , out_size , bias = True
191
+ )
192
+ self .user_layer2 = nn .Linear (2 * out_size , out_size , bias = True )
193
+ self .item_layer1 = nn .Linear (
194
+ (num_heads + 1 ) * out_size , out_size , bias = True
195
+ )
196
+ self .item_layer2 = nn .Linear (2 * out_size , out_size , bias = True )
197
+
198
+ # layernorm
148
199
self .layernorm = nn .LayerNorm (out_size )
149
200
150
- #network to score the node pairs
201
+ # network to score the node pairs
151
202
self .pred = nn .Linear (out_size , out_size )
152
203
self .dropout = nn .Dropout (dropout )
153
204
self .fc = nn .Linear (out_size , 1 )
154
205
155
206
def forward (self , g , user_key , item_key , user_idx , item_idx ):
156
- #relational neighbor aggregation, h1
207
+ # relational neighbor aggregation, h1
157
208
h1 = self .RelationalAGG (g , self .feature_dict )
158
209
159
- #metapath-based aggregation, h2
210
+ # metapath-based aggregation, h2
160
211
h2 = {}
161
212
for key in self .meta_path_patterns .keys ():
162
213
h2 [key ] = self .hans [key ](g , self .feature_dict [key ])
163
214
164
- #update node embeddings
215
+ # update node embeddings
165
216
user_emb = torch .cat ((h1 [user_key ], h2 [user_key ]), 1 )
166
217
item_emb = torch .cat ((h1 [item_key ], h2 [item_key ]), 1 )
167
218
user_emb = self .user_layer1 (user_emb )
168
219
item_emb = self .item_layer1 (item_emb )
169
- user_emb = self .user_layer2 (torch .cat ((user_emb , self .feature_dict [user_key ]), 1 ))
170
- item_emb = self .item_layer2 (torch .cat ((item_emb , self .feature_dict [item_key ]), 1 ))
220
+ user_emb = self .user_layer2 (
221
+ torch .cat ((user_emb , self .feature_dict [user_key ]), 1 )
222
+ )
223
+ item_emb = self .item_layer2 (
224
+ torch .cat ((item_emb , self .feature_dict [item_key ]), 1 )
225
+ )
171
226
172
- #Relu
227
+ # Relu
173
228
user_emb = F .relu_ (user_emb )
174
229
item_emb = F .relu_ (item_emb )
175
-
176
- #layer norm
230
+
231
+ # layer norm
177
232
user_emb = self .layernorm (user_emb )
178
233
item_emb = self .layernorm (item_emb )
179
-
180
- #obtain users/items embeddings and their interactions
234
+
235
+ # obtain users/items embeddings and their interactions
181
236
user_feat = user_emb [user_idx ]
182
237
item_feat = item_emb [item_idx ]
183
- interaction = user_feat * item_feat
238
+ interaction = user_feat * item_feat
184
239
185
- #score the node pairs
240
+ # score the node pairs
186
241
pred = self .pred (interaction )
187
- pred = self .dropout (pred ) # dropout
242
+ pred = self .dropout (pred ) # dropout
188
243
pred = self .fc (pred )
189
244
pred = torch .sigmoid (pred )
190
245
191
246
return pred .squeeze (1 )
192
-
0 commit comments