Skip to content

Commit 07ea82b

Browse files
authored
Merge pull request #24 from choderalab/fix-issue-23
Update bias assignment in DGL GAT model
2 parents 42893ce + d113733 commit 07ea82b

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

mtenn/conversion_utils/gat.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919

2020
class GAT(torch.nn.Module):
2121
def __init__(self, *args, model=None, **kwargs):
22+
super().__init__()
23+
2224
## If no model is passed, construct model based on passed args, otherwise copy
2325
## all parameters and weights over
2426
if model is None:
25-
super().__init__()
2627
self.gnn = GAT_dgl(*args, **kwargs)
2728
else:
2829
# Parameters that are conveniently accessible from the top level
@@ -31,22 +32,25 @@ def __init__(self, *args, model=None, **kwargs):
3132
num_heads = model.num_heads
3233
agg_modes = model.agg_modes
3334
# Parameters that can only be adcessed layer-wise
34-
layer_params = [
35-
(
36-
l.gat_conv.feat_drop.p,
37-
l.gat_conv.attn_drop.p,
38-
l.gat_conv.leaky_relu.negative_slope,
39-
bool(l.gat_conv.res_fc),
40-
l.gat_conv.activation,
41-
bool(l.gat_conv.bias),
35+
layer_params = []
36+
for l in model.gnn_layers:
37+
gc = l.gat_conv
38+
new_params = (
39+
gc.feat_drop.p,
40+
gc.attn_drop.p,
41+
gc.leaky_relu.negative_slope,
42+
gc.activation,
43+
bool(gc.res_fc),
44+
(gc.res_fc.bias is not None)
45+
if gc.has_linear_res
46+
else gc.has_explicit_bias,
4247
)
43-
for l in model.gnn_layers
44-
]
48+
layer_params += [new_params]
49+
4550
(
4651
feat_drops,
4752
attn_drops,
4853
alphas,
49-
residuals,
5054
activations,
5155
residuals,
5256
biases,

0 commit comments

Comments
 (0)