|
| 1 | +import glob |
| 2 | +import logging |
| 3 | +import os |
| 4 | + |
| 5 | +import torch |
| 6 | +from rdkit import Chem |
| 7 | +from torchdrug import data |
| 8 | +from torchdrug.core import Registry as R |
| 9 | +from torchdrug.layers import functional |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | +from diffpack import rotamer |
| 13 | +from diffpack.rotamer import get_chi_mask, atom_name_vocab, bb_atom_name |
| 14 | + |
| 15 | +logging.basicConfig(level=logging.DEBUG) |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +@R.register("datasets.SideChainDataset") |
| 21 | +class SideChainDataset(data.ProteinDataset): |
| 22 | + processed_file = None |
| 23 | + exclude_pdb_files = [] |
| 24 | + |
| 25 | + def __init__(self, path=None, pdb_files=None, verbose=1, **kwargs): |
| 26 | + if path is not None: |
| 27 | + logger.info("Loading dataset from folder %s" % path) |
| 28 | + path = os.path.expanduser(path) |
| 29 | + if not os.path.exists(path): |
| 30 | + os.makedirs(path) |
| 31 | + self.path = path |
| 32 | + pkl_file = os.path.join(path, self.processed_file) |
| 33 | + |
| 34 | + if os.path.exists(pkl_file): |
| 35 | + logger.info("Found existing pickle file %s" % pkl_file |
| 36 | + + ". Loading from pickle file (this may take a while)") |
| 37 | + self.load_pickle(pkl_file, verbose=verbose, **kwargs) |
| 38 | + else: |
| 39 | + logger.info("No pickle file found. Loading from pdb files (this may take a while)" |
| 40 | + + " and save to pickle file %s" % pkl_file) |
| 41 | + pdb_files = sorted(glob.glob(os.path.join(path, "*.pdb"))) |
| 42 | + self.load_pdbs(pdb_files, verbose=verbose, **kwargs) |
| 43 | + self.save_pickle(pkl_file, verbose=verbose) |
| 44 | + elif pdb_files is not None: |
| 45 | + logger.info("Loading dataset from pdb files") |
| 46 | + pdb_files = [os.path.expanduser(pdb_file) for pdb_file in pdb_files] |
| 47 | + pdb_files = [pdb_file for pdb_file in pdb_files if pdb_file.endswith(".pdb")] |
| 48 | + self.load_pdbs(pdb_files, verbose=verbose, **kwargs) |
| 49 | + |
| 50 | + # Filter out proteins with no residues |
| 51 | + indexes = [i for i, (protein, pdb_file) in enumerate(zip(self.data, self.pdb_files)) |
| 52 | + if (protein.num_residue > 0).all() and os.path.basename(pdb_file) not in self.exclude_pdb_files] |
| 53 | + self.data = [self.data[i] for i in indexes] |
| 54 | + self.sequences = [self.sequences[i] for i in indexes] |
| 55 | + self.pdb_files = [self.pdb_files[i] for i in indexes] |
| 56 | + |
| 57 | + def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, sanitize=True, removeHs=True, **kwargs): |
| 58 | + """ |
| 59 | + Load the dataset from pdb files. |
| 60 | +
|
| 61 | + Parameters: |
| 62 | + pdb_files (list of str): pdb file names |
| 63 | + transform (Callable, optional): protein sequence transformation function |
| 64 | + lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader. |
| 65 | + This may slow down the data loading process, but save a lot of CPU memory and dataset loading time. |
| 66 | + verbose (int, optional): output verbose level |
| 67 | + **kwargs |
| 68 | + """ |
| 69 | + num_sample = len(pdb_files) |
| 70 | + |
| 71 | + self.transform = transform |
| 72 | + self.lazy = lazy |
| 73 | + self.kwargs = kwargs |
| 74 | + self.data = [] |
| 75 | + self.pdb_files = [] |
| 76 | + self.sequences = [] |
| 77 | + |
| 78 | + if verbose: |
| 79 | + pdb_files = tqdm(pdb_files, "Constructing proteins from pdbs") |
| 80 | + for i, pdb_file in enumerate(pdb_files): |
| 81 | + if not lazy or i == 0: |
| 82 | + mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize, removeHs=removeHs) |
| 83 | + if not mol: |
| 84 | + logger.debug("Can't construct molecule from pdb file `%s`. Ignore this sample." % pdb_file) |
| 85 | + continue |
| 86 | + protein = data.Protein.from_molecule(mol, **kwargs) |
| 87 | + if not protein: |
| 88 | + logger.debug("Can't construct protein from pdb file `%s`. Ignore this sample." % pdb_file) |
| 89 | + continue |
| 90 | + else: |
| 91 | + protein = None |
| 92 | + if hasattr(protein, "residue_feature"): |
| 93 | + with protein.residue(): |
| 94 | + protein.residue_feature = protein.residue_feature.to_sparse() |
| 95 | + self.data.append(protein) |
| 96 | + self.pdb_files.append(pdb_file) |
| 97 | + self.sequences.append(protein.to_sequence() if protein else None) |
| 98 | + |
| 99 | + def get_item(self, index): |
| 100 | + if getattr(self, "lazy", False): |
| 101 | + protein = data.Protein.from_pdb(self.pdb_files[index], **self.kwargs) |
| 102 | + else: |
| 103 | + protein = self.data[index].clone() |
| 104 | + protein = protein.subgraph(protein.atom_name < 37) |
| 105 | + |
| 106 | + with protein.atom(): |
| 107 | + # Init atom14 index map |
| 108 | + protein.atom14index = rotamer.restype_atom14_index_map[ |
| 109 | + protein.residue_type[protein.atom2residue], protein.atom_name |
| 110 | + ] # [num_atom, 14] |
| 111 | + |
| 112 | + with protein.residue(): |
| 113 | + # Init residue features |
| 114 | + protein.residue_feature = functional.one_hot(protein.residue_type, 21) # [num_residue, 21] |
| 115 | + |
| 116 | + # Init residue masks |
| 117 | + chi_mask = get_chi_mask(protein) |
| 118 | + chi_1pi_periodic_mask = torch.tensor(rotamer.chi_pi_periodic)[protein.residue_type] |
| 119 | + chi_2pi_periodic_mask = ~chi_1pi_periodic_mask |
| 120 | + protein.chi_mask = chi_mask |
| 121 | + protein.chi_1pi_periodic_mask = torch.logical_and(chi_mask, chi_1pi_periodic_mask) # [num_residue, 4] |
| 122 | + protein.chi_2pi_periodic_mask = torch.logical_and(chi_mask, chi_2pi_periodic_mask) # [num_residue, 4] |
| 123 | + |
| 124 | + # Init atom37 features |
| 125 | + protein.atom37_mask = torch.zeros(protein.num_residue, len(atom_name_vocab), device=protein.device, |
| 126 | + dtype=torch.bool) # [num_residue, 37] |
| 127 | + protein.atom37_mask[protein.atom2residue, protein.atom_name] = True |
| 128 | + protein.sidechain37_mask = protein.atom37_mask.clone() # [num_residue, 37] |
| 129 | + protein.sidechain37_mask[:, bb_atom_name] = False |
| 130 | + item = {"graph": protein} |
| 131 | + |
| 132 | + if self.transform: |
| 133 | + item = self.transform(item) |
| 134 | + return item |
| 135 | + |
| 136 | + @staticmethod |
| 137 | + def from_pdb_files(pdb_files, verbose=1, **kwargs): |
| 138 | + return SideChainDataset(pdb_files, verbose=verbose, **kwargs) |
| 139 | + |
| 140 | + def __repr__(self): |
| 141 | + lines = ["#sample: %d" % len(self)] |
| 142 | + return "%s( %s)" % (self.__class__.__name__, "\n ".join(lines)) |
| 143 | + |
0 commit comments