Skip to content

Commit 50dac0e

Browse files
committed
add pretrain script & config
1 parent 1e5ab6d commit 50dac0e

File tree

7 files changed

+396
-0
lines changed

7 files changed

+396
-0
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
output_dir: ~/scratch/protein_outputs
2+
3+
dataset:
4+
class: AlphaFoldDB
5+
path: ~/scratch/protein-datasets/alphafold
6+
species_start: 0
7+
species_end: 22
8+
# species_id: 0
9+
# split_id: 1
10+
transform:
11+
class: Compose
12+
transforms:
13+
- class: ProteinView
14+
view: residue
15+
- class: TruncateProtein
16+
max_length: 100
17+
random: True
18+
19+
task:
20+
class: AnglePrediction
21+
num_sample: 512
22+
num_class: 8
23+
num_mlp_layer: 2
24+
model:
25+
class: GearNet
26+
input_dim: 21
27+
hidden_dims: [512, 512, 512, 512, 512, 512]
28+
batch_norm: True
29+
concat_hidden: True
30+
short_cut: True
31+
readout: 'sum'
32+
num_relation: 7
33+
edge_input_dim: 59
34+
num_angle_bin: 8
35+
graph_construction_model:
36+
class: GraphConstruction
37+
node_layers:
38+
- class: AlphaCarbonNode
39+
edge_layers:
40+
- class: SequentialEdge
41+
max_distance: 2
42+
- class: SpatialEdge
43+
radius: 10.0
44+
min_distance: 5
45+
- class: KNNEdge
46+
k: 10
47+
min_distance: 5
48+
edge_feature: gearnet
49+
50+
optimizer:
51+
class: Adam
52+
lr: 1.0e-3
53+
54+
engine:
55+
gpus: {{ gpus }}
56+
batch_size: 96
57+
log_interval: 100
58+
59+
save_interval: 5
60+
61+
train:
62+
num_epoch: 50
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
output_dir: ~/scratch/protein_outputs
2+
3+
dataset:
4+
class: AlphaFoldDB
5+
path: ~/scratch/protein-datasets/alphafold
6+
species_start: 0
7+
species_end: 22
8+
# species_id: 0
9+
# split_id: 1
10+
transform:
11+
class: Compose
12+
transforms:
13+
- class: ProteinView
14+
view: residue
15+
- class: TruncateProtein
16+
max_length: 100
17+
random: True
18+
19+
task:
20+
class: AttributeMasking
21+
mask_rate: 0.15
22+
num_mlp_layer: 2
23+
model:
24+
class: GearNet
25+
input_dim: 21
26+
hidden_dims: [512, 512, 512, 512, 512, 512]
27+
batch_norm: True
28+
concat_hidden: True
29+
short_cut: True
30+
readout: 'sum'
31+
num_relation: 7
32+
edge_input_dim: 59
33+
num_angle_bin: 8
34+
graph_construction_model:
35+
class: GraphConstruction
36+
node_layers:
37+
- class: AlphaCarbonNode
38+
edge_layers:
39+
- class: SequentialEdge
40+
max_distance: 2
41+
- class: SpatialEdge
42+
radius: 10.0
43+
min_distance: 5
44+
- class: KNNEdge
45+
k: 10
46+
min_distance: 5
47+
edge_feature: gearnet
48+
49+
optimizer:
50+
class: Adam
51+
lr: 1.0e-3
52+
53+
engine:
54+
gpus: {{ gpus }}
55+
batch_size: 96
56+
log_interval: 100
57+
58+
save_interval: 5
59+
60+
train:
61+
num_epoch: 50
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
output_dir: ~/scratch/protein_outputs
2+
3+
dataset:
4+
class: AlphaFoldDB
5+
path: ~/scratch/protein-datasets/alphafold
6+
species_start: 0
7+
species_end: 22
8+
# species_id: 0
9+
# split_id: 1
10+
transform:
11+
class: Compose
12+
transforms:
13+
- class: ProteinView
14+
view: residue
15+
- class: TruncateProtein
16+
max_length: 100
17+
random: True
18+
19+
task:
20+
class: DihedralPrediction
21+
num_sample: 512
22+
num_class: 8
23+
num_mlp_layer: 2
24+
model:
25+
class: GearNet
26+
input_dim: 21
27+
hidden_dims: [512, 512, 512, 512, 512, 512]
28+
batch_norm: True
29+
concat_hidden: True
30+
short_cut: True
31+
readout: 'sum'
32+
num_relation: 7
33+
edge_input_dim: 59
34+
num_angle_bin: 8
35+
graph_construction_model:
36+
class: GraphConstruction
37+
node_layers:
38+
- class: AlphaCarbonNode
39+
edge_layers:
40+
- class: SequentialEdge
41+
max_distance: 2
42+
- class: SpatialEdge
43+
radius: 10.0
44+
min_distance: 5
45+
- class: KNNEdge
46+
k: 10
47+
min_distance: 5
48+
edge_feature: gearnet
49+
50+
optimizer:
51+
class: Adam
52+
lr: 1.0e-3
53+
54+
engine:
55+
gpus: {{ gpus }}
56+
batch_size: 96
57+
log_interval: 100
58+
59+
save_interval: 5
60+
61+
train:
62+
num_epoch: 50

config/pretrain/dis_gearnet_edge.yaml

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
output_dir: ~/scratch/protein_outputs
2+
3+
dataset:
4+
class: AlphaFoldDB
5+
path: ~/scratch/protein-datasets/alphafold
6+
species_start: 0
7+
species_end: 22
8+
# species_id: 0
9+
# split_id: 1
10+
transform:
11+
class: Compose
12+
transforms:
13+
- class: ProteinView
14+
view: residue
15+
- class: TruncateProtein
16+
max_length: 100
17+
random: True
18+
19+
task:
20+
class: DistancePrediction
21+
num_sample: 256
22+
num_mlp_layer: 2
23+
model:
24+
class: GearNet
25+
input_dim: 21
26+
hidden_dims: [512, 512, 512, 512, 512, 512]
27+
batch_norm: True
28+
concat_hidden: True
29+
short_cut: True
30+
readout: 'sum'
31+
num_relation: 7
32+
edge_input_dim: 59
33+
num_angle_bin: 8
34+
graph_construction_model:
35+
class: GraphConstruction
36+
node_layers:
37+
- class: AlphaCarbonNode
38+
edge_layers:
39+
- class: SequentialEdge
40+
max_distance: 2
41+
- class: SpatialEdge
42+
radius: 10.0
43+
min_distance: 5
44+
- class: KNNEdge
45+
k: 10
46+
min_distance: 5
47+
edge_feature: gearnet
48+
49+
optimizer:
50+
class: Adam
51+
lr: 1.0e-3
52+
53+
engine:
54+
gpus: {{ gpus }}
55+
batch_size: 128
56+
log_interval: 100
57+
58+
save_interval: 5
59+
60+
train:
61+
num_epoch: 50

config/pretrain/mc_ESM_gearnet.yaml

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
output_dir: ~/scratch/protein_outputs
2+
3+
dataset:
4+
class: AlphaFoldDB
5+
path: ~/scratch/protein-datasets/alphafold
6+
species_start: 0
7+
species_end: 22
8+
# species_id: 0
9+
# split_id: 1
10+
transform:
11+
class: ProteinView
12+
view: residue
13+
14+
task:
15+
class: Unsupervised
16+
model:
17+
class: MultiviewContrast
18+
crop_funcs:
19+
- class: SubsequenceNode
20+
max_length: 50
21+
- class: SubspaceNode
22+
entity_level: residue
23+
min_neighbor: 15
24+
min_radius: 15.0
25+
noise_funcs:
26+
- class: IdentityNode
27+
- class: RandomEdgeMask
28+
mask_rate: 0.15
29+
model:
30+
class: GearNet
31+
input_dim: 21
32+
hidden_dims: [512, 512, 512, 512, 512, 512]
33+
batch_norm: True
34+
concat_hidden: True
35+
short_cut: True
36+
readout: 'sum'
37+
num_relation: 7
38+
edge_input_dim: 59
39+
num_angle_bin: 8
40+
graph_construction_model:
41+
class: GraphConstruction
42+
node_layers:
43+
- class: AlphaCarbonNode
44+
edge_layers:
45+
- class: SequentialEdge
46+
max_distance: 2
47+
- class: SpatialEdge
48+
radius: 10.0
49+
min_distance: 5
50+
- class: KNNEdge
51+
k: 10
52+
min_distance: 5
53+
edge_feature: gearnet
54+
55+
optimizer:
56+
class: Adam
57+
lr: 2.0e-4
58+
59+
engine:
60+
gpus: {{ gpus }}
61+
batch_size: 48
62+
log_interval: 100
63+
64+
save_interval: 5
65+
66+
train:
67+
num_epoch: 50

script/pretrain.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
import sys
3+
import math
4+
import pprint
5+
6+
import torch
7+
8+
from torchdrug import core, models, tasks, datasets, utils
9+
from torchdrug.utils import comm
10+
11+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
12+
import util
13+
14+
15+
def save(solver, path):
16+
if isinstance(solver.model, tasks.Unsupervised):
17+
model = solver.model.model.model
18+
else:
19+
model = solver.model.model
20+
21+
if comm.get_rank() == 0:
22+
logger.warning("Save checkpoint to %s" % path)
23+
path = os.path.expanduser(path)
24+
if comm.get_rank() == 0:
25+
torch.save(model.state_dict(), path)
26+
comm.synchronize()
27+
28+
29+
if __name__ == "__main__":
30+
args, vars = util.parse_args()
31+
cfg = util.load_config(args.config, context=vars)
32+
working_dir = util.create_working_directory(cfg)
33+
34+
torch.manual_seed(args.seed + comm.get_rank())
35+
36+
logger = util.get_root_logger()
37+
if comm.get_rank() == 0:
38+
logger.warning("Config file: %s" % args.config)
39+
logger.warning(pprint.pformat(cfg))
40+
41+
species_start = cfg.dataset.get("species_start", 0)
42+
species_end = cfg.dataset.get("species_end", 0)
43+
assert species_end >= species_start
44+
if species_end > species_start:
45+
cfg.dataset.species_id = species_start
46+
cfg.dataset.split_id = 0
47+
cfg.dataset.pop("species_start")
48+
cfg.dataset.pop("species_end")
49+
dataset = core.Configurable.load_config_dict(cfg.dataset)
50+
solver = util.build_pretrain_solver(cfg, dataset)
51+
52+
step = cfg.get("save_interval", 1)
53+
for i in range(0, cfg.train.num_epoch, step):
54+
kwargs = cfg.train.copy()
55+
kwargs["num_epoch"] = min(step, cfg.train.num_epoch - i)
56+
57+
if species_end == species_start:
58+
solver.train(**kwargs)
59+
else:
60+
for species_id in range(species_start, species_end):
61+
for split_id in range(dataset.species_nsplit[species_id]):
62+
cfg.dataset.species_id = species_id
63+
cfg.dataset.split_id = split_id
64+
dataset = core.Configurable.load_config_dict(cfg.dataset)
65+
logger.warning('Epoch: {}\tSpecies id: {}\tSplit id: {}\tSplit length: {}'.format(
66+
i, species_id, split_id, len(dataset)))
67+
solver.train_set = dataset
68+
solver.train(**kwargs)
69+
70+
save(solver, "model_epoch_%d.pth" % (i + kwargs["num_epoch"]))

0 commit comments

Comments
 (0)