@@ -26,6 +26,16 @@ class GraphConstruction(nn.Module, core.Configurable):
26
26
2. For ``gearnet``, the feature of the edge :math:`e_{ij}` between residue :math:`i` and residue :math:`j`
27
27
is the concatenation ``[residue_type(i), residue_type(j), edge_type(e_ij),
28
28
sequential_distance(i,j), spatial_distance(i,j)]``.
29
+
30
+ .. note::
31
+ You may customize your own edge features by inheriting this class and define a member function
32
+ for your features. Use ``edge_feature="my_feature"`` to call the following feature function.
33
+
34
+ .. code:: python
35
+
36
+ def edge_my_feature(self, graph, edge_list, num_relation):
37
+ ...
38
+ return feature # the first dimension must be ``graph.num_edge``
29
39
"""
30
40
31
41
max_seq_dist = 10
@@ -43,7 +53,7 @@ def __init__(self, node_layers=None, edge_layers=None, edge_feature="residue_typ
43
53
self .edge_layers = edge_layers
44
54
self .edge_feature = edge_feature
45
55
46
- def edge_residue_type (self , graph , edge_list ):
56
+ def edge_residue_type (self , graph , edge_list , num_relation ):
47
57
node_in , node_out , _ = edge_list .t ()
48
58
residue_in , residue_out = graph .atom2residue [node_in ], graph .atom2residue [node_out ]
49
59
in_residue_type = graph .residue_type [residue_in ]
@@ -103,10 +113,8 @@ def apply_edge_layer(self, graph):
103
113
num_edges = edge2graph .bincount (minlength = graph .batch_size )
104
114
offsets = (graph .num_cum_nodes - graph .num_nodes ).repeat_interleave (num_edges )
105
115
106
- if self .edge_feature == "residue_type" :
107
- edge_feature = self .edge_residue_type (graph , edge_list )
108
- elif self .edge_feature == "gearnet" :
109
- edge_feature = self .edge_gearnet (graph , edge_list , num_relation )
116
+ if hasattr (self , "edge_%s" % self .edge_feature ):
117
+ edge_feature = getattr (self , "edge_%s" % self .edge_feature )(graph , edge_list , num_relation )
110
118
elif self .edge_feature is None :
111
119
edge_feature = None
112
120
else :
0 commit comments