Skip to content

Commit 91e9bd9

Browse files
committed
change chain id to accept lowercase and digits
1 parent c8155f4 commit 91e9bd9

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

torchdrug/data/protein.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import string
23
import warnings
34
from collections import defaultdict
45

@@ -50,18 +51,16 @@ class Protein(Molecule):
5051
residue2id = {"GLY": 0, "ALA": 1, "SER": 2, "PRO": 3, "VAL": 4, "THR": 5, "CYS": 6, "ILE": 7, "LEU": 8,
5152
"ASN": 9, "ASP": 10, "GLN": 11, "LYS": 12, "GLU": 13, "MET": 14, "HIS": 15, "PHE": 16,
5253
"ARG": 17, "TYR": 18, "TRP": 19}
54+
residue_symbol2id = {"G": 0, "A": 1, "S": 2, "P": 3, "V": 4, "T": 5, "C": 6, "I": 7, "L": 8, "N": 9,
55+
"D": 10, "Q": 11, "K": 12, "E": 13, "M": 14, "H": 15, "F": 16, "R": 17, "Y": 18, "W": 19}
5356
atom_name2id = {"C": 0, "CA": 1, "CB": 2, "CD": 3, "CD1": 4, "CD2": 5, "CE": 6, "CE1": 7, "CE2": 8,
5457
"CE3": 9, "CG": 10, "CG1": 11, "CG2": 12, "CH2": 13, "CZ": 14, "CZ2": 15, "CZ3": 16,
5558
"N": 17, "ND1": 18, "ND2": 19, "NE": 20, "NE1": 21, "NE2": 22, "NH1": 23, "NH2": 24,
5659
"NZ": 25, "O": 26, "OD1": 27, "OD2": 28, "OE1": 29, "OE2": 30, "OG": 31, "OG1": 32,
5760
"OH": 33, "OXT": 34, "SD": 35, "SG": 36, "UNK": 37}
58-
alphabet2id = {" ": 0, "A": 1, "B": 2, "C": 3, "D": 4, "E": 5, "F": 6, "G": 7, "H": 8, "I": 9, "J": 10,
59-
"K": 11, "L": 12, "M": 13, "N": 14, "O": 15, "P": 16, "Q": 17, "R": 18, "S": 19, "T": 20,
60-
"U": 21, "V": 22, "W": 23, "X": 24, "Y": 25, "Z": 26}
61+
alphabet2id = {c: i for i, c in enumerate(" " + string.ascii_uppercase + string.ascii_lowercase + string.digits)}
6162
id2residue = {v: k for k, v in residue2id.items()}
62-
id2residue_symbol = {0: "G", 1: "A", 2: "S", 3: "P", 4: "V", 5: "T", 6: "C", 7: "I", 8: "L", 9: "N",
63-
10: "D", 11: "Q", 12: "K", 13: "E", 14: "M", 15: "H", 16: "F", 17: "R", 18: "Y", 19: "W"}
64-
residue_symbol2id = {v: k for k, v in id2residue_symbol.items()}
63+
id2residue_symbol = {v: k for k, v in residue_symbol2id.items()}
6564
id2atom_name = {v: k for k, v in atom_name2id.items()}
6665
id2alphabet = {v: k for k, v in alphabet2id.items()}
6766

@@ -214,7 +213,11 @@ def from_molecule(cls, mol, atom_feature="default", bond_feature="default", resi
214213
type = "GLY"
215214
residue_type.append(cls.residue2id[type])
216215
residue_number.append(number)
217-
if pdbinfo.GetInsertionCode() not in cls.alphabet2id or pdbinfo.GetChainId() not in cls.alphabet2id:
216+
if pdbinfo.GetInsertionCode() not in cls.alphabet2id:
217+
warnings.warn(f"Fail to create the protein. Unknown insertion code {pdbinfo.GetInsertionCode()}.")
218+
return None
219+
if pdbinfo.GetChainId() not in cls.alphabet2id:
220+
warnings.warn(f"Fail to create the protein. Unknown chain id {pdbinfo.GetChainId()}.")
218221
return None
219222
insertion_code.append(cls.alphabet2id[pdbinfo.GetInsertionCode()])
220223
chain_id.append(cls.alphabet2id[pdbinfo.GetChainId()])

0 commit comments

Comments
 (0)