Skip to content

Commit e80ab33

Browse files
committed
camera ready
0 parents  commit e80ab33

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+5367
-0
lines changed

README.md

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# ProtST: Multi-Modality Learning of Protein Sequences and Biomedical Texts
2+
3+
ProtST is an advanced pretraining framework for protein sequence understanding and prediction, as introduced in our [ICML2023 oral paper](https://arxiv.org/abs/2301.12040). It is designed to enhance protein sequence pre-training and understanding by integrating protein functions and other important properties through biomedical texts.
4+
5+
The effectiveness and superiority of ProtST-induced PLMs over previous ones are demonstrated on diverse representation learning downstream tasks and zero-shot predictions. It also enables functional protein retrieval from large-scale databases even without any function annotation, as illustrated below.
6+
7+
![ProtST](asset/framework.png)
8+
9+
# Installation #
10+
11+
You may install the dependencies of TorchProtein and ProtST as below.
12+
Generally, they work with Python 3.7/3.8 and PyTorch version >= 1.8.0.
13+
14+
```bash
15+
conda create -n protein python=3.9
16+
conda activate protein
17+
18+
conda install pytorch==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia
19+
conda install torchdrug pytorch-sparse pytorch-scatter pytorch-cluster -c pytorch -c pyg -c milagraph
20+
21+
conda install scikit-learn pandas decorator ipython networkx tqdm matplotlib -y
22+
conda install fair-esm transformers easydict pyyaml lmdb -c conda-forge
23+
```
24+
25+
# Pre-trained Model Zoo
26+
27+
| Model | Config | Ckpt |
28+
|:---------------:|:------------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------:|
29+
| ProtST-ESM-1b | [config](config/pretrain/pretrain_esm.yaml) | [ckpt](https://protsl.s3.us-east-2.amazonaws.com/checkpoints/protst_esm1b.pth) |
30+
| ProtST-ESM-2 | [config](config/pretrain/pretrain_esm.yaml) | [ckpt](https://protsl.s3.us-east-2.amazonaws.com/checkpoints/protst_esm2.pth) | |
31+
| ProtST-ProtBert | [config](config/pretrain/pretrain_protbert.yaml) | [ckpt](https://protsl.s3.us-east-2.amazonaws.com/checkpoints/protst_protbert.pth) |
32+
33+
# Usage
34+
35+
To reproduce all the experiments in ProtST, we provide all the necessary configuration files at `config/.../*.yaml`, which are categorized by the dataset, model architecture, and hyperparameters. When running experiments, we specify the configuration file with an argument `--config` and all the required arguments marked by `{{ }}` in that configuration file.
36+
37+
Note that all the datasets will be automatically downloaded in the code. But if you are using clusters without Internet connection, please run `python ./script/prepare_all_datasets.py` to cache datasets in advance.
38+
39+
## Pre-training
40+
41+
By default, we pretrain 3 different PLM backbones (ESM-1b, ESM2 and ProtBert) using 4 V100 GPUs with the following command. Note that we have the choice of using two versions of text encoders: PebMedBert trained with only abstracts `PebMedBert-abs` and PebMedBert trained with full papers `PebMedBert-full`.
42+
43+
```
44+
alias python4proc='python -m torch.distributed.launch --nproc_per_node=4'
45+
46+
# pretrain ESM-1b
47+
python4proc script/run_pretrain.py --config ./config/pretrain/pretrain_esm.yaml --protein_model ESM-1b --text_model PubMedBERT-abs
48+
49+
# pretrain ESM-2
50+
python4proc script/run_pretrain.py --config ./config/pretrain/pretrain_esm.yaml --protein_model ESM-2-650M --text_model PubMedBERT-abs
51+
52+
# pretrain ProtBert
53+
python4proc script/run_pretrain.py --config ./config/pretrain/pretrain_protbert.yaml --text_model PubMedBERT-abs
54+
```
55+
56+
## Downstream Tasks: Representation Learning
57+
58+
For representation learning, we verify our pre-trained multimodal PLMs on 11 standard benchmarks for protein localization prediction, fitness landscape prediction and protein function annotation, under both fix-encoder learning and full-model tuning settings.
59+
60+
We label the pretrained checkpoints as `PRETRAIN_CHECKPOINT`. For different PLM backbone, the corresponding configuration files are in `./config/downstream_task/.../*.yaml`. We give a demonstration for ProtST-enhanced ESM-1b.
61+
62+
### Protein Localization Prediction
63+
64+
For binary localization prediction, you can run as below to perform fix-encoder learning and full-model tuning, respectively:
65+
66+
```
67+
# fix-encoder learning
68+
python4proc ./script/run_downstream.py --config ./config/downstream_task/PretrainESM/localization_fix.yaml --checkpoint $PRETRAIN_CHECKPOINT --dataset BinaryLocalization --num_class 2
69+
70+
# full-model tuning
71+
python4proc ./script/run_downstream.py --config ./config/downstream_task/PretrainESM/localization_tune.yaml --checkpoint $PRETRAIN_CHECKPOINT --dataset BinaryLocalization --num_class 2
72+
```
73+
74+
**Note that**, subcellular localization can be performed in the similar way (please see `./config` for details).
75+
76+
### Fitness Landscape Prediction
77+
78+
For Beta-Lactamase fitness prediction, you can run as below to perform fix-encoder learning and full-model tuning, respectively:
79+
80+
```
81+
# fix-encoder learning
82+
python4proc ./script/run_downstream.py --config ./config/downstream_task/PretrainESM/fitness_fix.yaml --checkpoint $PRETRAIN_CHECKPOINT --dataset BetaLactamase --batch_size 32
83+
84+
# full-model tuning
85+
python4proc ./script/run_downstream.py --config ./config/downstream_task/PretrainESM/fitness_tune.yaml --checkpoint $PRETRAIN_CHECKPOINT --dataset BetaLactamase --batch_size 6
86+
```
87+
88+
**Note that**, Fluorescence, Stability, AAV and Thermostability prediction can be performed in the similar way (please see `./config` for details).
89+
90+
### Protein Function Annotation
91+
92+
For Enzyme Commission (EC) number prediction, you can run as below to perform full-model tuning:
93+
94+
```
95+
python4proc ./script/run_downstream.py --config ./config/downstream_task/PretrainESM/annotation_tune.yaml --checkpoint $PRETRAIN_CHECKPOINT --dataset td_datasets.EnzymeCommission --branch null
96+
```
97+
98+
**Note that**, the Gene Ontology (GO) term prediction at Molecular Function (MF), Biological Process (BP) and Cellular Component (CC) branches can be performed in the similar way (please see `./config` for details).
99+
100+
## Downstream Tasks: Zero-shot Protein Classification
101+
102+
### Zero-shot Predictors
103+
104+
ProtST supports zero-shot protein classification, where it does not require any labeled protein. This is achieved by comparing representation similarities between a query protein and all labels, thanks to the aligned representation space of protein sequences and label descriptions in ProtST.
105+
106+
We demonstrate on zero-shot subcellular localization prediction and zero-shot reaction classification with ProtST-enhanced ESM-1b. We have also explored different prompt templates and description fields as listed in `./data/zero_shot_classification/`.
107+
108+
```
109+
# Subcellular Localization Prediction
110+
111+
python ./script/run_zero_shot.py --config ./config/zero_shot/PretrainESM/zero_shot.yaml --checkpoint $PRETRAIN_CHECKPOINT --prompt_label ./data/zero_shot_classification/subloc_name.tsv --dataset SubcellularLocalization --field "['name']"
112+
113+
# Reaction Classification
114+
115+
python ./script/run_zero_shot.py --config ./config/zero_shot/PretrainESM/zero_shot.yaml --checkpoint $PRETRAIN_CHECKPOINT --prompt_label ./data/zero_shot_classification/reaction_name.tsv --dataset Reaction --field "['name']"
116+
```
117+
118+
### Few-shot and Non-parametric Baselines
119+
120+
ProtST-induced zero-shot classifiers have better data efficiency against various few-shot and non-parametric classifiers. You can run these baselines as below:
121+
122+
```
123+
# few-shot classifiers
124+
125+
## Subcellular Localization Prediction
126+
127+
python ./script/run_few_shot.py --config ./config/few_shot/PretrainESM/few_shot.yaml --dataset SubcellularLocalization --num_class 10 --checkpoint $PRETRAIN_CHECKPOINT
128+
129+
## Reaction Classification
130+
131+
python ./script/run_few_shot.py --config ./config/few_shot/PretrainESM/few_shot.yaml --dataset Reaction --num_class 384 --checkpoint $PRETRAIN_CHECKPOINT
132+
133+
# non-parametric few-shot classifiers
134+
135+
## Subcellular Localization Prediction
136+
137+
python ./script/run_few_shot_nonparam.py --config ./config/few_shot/PretrainESM/few_shot.yaml --dataset SubcellularLocalization --num_class 10 --checkpoint $PRETRAIN_CHECKPOINT
138+
139+
## Reaction Classification
140+
141+
python ./script/run_few_shot_nonparam.py --config ./config/few_shot/PretrainESM/few_shot.yaml --dataset Reaction --num_class 384 --checkpoint $PRETRAIN_CHECKPOINT
142+
```
143+
144+
### Predictor Ensemble
145+
146+
We also show that ProtST-based zero-shot predictor can enhance the performance of supervised learning models via ensemble. We use the following scripts to do ensembles, where `SUPERVISED_CHECKPOINT` refers to the checkpoints obtained by supervised learning on downstream tasks.
147+
148+
```
149+
## Subcellular Localization Prediction
150+
151+
python ./script/run_supervised_with_zero.py -sc ./config/downstream_task/PretrainESM/localization_fix.yaml -zc ./config/zero_shot/zero_shot.yaml --dataset SubcellularLocalization --num_class 10 --prompt_label ./data/zero_shot_classification/subloc_name.tsv --field "['name']" --checkpoint $PRETRAIN_CHECKPOINT --supervised_checkpoint $SUPERVISED_CHECKPOINT
152+
153+
## Reaction Classification
154+
155+
python ./script/run_supervised_with_zero.py -sc ./config/downstream_task/PretrainESM/reaction_tune.yaml -zc ./config/zero_shot/zero_shot.yaml --dataset Reaction --num_class 384 --prompt_label ./data/zero_shot_classification/reaction_name.tsv --field "['name']" --checkpoint $PRETRAIN_CHECKPOINT --supervised_checkpoint $SUPERVISED_CHECKPOINT
156+
```
157+
158+
## Downstream Tasks: Text to Protein Retrieval
159+
160+
We illustrate the capability of ProtST-ESM-1b on retrieving functional proteins as below, where no function annotation is required:
161+
162+
```
163+
python ./script/run_t2p_retrieval.py --config ./config/t2p_retrieval/go_mf.yaml --checkpoint $PRETRAIN_CHECKPOINT
164+
```
165+
166+
# Citation
167+
168+
If you find this project helpful, please cite our paper:
169+
170+
```
171+
@article{xu2023protst,
172+
title={ProtST: Multi-Modality Learning of Protein Sequences and Biomedical Texts},
173+
author={Xu, Minghao and Yuan, Xinyu and Miret, Santiago and Tang, Jian},
174+
journal={arXiv preprint arXiv:2301.12040},
175+
year={2023}
176+
}
177+
```
178+
179+
# Contact
180+
181+
For any questions or issues, open an issue or contact
182+
Minghao Xu ([email protected]) and Xinyu Yuan ([email protected]).

asset/framework.png

457 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
output_dir: ~/scratch/protst_output/
2+
checkpoint: {{ checkpoint }}
3+
4+
dataset:
5+
class: {{ dataset }} # td_datasets.EnzymeCommission / GeneOntology
6+
path: ~/scratch/protein-datasets/
7+
branch: {{ branch }} # EC: null; GO_MF: MF; GO_CC: CC; GO_BP: BP
8+
test_cutoff: 0.95
9+
atom_feature: null
10+
bond_feature: null
11+
transform:
12+
class: Compose
13+
transforms:
14+
- class: ProteinView
15+
view: residue
16+
- class: TruncateProtein
17+
max_length: 550
18+
19+
task:
20+
class: MultipleBinaryClassification
21+
model:
22+
class: PretrainESM
23+
path: ~/scratch/esm-model-weights/
24+
model: ESM-1b
25+
mask_modeling: False
26+
output_dim: 512
27+
readout: mean
28+
use_proj: False
29+
criterion: bce
30+
metric: ['auprc@micro', 'f1_max']
31+
num_mlp_layer: 2
32+
33+
optimizer:
34+
class: Adam
35+
lr: 1.0e-4
36+
37+
engine:
38+
gpus: [0, 1, 2, 3]
39+
batch_size: 2
40+
log_interval: 1000
41+
42+
lr_ratio: 0.1
43+
44+
eval_metric: f1_max
45+
46+
train:
47+
num_epoch: 50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
output_dir: ~/scratch/protst_output/
2+
checkpoint: {{ checkpoint }}
3+
4+
dataset:
5+
class: {{ dataset }} # BetaLactamase / Fluorescence / Stability / td_datasets.AAV / td_datasets.Thermostability
6+
path: ~/scratch/protein-datasets/
7+
atom_feature: null
8+
bond_feature: null
9+
transform:
10+
class: Compose
11+
transforms:
12+
- class: ProteinView
13+
view: "residue"
14+
15+
task:
16+
class: PropertyPrediction
17+
model:
18+
class: PretrainESM
19+
path: ~/scratch/esm-model-weights/
20+
model: ESM-1b
21+
mask_modeling: False
22+
output_dim: 512
23+
readout: mean
24+
use_proj: False
25+
criterion: mse
26+
metric: ["mae", "rmse", "spearmanr"]
27+
normalization: False
28+
num_mlp_layer: 2
29+
30+
eval_metric: spearmanr
31+
32+
optimizer:
33+
class: Adam
34+
lr: 5.0e-5
35+
36+
fix_encoder: True
37+
38+
engine:
39+
gpus: [0, 1, 2, 3]
40+
batch_size: {{ batch_size }} # td_datasets.Thermostability: 8; others: 32
41+
42+
train:
43+
num_epoch: 100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
output_dir: ~/scratch/protst_output/
2+
checkpoint: {{ checkpoint }}
3+
4+
dataset:
5+
class: {{ dataset }} # BetaLactamase / Fluorescence / Stability / td_datasets.AAV / td_datasets.Thermostability
6+
path: ~/scratch/protein-datasets/
7+
atom_feature: null
8+
bond_feature: null
9+
transform:
10+
class: Compose
11+
transforms:
12+
- class: ProteinView
13+
view: "residue"
14+
15+
task:
16+
class: PropertyPrediction
17+
model:
18+
class: PretrainESM
19+
path: ~/scratch/esm-model-weights/
20+
model: ESM-1b
21+
mask_modeling: False
22+
output_dim: 512
23+
readout: mean
24+
use_proj: False
25+
criterion: mse
26+
metric: ["mae", "rmse", "spearmanr"]
27+
normalization: False
28+
num_mlp_layer: 2
29+
30+
eval_metric: spearmanr
31+
32+
optimizer:
33+
class: Adam
34+
lr: 2.0e-4
35+
36+
lr_ratio: 0.1
37+
38+
engine:
39+
gpus: [0, 1, 2, 3]
40+
batch_size: {{ batch_size }} # td_datasets.Thermostability: 1; others: 6
41+
42+
train:
43+
num_epoch: 100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
output_dir: ~/scratch/protst_output/
2+
checkpoint: {{ checkpoint }}
3+
4+
dataset:
5+
class: {{ dataset }} # BinaryLocalization / SubcellularLocalization
6+
path: ~/scratch/protein-datasets/
7+
atom_feature: null
8+
bond_feature: null
9+
transform:
10+
class: Compose
11+
transforms:
12+
- class: ProteinView
13+
view: "residue"
14+
15+
task:
16+
class: PropertyPrediction
17+
model:
18+
class: PretrainESM
19+
path: ~/scratch/esm-model-weights/
20+
model: ESM-1b
21+
mask_modeling: False
22+
output_dim: 512
23+
readout: mean
24+
use_proj: False
25+
criterion: ce
26+
metric: ["acc", "mcc"]
27+
num_mlp_layer: 2
28+
num_class: {{ num_class }}
29+
30+
eval_metric: accuracy
31+
32+
optimizer:
33+
class: Adam
34+
lr: 5.0e-5
35+
36+
fix_encoder: True
37+
38+
engine:
39+
gpus: [0, 1, 2, 3]
40+
batch_size: 32
41+
42+
train:
43+
num_epoch: 100

0 commit comments

Comments
 (0)