Skip to content

Commit be8763f

Browse files
frozenbugsSteve
and
Steve
authored
[Misc] Black auto fix. (dmlc#4679)
Co-authored-by: Steve <[email protected]>
1 parent eae6ce2 commit be8763f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+7517
-4313
lines changed

examples/pytorch/TAHIN/TAHIN.py

+144-90
Original file line numberDiff line numberDiff line change
@@ -6,187 +6,241 @@
66
import dgl.function as fn
77
from dgl.nn.pytorch import GATConv
88

9-
#Semantic attention in the metapath-based aggregation (the same as that in the HAN)
9+
10+
# Semantic attention in the metapath-based aggregation (the same as that in the HAN)
1011
class SemanticAttention(nn.Module):
1112
def __init__(self, in_size, hidden_size=128):
1213
super(SemanticAttention, self).__init__()
1314

1415
self.project = nn.Sequential(
1516
nn.Linear(in_size, hidden_size),
1617
nn.Tanh(),
17-
nn.Linear(hidden_size, 1, bias=False)
18+
nn.Linear(hidden_size, 1, bias=False),
1819
)
1920

2021
def forward(self, z):
21-
'''
22+
"""
2223
Shape of z: (N, M , D*K)
2324
N: number of nodes
2425
M: number of metapath patterns
2526
D: hidden_size
2627
K: number of heads
27-
'''
28-
w = self.project(z).mean(0) # (M, 1)
29-
beta = torch.softmax(w, dim=0) # (M, 1)
30-
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
28+
"""
29+
w = self.project(z).mean(0) # (M, 1)
30+
beta = torch.softmax(w, dim=0) # (M, 1)
31+
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
32+
33+
return (beta * z).sum(1) # (N, D * K)
3134

32-
return (beta * z).sum(1) # (N, D * K)
3335

34-
#Metapath-based aggregation (the same as the HANLayer)
36+
# Metapath-based aggregation (the same as the HANLayer)
3537
class HANLayer(nn.Module):
36-
def __init__(self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout):
38+
def __init__(
39+
self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout
40+
):
3741
super(HANLayer, self).__init__()
3842

3943
# One GAT layer for each meta path based adjacency matrix
4044
self.gat_layers = nn.ModuleList()
4145
for i in range(len(meta_path_patterns)):
42-
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
43-
dropout, dropout, activation=F.elu,
44-
allow_zero_in_degree=True))
45-
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
46-
self.meta_path_patterns = list(tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns)
46+
self.gat_layers.append(
47+
GATConv(
48+
in_size,
49+
out_size,
50+
layer_num_heads,
51+
dropout,
52+
dropout,
53+
activation=F.elu,
54+
allow_zero_in_degree=True,
55+
)
56+
)
57+
self.semantic_attention = SemanticAttention(
58+
in_size=out_size * layer_num_heads
59+
)
60+
self.meta_path_patterns = list(
61+
tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns
62+
)
4763

4864
self._cached_graph = None
4965
self._cached_coalesced_graph = {}
5066

5167
def forward(self, g, h):
5268
semantic_embeddings = []
53-
#obtain metapath reachable graph
69+
# obtain metapath reachable graph
5470
if self._cached_graph is None or self._cached_graph is not g:
5571
self._cached_graph = g
5672
self._cached_coalesced_graph.clear()
5773
for meta_path_pattern in self.meta_path_patterns:
58-
self._cached_coalesced_graph[meta_path_pattern] = dgl.metapath_reachable_graph(
59-
g, meta_path_pattern)
74+
self._cached_coalesced_graph[
75+
meta_path_pattern
76+
] = dgl.metapath_reachable_graph(g, meta_path_pattern)
6077

6178
for i, meta_path_pattern in enumerate(self.meta_path_patterns):
6279
new_g = self._cached_coalesced_graph[meta_path_pattern]
6380
semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
64-
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K)
81+
semantic_embeddings = torch.stack(
82+
semantic_embeddings, dim=1
83+
) # (N, M, D * K)
84+
85+
return self.semantic_attention(semantic_embeddings) # (N, D * K)
6586

66-
return self.semantic_attention(semantic_embeddings) # (N, D * K)
6787

68-
#Relational neighbor aggregation
88+
# Relational neighbor aggregation
6989
class RelationalAGG(nn.Module):
7090
def __init__(self, g, in_size, out_size, dropout=0.1):
7191
super(RelationalAGG, self).__init__()
7292
self.in_size = in_size
7393
self.out_size = out_size
7494

75-
#Transform weights for different types of edges
76-
self.W_T = nn.ModuleDict({
77-
name : nn.Linear(in_size, out_size, bias = False) for name in g.etypes
78-
})
95+
# Transform weights for different types of edges
96+
self.W_T = nn.ModuleDict(
97+
{
98+
name: nn.Linear(in_size, out_size, bias=False)
99+
for name in g.etypes
100+
}
101+
)
79102

80-
#Attention weights for different types of edges
81-
self.W_A = nn.ModuleDict({
82-
name : nn.Linear(out_size, 1, bias = False) for name in g.etypes
83-
})
103+
# Attention weights for different types of edges
104+
self.W_A = nn.ModuleDict(
105+
{name: nn.Linear(out_size, 1, bias=False) for name in g.etypes}
106+
)
84107

85-
#layernorm
108+
# layernorm
86109
self.layernorm = nn.LayerNorm(out_size)
87110

88-
#dropout layer
111+
# dropout layer
89112
self.dropout = nn.Dropout(dropout)
90113

91114
def forward(self, g, feat_dict):
92-
funcs={}
115+
funcs = {}
93116
for srctype, etype, dsttype in g.canonical_etypes:
94-
g.nodes[dsttype].data['h'] = feat_dict[dsttype] #nodes' original feature
95-
g.nodes[srctype].data['h'] = feat_dict[srctype]
96-
g.nodes[srctype].data['t_h'] = self.W_T[etype](feat_dict[srctype]) #src nodes' transformed feature
97-
98-
#compute the attention numerator (exp)
99-
g.apply_edges(fn.u_mul_v('t_h','h','x'),etype=etype)
100-
g.edges[etype].data['x'] = torch.exp(self.W_A[etype](g.edges[etype].data['x']))
101-
102-
#first update to compute the attention denominator (\sum exp)
103-
funcs[etype] = (fn.copy_e('x', 'm'), fn.sum('m', 'att'))
104-
g.multi_update_all(funcs, 'sum')
105-
106-
funcs={}
117+
g.nodes[dsttype].data["h"] = feat_dict[
118+
dsttype
119+
] # nodes' original feature
120+
g.nodes[srctype].data["h"] = feat_dict[srctype]
121+
g.nodes[srctype].data["t_h"] = self.W_T[etype](
122+
feat_dict[srctype]
123+
) # src nodes' transformed feature
124+
125+
# compute the attention numerator (exp)
126+
g.apply_edges(fn.u_mul_v("t_h", "h", "x"), etype=etype)
127+
g.edges[etype].data["x"] = torch.exp(
128+
self.W_A[etype](g.edges[etype].data["x"])
129+
)
130+
131+
# first update to compute the attention denominator (\sum exp)
132+
funcs[etype] = (fn.copy_e("x", "m"), fn.sum("m", "att"))
133+
g.multi_update_all(funcs, "sum")
134+
135+
funcs = {}
107136
for srctype, etype, dsttype in g.canonical_etypes:
108-
g.apply_edges(fn.e_div_v('x', 'att', 'att'),etype=etype) #compute attention weights (numerator/denominator)
109-
funcs[etype] = (fn.u_mul_e('h', 'att', 'm'), fn.sum('m', 'h')) #\sum(h0*att) -> h1
110-
#second update to obtain h1
111-
g.multi_update_all(funcs, 'sum')
112-
113-
#apply activation, layernorm, and dropout
114-
feat_dict={}
137+
g.apply_edges(
138+
fn.e_div_v("x", "att", "att"), etype=etype
139+
) # compute attention weights (numerator/denominator)
140+
funcs[etype] = (
141+
fn.u_mul_e("h", "att", "m"),
142+
fn.sum("m", "h"),
143+
) # \sum(h0*att) -> h1
144+
# second update to obtain h1
145+
g.multi_update_all(funcs, "sum")
146+
147+
# apply activation, layernorm, and dropout
148+
feat_dict = {}
115149
for ntype in g.ntypes:
116-
feat_dict[ntype] = self.dropout(self.layernorm(F.relu_(g.nodes[ntype].data['h']))) #apply activation, layernorm, and dropout
117-
150+
feat_dict[ntype] = self.dropout(
151+
self.layernorm(F.relu_(g.nodes[ntype].data["h"]))
152+
) # apply activation, layernorm, and dropout
153+
118154
return feat_dict
119155

156+
120157
class TAHIN(nn.Module):
121-
def __init__(self, g, meta_path_patterns, in_size, out_size, num_heads, dropout):
158+
def __init__(
159+
self, g, meta_path_patterns, in_size, out_size, num_heads, dropout
160+
):
122161
super(TAHIN, self).__init__()
123162

124-
#embeddings for different types of nodes, h0
163+
# embeddings for different types of nodes, h0
125164
self.initializer = nn.init.xavier_uniform_
126-
self.feature_dict = nn.ParameterDict({
127-
ntype: nn.Parameter(self.initializer(torch.empty(g.num_nodes(ntype), in_size))) for ntype in g.ntypes
128-
})
165+
self.feature_dict = nn.ParameterDict(
166+
{
167+
ntype: nn.Parameter(
168+
self.initializer(torch.empty(g.num_nodes(ntype), in_size))
169+
)
170+
for ntype in g.ntypes
171+
}
172+
)
129173

130-
#relational neighbor aggregation, this produces h1
174+
# relational neighbor aggregation, this produces h1
131175
self.RelationalAGG = RelationalAGG(g, in_size, out_size)
132176

133-
#metapath-based aggregation modules for user and item, this produces h2
134-
self.meta_path_patterns = meta_path_patterns
135-
#one HANLayer for user, one HANLayer for item
136-
self.hans = nn.ModuleDict({
137-
key: HANLayer(value, in_size, out_size, num_heads, dropout) for key, value in self.meta_path_patterns.items()
138-
})
139-
140-
#layers to combine h0, h1, and h2
141-
#used to update node embeddings
142-
self.user_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True)
143-
self.user_layer2 = nn.Linear(2*out_size, out_size, bias=True)
144-
self.item_layer1 = nn.Linear((num_heads+1)*out_size, out_size, bias=True)
145-
self.item_layer2 = nn.Linear(2*out_size, out_size, bias=True)
146-
147-
#layernorm
177+
# metapath-based aggregation modules for user and item, this produces h2
178+
self.meta_path_patterns = meta_path_patterns
179+
# one HANLayer for user, one HANLayer for item
180+
self.hans = nn.ModuleDict(
181+
{
182+
key: HANLayer(value, in_size, out_size, num_heads, dropout)
183+
for key, value in self.meta_path_patterns.items()
184+
}
185+
)
186+
187+
# layers to combine h0, h1, and h2
188+
# used to update node embeddings
189+
self.user_layer1 = nn.Linear(
190+
(num_heads + 1) * out_size, out_size, bias=True
191+
)
192+
self.user_layer2 = nn.Linear(2 * out_size, out_size, bias=True)
193+
self.item_layer1 = nn.Linear(
194+
(num_heads + 1) * out_size, out_size, bias=True
195+
)
196+
self.item_layer2 = nn.Linear(2 * out_size, out_size, bias=True)
197+
198+
# layernorm
148199
self.layernorm = nn.LayerNorm(out_size)
149200

150-
#network to score the node pairs
201+
# network to score the node pairs
151202
self.pred = nn.Linear(out_size, out_size)
152203
self.dropout = nn.Dropout(dropout)
153204
self.fc = nn.Linear(out_size, 1)
154205

155206
def forward(self, g, user_key, item_key, user_idx, item_idx):
156-
#relational neighbor aggregation, h1
207+
# relational neighbor aggregation, h1
157208
h1 = self.RelationalAGG(g, self.feature_dict)
158209

159-
#metapath-based aggregation, h2
210+
# metapath-based aggregation, h2
160211
h2 = {}
161212
for key in self.meta_path_patterns.keys():
162213
h2[key] = self.hans[key](g, self.feature_dict[key])
163214

164-
#update node embeddings
215+
# update node embeddings
165216
user_emb = torch.cat((h1[user_key], h2[user_key]), 1)
166217
item_emb = torch.cat((h1[item_key], h2[item_key]), 1)
167218
user_emb = self.user_layer1(user_emb)
168219
item_emb = self.item_layer1(item_emb)
169-
user_emb = self.user_layer2(torch.cat((user_emb, self.feature_dict[user_key]), 1))
170-
item_emb = self.item_layer2(torch.cat((item_emb, self.feature_dict[item_key]), 1))
220+
user_emb = self.user_layer2(
221+
torch.cat((user_emb, self.feature_dict[user_key]), 1)
222+
)
223+
item_emb = self.item_layer2(
224+
torch.cat((item_emb, self.feature_dict[item_key]), 1)
225+
)
171226

172-
#Relu
227+
# Relu
173228
user_emb = F.relu_(user_emb)
174229
item_emb = F.relu_(item_emb)
175-
176-
#layer norm
230+
231+
# layer norm
177232
user_emb = self.layernorm(user_emb)
178233
item_emb = self.layernorm(item_emb)
179-
180-
#obtain users/items embeddings and their interactions
234+
235+
# obtain users/items embeddings and their interactions
181236
user_feat = user_emb[user_idx]
182237
item_feat = item_emb[item_idx]
183-
interaction = user_feat*item_feat
238+
interaction = user_feat * item_feat
184239

185-
#score the node pairs
240+
# score the node pairs
186241
pred = self.pred(interaction)
187-
pred = self.dropout(pred) #dropout
242+
pred = self.dropout(pred) # dropout
188243
pred = self.fc(pred)
189244
pred = torch.sigmoid(pred)
190245

191246
return pred.squeeze(1)
192-

0 commit comments

Comments
 (0)