Skip to content

Commit 5f62453

Browse files
authored
prepare release
1 parent 73f020d commit 5f62453

19 files changed

+5138
-1
lines changed

10mh_A.pdb

+2,608
Large diffs are not rendered by default.

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2023 Yangtian Zhang, Zuobai Zhang
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+72-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,72 @@
1-
# DiffPack
1+
# DiffPack: A Torsional Diffusion Model for Autoregressive Protein Side-Chain Packing
2+
**DiffPack** is a novel torsional diffusion model designed for predicting the conformation of protein side-chains based on their backbones, as introduced in [arxiv link](https://arxiv.org/abs/2306.01794). By learning the joint distribution of side-chain torsional angles through a process of diffusing and denoising on the torsional space, DiffPack significantly improves angle accuracy across various benchmarks for protein side-chain packing.
3+
4+
5+
## Installation
6+
You can install DiffPack with the following commands, which will install all the dependencies.
7+
```shell
8+
conda create -n diffpack python=3.8
9+
conda activate diffpack
10+
```
11+
12+
```shell
13+
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
14+
conda install pyg -c pyg
15+
conda install torchdrug -c milagraph -c conda-forge -c pytorch -c pyg
16+
```
17+
18+
```shell
19+
pip install biopython==1.77
20+
pip install pyyaml
21+
pip install easydict
22+
```
23+
![framwork](asset/diffpack.png)
24+
25+
## Model Checkpoints
26+
We provide several versions of DiffPack, each with its own configuration and checkpoint:
27+
28+
| Model | Config | Checkpoint |
29+
|---------------------------------------|--------------------------------------------|-----------------------|
30+
| DiffPack (Vanila) | [Config](config/inference.yaml) | [Google Drive Link](https://drive.google.com/file/d/1tZ9ZOjIxq9SxrkdvbLJyLUBbt2P-mksO/view?usp=sharing) |
31+
| DiffPack (with Confidence Prediction) | [Config](config/inference_confidence.yaml) | [Google Drive Link](https://drive.google.com/file/d/1tZ9ZOjIxq9SxrkdvbLJyLUBbt2P-mksO/view?usp=sharing) |
32+
33+
The Vanilla version of DiffPack is the base model,
34+
while the version with Confidence Prediction includes an additional feature that estimates the confidence score of the predicted side-chain conformation.
35+
36+
Most of the configuration is specified in the configuration file. We list some important configuration hyperparameters here:
37+
- `mode`: Backward mode in diffusion process. We use `ode` or `sde` for DiffPack.
38+
- `annealed_temp`: Annealing temperature in diffusion process. We use `3` for DiffPack. Ideally, higher value corresponds to lower temperature.
39+
- `num_sample`: Number of samples in diffusion process. Confidence model will decide which sample to use.
40+
41+
## Running DiffPack
42+
To use DiffPack for new proteins on your local machine, we provide the necessary configuration files in the config/ folder.
43+
For instance, if you have two pdb files 1a3a.pdb and 1a3b.pdb,
44+
you can run the following command to infer new proteins and save the results in your chosen output folder:
45+
```shell
46+
python script/inference.py -c config/inference_confidence.yaml \
47+
--seed 2023 \
48+
--output_dir path/to/output \
49+
--pdb_files 1a3a.pdb 1a3b.pdb ...
50+
```
51+
This command will generate and save the predicted side-chain conformations for the given proteins.
52+
53+
## Retraining DiffPack
54+
For those interested in training DiffPack on their own datasets, we will soon release the code and instructions for this process.
55+
Stay tuned for updates!
56+
57+
## Visualization of Results
58+
![Visualization](asset/result.png)
59+
60+
## License
61+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
62+
## Citation
63+
If you find DiffPack useful in your research or project, please cite our paper:
64+
```
65+
@article{zhang2023diffpack,
66+
title={DiffPack: A Torsional Diffusion Model for Autoregressive Protein Side-Chain Packing},
67+
author={Zhang, Yangtian and Zhang, Zuobai and Zhong, Bozitao and Misra, Sanchit and Tang, Jian},
68+
journal={arXiv preprint arXiv:2306.01794},
69+
year={2023}
70+
}
71+
```
72+

asset/diffpack.png

1.38 MB
Loading

asset/result.png

9.31 MB
Loading

config/inference.yaml

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
test_set:
2+
class: SideChainDataset
3+
path: null
4+
atom_feature: residue_symbol
5+
bond_feature: null
6+
residue_feature: null
7+
mol_feature: null
8+
sanitize: true
9+
removeHs: true
10+
transform:
11+
class: Compose
12+
transforms: []
13+
14+
15+
task:
16+
class: TorsionalDiffusion
17+
train_chi_id: null
18+
schedule_1pi_periodic:
19+
class: SO2VESchedule
20+
pi_periodic: true
21+
annealed_temp: 3
22+
cache_folder: ~/scratch/output/diffpack
23+
mode: ode
24+
schedule_2pi_periodic:
25+
class: SO2VESchedule
26+
pi_periodic: false
27+
annealed_temp: 3
28+
cache_folder: ~/scratch/output/diffpack
29+
mode: ode
30+
sigma_embedding:
31+
class: SigmaEmbeddingLayer
32+
input_dim: 39
33+
hidden_dims: [ 64, 128 ]
34+
sigma_dim: 64
35+
model:
36+
class: GearNet
37+
input_dim: 128
38+
hidden_dims: [128, 128, 128, 128, 128, 128]
39+
batch_norm: True
40+
concat_hidden: True
41+
short_cut: True
42+
readout: 'sum'
43+
num_relation: 6
44+
edge_input_dim: 58
45+
num_angle_bin: 8
46+
torsion_mlp_hidden_dims: [ 64, 128 ]
47+
graph_construction_model:
48+
class: GraphConstruction
49+
edge_layers:
50+
- class: BondEdge
51+
- class: SpatialEdge
52+
radius: 4.5
53+
min_distance: 2
54+
- class: KNNEdge
55+
k: 10
56+
min_distance: 0
57+
edge_feature: gearnet
58+
59+
optimizer:
60+
class: Adam
61+
lr: 1.0e-4
62+
63+
engine:
64+
gpus: [0] #, 1, 2, 3]
65+
batch_size: 32
66+
log_interval: 1000
67+
68+
model_checkpoint: ~/scratch/trained_model/diffpack/gearnet_edge_confidence_converted.pth
69+
70+
train:
71+
num_epoch: 0

config/inference_confidence.yaml

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
test_set:
2+
class: SideChainDataset
3+
path: null
4+
atom_feature: residue_symbol
5+
bond_feature: null
6+
residue_feature: null
7+
mol_feature: null
8+
sanitize: true
9+
removeHs: true
10+
transform:
11+
class: Compose
12+
transforms: []
13+
14+
15+
task:
16+
class: ConfidencePrediction
17+
num_sample: 4
18+
num_mlp_layer: 3
19+
train_chi_id: null
20+
schedule_1pi_periodic:
21+
class: SO2VESchedule
22+
pi_periodic: true
23+
annealed_temp: 3
24+
cache_folder: ~/scratch/output/diffpack
25+
mode: ode
26+
schedule_2pi_periodic:
27+
class: SO2VESchedule
28+
pi_periodic: false
29+
annealed_temp: 3
30+
cache_folder: ~/scratch/output/diffpack
31+
mode: ode
32+
confidence_model:
33+
class: GearNet
34+
input_dim: 39
35+
hidden_dims: [ 128, 128, 128, 128, 128, 128 ]
36+
batch_norm: True
37+
concat_hidden: True
38+
short_cut: True
39+
readout: 'sum'
40+
num_relation: 6
41+
edge_input_dim: 58
42+
num_angle_bin: 8
43+
sigma_embedding:
44+
class: SigmaEmbeddingLayer
45+
input_dim: 39
46+
hidden_dims: [ 64, 128 ]
47+
sigma_dim: 64
48+
model:
49+
class: GearNet
50+
input_dim: 128
51+
hidden_dims: [128, 128, 128, 128, 128, 128]
52+
batch_norm: True
53+
concat_hidden: True
54+
short_cut: True
55+
readout: 'sum'
56+
num_relation: 6
57+
edge_input_dim: 58
58+
num_angle_bin: 8
59+
torsion_mlp_hidden_dims: [ 64, 128 ]
60+
graph_construction_model:
61+
class: GraphConstruction
62+
edge_layers:
63+
- class: BondEdge
64+
- class: SpatialEdge
65+
radius: 4.5
66+
min_distance: 2
67+
- class: KNNEdge
68+
k: 10
69+
min_distance: 0
70+
edge_feature: gearnet
71+
72+
optimizer:
73+
class: Adam
74+
lr: 1.0e-4
75+
76+
engine:
77+
gpus: [0] #, 1, 2, 3]
78+
batch_size: 32
79+
log_interval: 1000
80+
81+
model_checkpoint: ~/scratch/trained_model/diffpack/gearnet_edge_confidence_converted.pth
82+
83+
train:
84+
num_epoch: 0

diffpack/__init__.py

Whitespace-only changes.

diffpack/dataset.py

+143
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import glob
2+
import logging
3+
import os
4+
5+
import torch
6+
from rdkit import Chem
7+
from torchdrug import data
8+
from torchdrug.core import Registry as R
9+
from torchdrug.layers import functional
10+
from tqdm import tqdm
11+
12+
from diffpack import rotamer
13+
from diffpack.rotamer import get_chi_mask, atom_name_vocab, bb_atom_name
14+
15+
logging.basicConfig(level=logging.DEBUG)
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@R.register("datasets.SideChainDataset")
21+
class SideChainDataset(data.ProteinDataset):
22+
processed_file = None
23+
exclude_pdb_files = []
24+
25+
def __init__(self, path=None, pdb_files=None, verbose=1, **kwargs):
26+
if path is not None:
27+
logger.info("Loading dataset from folder %s" % path)
28+
path = os.path.expanduser(path)
29+
if not os.path.exists(path):
30+
os.makedirs(path)
31+
self.path = path
32+
pkl_file = os.path.join(path, self.processed_file)
33+
34+
if os.path.exists(pkl_file):
35+
logger.info("Found existing pickle file %s" % pkl_file
36+
+ ". Loading from pickle file (this may take a while)")
37+
self.load_pickle(pkl_file, verbose=verbose, **kwargs)
38+
else:
39+
logger.info("No pickle file found. Loading from pdb files (this may take a while)"
40+
+ " and save to pickle file %s" % pkl_file)
41+
pdb_files = sorted(glob.glob(os.path.join(path, "*.pdb")))
42+
self.load_pdbs(pdb_files, verbose=verbose, **kwargs)
43+
self.save_pickle(pkl_file, verbose=verbose)
44+
elif pdb_files is not None:
45+
logger.info("Loading dataset from pdb files")
46+
pdb_files = [os.path.expanduser(pdb_file) for pdb_file in pdb_files]
47+
pdb_files = [pdb_file for pdb_file in pdb_files if pdb_file.endswith(".pdb")]
48+
self.load_pdbs(pdb_files, verbose=verbose, **kwargs)
49+
50+
# Filter out proteins with no residues
51+
indexes = [i for i, (protein, pdb_file) in enumerate(zip(self.data, self.pdb_files))
52+
if (protein.num_residue > 0).all() and os.path.basename(pdb_file) not in self.exclude_pdb_files]
53+
self.data = [self.data[i] for i in indexes]
54+
self.sequences = [self.sequences[i] for i in indexes]
55+
self.pdb_files = [self.pdb_files[i] for i in indexes]
56+
57+
def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, sanitize=True, removeHs=True, **kwargs):
58+
"""
59+
Load the dataset from pdb files.
60+
61+
Parameters:
62+
pdb_files (list of str): pdb file names
63+
transform (Callable, optional): protein sequence transformation function
64+
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
65+
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
66+
verbose (int, optional): output verbose level
67+
**kwargs
68+
"""
69+
num_sample = len(pdb_files)
70+
71+
self.transform = transform
72+
self.lazy = lazy
73+
self.kwargs = kwargs
74+
self.data = []
75+
self.pdb_files = []
76+
self.sequences = []
77+
78+
if verbose:
79+
pdb_files = tqdm(pdb_files, "Constructing proteins from pdbs")
80+
for i, pdb_file in enumerate(pdb_files):
81+
if not lazy or i == 0:
82+
mol = Chem.MolFromPDBFile(pdb_file, sanitize=sanitize, removeHs=removeHs)
83+
if not mol:
84+
logger.debug("Can't construct molecule from pdb file `%s`. Ignore this sample." % pdb_file)
85+
continue
86+
protein = data.Protein.from_molecule(mol, **kwargs)
87+
if not protein:
88+
logger.debug("Can't construct protein from pdb file `%s`. Ignore this sample." % pdb_file)
89+
continue
90+
else:
91+
protein = None
92+
if hasattr(protein, "residue_feature"):
93+
with protein.residue():
94+
protein.residue_feature = protein.residue_feature.to_sparse()
95+
self.data.append(protein)
96+
self.pdb_files.append(pdb_file)
97+
self.sequences.append(protein.to_sequence() if protein else None)
98+
99+
def get_item(self, index):
100+
if getattr(self, "lazy", False):
101+
protein = data.Protein.from_pdb(self.pdb_files[index], **self.kwargs)
102+
else:
103+
protein = self.data[index].clone()
104+
protein = protein.subgraph(protein.atom_name < 37)
105+
106+
with protein.atom():
107+
# Init atom14 index map
108+
protein.atom14index = rotamer.restype_atom14_index_map[
109+
protein.residue_type[protein.atom2residue], protein.atom_name
110+
] # [num_atom, 14]
111+
112+
with protein.residue():
113+
# Init residue features
114+
protein.residue_feature = functional.one_hot(protein.residue_type, 21) # [num_residue, 21]
115+
116+
# Init residue masks
117+
chi_mask = get_chi_mask(protein)
118+
chi_1pi_periodic_mask = torch.tensor(rotamer.chi_pi_periodic)[protein.residue_type]
119+
chi_2pi_periodic_mask = ~chi_1pi_periodic_mask
120+
protein.chi_mask = chi_mask
121+
protein.chi_1pi_periodic_mask = torch.logical_and(chi_mask, chi_1pi_periodic_mask) # [num_residue, 4]
122+
protein.chi_2pi_periodic_mask = torch.logical_and(chi_mask, chi_2pi_periodic_mask) # [num_residue, 4]
123+
124+
# Init atom37 features
125+
protein.atom37_mask = torch.zeros(protein.num_residue, len(atom_name_vocab), device=protein.device,
126+
dtype=torch.bool) # [num_residue, 37]
127+
protein.atom37_mask[protein.atom2residue, protein.atom_name] = True
128+
protein.sidechain37_mask = protein.atom37_mask.clone() # [num_residue, 37]
129+
protein.sidechain37_mask[:, bb_atom_name] = False
130+
item = {"graph": protein}
131+
132+
if self.transform:
133+
item = self.transform(item)
134+
return item
135+
136+
@staticmethod
137+
def from_pdb_files(pdb_files, verbose=1, **kwargs):
138+
return SideChainDataset(pdb_files, verbose=verbose, **kwargs)
139+
140+
def __repr__(self):
141+
lines = ["#sample: %d" % len(self)]
142+
return "%s( %s)" % (self.__class__.__name__, "\n ".join(lines))
143+

0 commit comments

Comments
 (0)