Skip to content

Commit c8155f4

Browse files
committed
update md5s for alphafolddb and handle unknown residue types in esm
1 parent 5d21a5f commit c8155f4

File tree

9 files changed

+43
-33
lines changed

9 files changed

+43
-33
lines changed

conda/torchdrug/meta.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ source:
77

88
requirements:
99
host:
10-
- python >=3.7,<3.10
10+
- python >=3.7,<3.11
1111
- pip
1212
run:
13-
- python >=3.7,<3.10
13+
- python >=3.7,<3.11
1414
- pytorch >=1.8.0
1515
- pytorch-scatter >=2.0.8
1616
- pytorch-cluster >=1.5.9

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"lmdb",
4444
"fair-esm",
4545
],
46-
python_requires=">=3.7,<3.10",
46+
python_requires=">=3.7,<3.11",
4747
classifiers=[
4848
"Development Status :: 4 - Beta",
4949
'Intended Audience :: Developers',

torchdrug/data/feature.py

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def onehot(x, vocab, allow_unknown=False):
5050
return feature
5151

5252

53+
# TODO: this one is too slow
5354
@R.register("features.atom.default")
5455
def atom_default(atom):
5556
"""Default atom feature.
@@ -331,6 +332,7 @@ def molecule_default(mol):
331332
"""Default molecule feature."""
332333
return ExtendedConnectivityFingerprint(mol)
333334

335+
334336
ECFP = ExtendedConnectivityFingerprint
335337

336338

torchdrug/data/protein.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def from_sequence(cls, sequence, atom_feature="default", bond_feature="default",
295295
"""
296296
if atom_feature is None and bond_feature is None and residue_feature == "default":
297297
return cls._residue_from_sequence(sequence)
298-
298+
299299
mol = Chem.MolFromSequence(sequence)
300300
if mol is None:
301301
raise ValueError("Invalid sequence `%s`" % sequence)

torchdrug/datasets/alphafolddb.py

+27-21
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ class AlphaFoldDB(data.ProteinDataset):
1818
Parameters:
1919
path (str): path to store the dataset
2020
species_id (int, optional): the id of species to be loaded. The species are numbered
21-
by the order appeared on https://alphafold.ebi.ac.uk/download (0-20 for model
21+
by the order appeared on https://alphafold.ebi.ac.uk/download (0-20 for model
2222
organism proteomes, 21 for Swiss-Prot)
23-
split_id (int, optional): the id of split to be loaded. To avoid large memory consumption
24-
for one dataset, we have cut each species into several splits, each of which contains
23+
split_id (int, optional): the id of split to be loaded. To avoid large memory consumption
24+
for one dataset, we have cut each species into several splits, each of which contains
2525
at most 22000 proteins.
2626
verbose (int, optional): output verbose level
2727
**kwargs
@@ -60,46 +60,52 @@ class AlphaFoldDB(data.ProteinDataset):
6060
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000579_71421_HAEIN_v2.tar",
6161
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000429_85962_HELPY_v2.tar",
6262
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000007841_1125630_KLEPH_v2.tar",
63-
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008153_5671_LEIIN_v2.tar",
63+
# "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008153_5671_LEIIN_v2.tar",
6464
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000078237_100816_9PEZI1_v2.tar",
6565
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000806_272631_MYCLE_v2.tar",
66-
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001584_83332_MYCTU_v2.tar",
66+
# "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001584_83332_MYCTU_v2.tar",
6767
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000020681_1299332_MYCUL_v2.tar",
6868
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000535_242231_NEIG1_v2.tar",
6969
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000006304_1133849_9NOCA1_v2.tar",
7070
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000024404_6282_ONCVO_v2.tar",
7171
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002059_502779_PARBA_v2.tar",
72-
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001450_36329_PLAF7_v2.tar",
72+
# "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001450_36329_PLAF7_v2.tar",
7373
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002438_208964_PSEAE_v2.tar",
7474
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000001014_99287_SALTY_v2.tar",
7575
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008854_6183_SCHMA_v2.tar",
7676
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002716_300267_SHIDS_v2.tar",
7777
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000018087_1391915_SPOS1_v2.tar",
78-
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008816_93061_STAA8_v2.tar",
78+
# "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008816_93061_STAA8_v2.tar",
7979
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000000586_171101_STRR6_v2.tar",
8080
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000035681_6248_STRER_v2.tar",
8181
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000030665_36087_TRITR_v2.tar",
8282
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000008524_185431_TRYB2_v2.tar",
83-
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002296_353153_TRYCC_v2.tar",
83+
# "https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000002296_353153_TRYCC_v2.tar",
8484
"https://ftp.ebi.ac.uk/pub/databases/alphafold/v2/UP000270924_6293_WUCBA_v2.tar"
8585
]
8686
md5s = [
87-
"4cd5f596ebfc3d45d9f6b647dc5684af", "9e26602ba2d9f233ef4fcf82703ddb59",
88-
"60a09db1e1c47a98763d09879784f536", "a0ab562b7372f149673c4518f949501f",
89-
"6205138b14fb7e7ec09b366e3e4f294b", "31f31359cd7254f82304e3886440bdd3",
90-
"a590096e65461ed4eb092b2147b97f0b", "8f1e120f372995644a7101ad58e5b2ae",
91-
"9a659c4aed2a8b833478dcd5fffc5fd8", "95d775f2ae271cf50a101c73335cd250",
92-
"e5b12da43f5bd77298ca50e19706bdeb", "90e953abba9c8fe202e0adf825c0dfcc",
93-
"38a11553c7e2d00482281e74f7daf321", "2bcdfe2c37154a355fe4e8150c279c13",
94-
"580a55e56a44fed935f0101c37a8c4ab", "b8d08a9033d111429fadb4e25820f9f7",
95-
"59d1167f414a86cbccfb204791fea0eb", "dfde6b44026f19a88f1abc8ac2798ce6",
96-
"a1c2047a16130d61cac4db23b2f5b560", "e4d4b72df8d075aeb607dcb095210304",
97-
"5cdad48c799ffd723636cae26433f1f9", "98a7c13987f578277bfb66ac48a1e242",
87+
"4cd5f596ebfc3d45d9f6b647dc5684af", "b89bee5507f78f971417cc8fd75b40f7", "a6459a1f1a0a22fbf25f1c05c2889ae3",
88+
"24dfba8ab93dbf3f51e7db6b912dd6b4", "6b81b3086ed9e57e04a54f148ecf974c", "a50f4fd9f581c89e79e1b2857e54b786",
89+
"fdd16245769bf1f7d91a0e285ac00e52", "66b9750c511182bc5f8ee71fe2ab2a17", "5dadeb5aac704025cac33f7557794858",
90+
"99b22e0f050d845782d914becbfe4d2f", "da938dfae4fabf6e144f4b5ede5885ec", "2003c09d437cfb4093552c588a33e06d",
91+
"fba59f386cfa33af3f70ae664b7feac0", "d7a1a6c02213754ee1a1ffb3b41ad4ba", "8a0e8deadffec2aba3b7edd6534b7481",
92+
"1854d0bbcf819de1de7b0cfdb6d32b2e", "d9720e3809db6916405db096b520c236", "6b918e9e4d645b12a80468bcea805f1f",
93+
"ed0eefe927eb8c3b81cf87eaabbb8d6e", "051369e0dc8fed4798c8b2c68e6cbe2e", "b05ff57164167851651c625dca66ed28",
94+
"68e7a6e57bd43cb52e344b3190073387", "75d027ac7833f284fda65ea620353e8a", "7d85bb2ee4130096a6d905ab8d726bcc",
95+
"63498210c88e8bfb1a7346c4ddf73bb1", "5bf2211304ef91d60bb3838ec12d89cd", "4981758eb8980e9df970ac6113e4084c",
96+
"322431789942595b599d2b86670f41b3", "35d7b32e37bcc23d02b12b03b1e0c093", "1b8847dd786fa41b5b38f5e7aa58b813",
97+
"126bdbe59fa82d55bfa098b710bdf650", "6c6d3248ed943dd7137637fc92d7ba37", "532203c6877433df5651b95d27685825",
98+
"6e7112411da5843bec576271c44e0a0a", "0e4f913a9b4672b0ad3cc9c4f2de5c8d", "a138d0060b2e8a0ef1f90cf3ab7b7ca0",
99+
"04d491dd1c679e91b5a2f3b9f14db555", "889c051e39305614accdff00414bfa67", "cd87cf24e5135c9d729940194ccc65c8",
100+
"75eb8bfe866cf3040f4c08a566c32bc1", "fd8e6ddb9c159aab781a11c287c85feb", "b91a2e103980b96f755712f2b559ad66",
101+
"26187d09b093649686d7c158aa4fd113", "62e16894bb4b8951a82befd24ad4ee21", "85c001df1d91788bf3cc1f97230b1dac",
102+
"91a25af808351757b101a8c9c787db9e", "8b3e8645cc4c2484c331759b9d1df5bc", "e8a76a6ab290e6743233510e8d1eb4a5",
103+
"38280bd7804f4c060b0775c4abed9b89"
98104
]
99105
species_nsplit = [
100106
2, 1, 1, 2, 1, 1, 1, 3, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 2, 20,
101107
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
102-
1, 1, 1, 1, 1, 1, 1, 1, 1, 1
108+
1, 1, 1, 1, 1, #1, 1, 1, 1, 1
103109
]
104110
split_length = 22000
105111

@@ -111,7 +117,7 @@ def __init__(self, path, species_id=0, split_id=0, verbose=1, **kwargs):
111117

112118
species_name = os.path.basename(self.urls[species_id])[:-4]
113119
if split_id >= self.species_nsplit[species_id]:
114-
raise ValueError("Split id %d should be less than %d in species %s" %
120+
raise ValueError("Split id %d should be less than %d in species %s" %
115121
(split_id, self.species_nsplit[species_id], species_name))
116122
self.processed_file = "%s_%d.pkl.gz" % (species_name, split_id)
117123
pkl_file = os.path.join(path, self.processed_file)

torchdrug/datasets/gene_ontology.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
@utils.copy_args(data.ProteinDataset.load_pdbs)
1414
class GeneOntology(data.ProteinDataset):
1515
"""
16-
A set of proteins with their 3D structures and GO terms. These terms classify proteins
17-
into hierarchically related functional classes organized into three ontologies: molecular
16+
A set of proteins with their 3D structures and GO terms. These terms classify proteins
17+
into hierarchically related functional classes organized into three ontologies: molecular
1818
function (MF), biological process (BP) and cellular component (CC).
1919
2020
Statistics (test_cutoff=0.95):
@@ -51,7 +51,7 @@ def __init__(self, path, branch="MF", test_cutoff=0.95, verbose=1, **kwargs):
5151
zip_file = utils.download(self.url, path, md5=self.md5)
5252
path = os.path.join(utils.extract(zip_file), "GeneOntology")
5353
pkl_file = os.path.join(path, self.processed_file)
54-
54+
5555
csv_file = os.path.join(path, "nrPDB-GO_test.csv")
5656
pdb_ids = []
5757
with open(csv_file, "r") as fin:

torchdrug/layers/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,10 @@ class SinusoidalPositionEmbedding(nn.Module):
329329
Positional embedding based on sine and cosine functions, proposed in `Attention Is All You Need`_.
330330
331331
.. _Attention Is All You Need:
332-
https://arxiv.org/pdf/1706.03762.pdf
332+
https://arxiv.org/pdf/1706.03762.pdf
333333
334334
Parameters:
335-
output_dim (int): output dimension
335+
output_dim (int): output dimension
336336
"""
337337

338338
def __init__(self, output_dim):

torchdrug/layers/conv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -810,4 +810,4 @@ def message_and_aggregate(self, graph, input):
810810
dim_size=graph.num_node * graph.num_relation)
811811
update += edge_update
812812

813-
return update.view(graph.num_node, self.num_relation * self.input_dim)
813+
return update.view(graph.num_node, self.num_relation * self.input_dim)

torchdrug/models/esm.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class EvolutionaryScaleModeling(nn.Module, core.Configurable):
7171
"ESM-2-3B": 36,
7272
"ESM-2-15B": 48,
7373
}
74-
74+
7575
max_input_length = 1024 - 2
7676

7777
def __init__(self, path, model="ESM-1b", readout="mean"):
@@ -82,6 +82,7 @@ def __init__(self, path, model="ESM-1b", readout="mean"):
8282
self.path = path
8383

8484
_model, alphabet = self.load_weight(path, model)
85+
self.alphabet = alphabet
8586
mapping = self.construct_mapping(alphabet)
8687
self.output_dim = self.output_dim[model]
8788
self.model = _model
@@ -111,7 +112,7 @@ def load_weight(self, path, model):
111112
return esm.pretrained.load_model_and_alphabet_core(model_name, model_data, regression_data)
112113

113114
def construct_mapping(self, alphabet):
114-
mapping = [0] * len(data.Protein.id2residue_symbol)
115+
mapping = [-1] * max(len(data.Protein.id2residue_symbol), len(self.alphabet))
115116
for i, token in data.Protein.id2residue_symbol.items():
116117
mapping[i] = alphabet.get_idx(token)
117118
mapping = torch.tensor(mapping)
@@ -133,6 +134,7 @@ def forward(self, graph, input, all_loss=None, metric=None):
133134
"""
134135
input = graph.residue_type
135136
input = self.mapping[input]
137+
input[input == -1] = graph.residue_type[input == -1]
136138
size = graph.num_residues
137139
if (size > self.max_input_length).any():
138140
warnings.warn("ESM can only encode proteins within %d residues. Truncate the input to fit into ESM."

0 commit comments

Comments
 (0)