Skip to content

Commit 17ab229

Browse files
committed
support customized edge feature in graph construction
1 parent 91e9bd9 commit 17ab229

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

torchdrug/core/core.py

-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def __setattr__(self, key, value):
8686
def __delattr__(self, key):
8787
if hasattr(self, "meta_dict") and key in self.meta_dict:
8888
del self.meta_dict[key]
89-
del self.data_dict[key]
9089
super(_MetaContainer, self).__delattr__(self, key)
9190

9291
def _setattr(self, key, value):

torchdrug/core/engine.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -225,20 +225,21 @@ def evaluate(self, split, log=True):
225225

226226
return metric
227227

228-
def load(self, checkpoint, load_optimizer=True):
228+
def load(self, checkpoint, load_optimizer=True, strict=True):
229229
"""
230230
Load a checkpoint from file.
231231
232232
Parameters:
233233
checkpoint (file-like): checkpoint file
234234
load_optimizer (bool, optional): load optimizer state or not
235+
strict (bool, optional): whether to strictly check the checkpoint matches the model parameters
235236
"""
236237
if comm.get_rank() == 0:
237238
logger.warning("Load checkpoint from %s" % checkpoint)
238239
checkpoint = os.path.expanduser(checkpoint)
239240
state = torch.load(checkpoint, map_location=self.device)
240241

241-
self.model.load_state_dict(state["model"])
242+
self.model.load_state_dict(state["model"], strict=strict)
242243

243244
if load_optimizer:
244245
self.optimizer.load_state_dict(state["optimizer"])

torchdrug/layers/geometry/graph.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ class GraphConstruction(nn.Module, core.Configurable):
2626
2. For ``gearnet``, the feature of the edge :math:`e_{ij}` between residue :math:`i` and residue :math:`j`
2727
is the concatenation ``[residue_type(i), residue_type(j), edge_type(e_ij),
2828
sequential_distance(i,j), spatial_distance(i,j)]``.
29+
30+
.. note::
31+
You may customize your own edge features by inheriting this class and define a member function
32+
for your features. Use ``edge_feature="my_feature"`` to call the following feature function.
33+
34+
.. code:: python
35+
36+
def edge_my_feature(self, graph, edge_list, num_relation):
37+
...
38+
return feature # the first dimension must be ``graph.num_edge``
2939
"""
3040

3141
max_seq_dist = 10
@@ -43,7 +53,7 @@ def __init__(self, node_layers=None, edge_layers=None, edge_feature="residue_typ
4353
self.edge_layers = edge_layers
4454
self.edge_feature = edge_feature
4555

46-
def edge_residue_type(self, graph, edge_list):
56+
def edge_residue_type(self, graph, edge_list, num_relation):
4757
node_in, node_out, _ = edge_list.t()
4858
residue_in, residue_out = graph.atom2residue[node_in], graph.atom2residue[node_out]
4959
in_residue_type = graph.residue_type[residue_in]
@@ -103,10 +113,8 @@ def apply_edge_layer(self, graph):
103113
num_edges = edge2graph.bincount(minlength=graph.batch_size)
104114
offsets = (graph.num_cum_nodes - graph.num_nodes).repeat_interleave(num_edges)
105115

106-
if self.edge_feature == "residue_type":
107-
edge_feature = self.edge_residue_type(graph, edge_list)
108-
elif self.edge_feature == "gearnet":
109-
edge_feature = self.edge_gearnet(graph, edge_list, num_relation)
116+
if hasattr(self, "edge_%s" % self.edge_feature):
117+
edge_feature = getattr(self, "edge_%s" % self.edge_feature)(graph, edge_list, num_relation)
110118
elif self.edge_feature is None:
111119
edge_feature = None
112120
else:

0 commit comments

Comments
 (0)