Skip to content

Commit b73ff45

Browse files
committed
update ptreain config
1 parent 50dac0e commit b73ff45

File tree

5 files changed

+9
-7
lines changed

5 files changed

+9
-7
lines changed

config/pretrain/angle_gearnet_edge.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
output_dir: ~/scratch/protein_outputs
1+
output_dir: ~/scratch/protein_output
22

33
dataset:
44
class: AlphaFoldDB
55
path: ~/scratch/protein-datasets/alphafold
66
species_start: 0
77
species_end: 22
8-
# species_id: 0
8+
# species_id: 3
99
# split_id: 1
1010
transform:
1111
class: Compose

config/pretrain/dihedral_gearnet_edge.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
output_dir: ~/scratch/protein_outputs
1+
output_dir: ~/scratch/protein_output
22

33
dataset:
44
class: AlphaFoldDB

config/pretrain/dis_gearnet_edge.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
output_dir: ~/scratch/protein_outputs
1+
output_dir: ~/scratch/protein_output
22

33
dataset:
44
class: AlphaFoldDB

config/pretrain/mc_ESM_gearnet.yaml renamed to config/pretrain/mc_gearnet_edge.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
output_dir: ~/scratch/protein_outputs
1+
output_dir: ~/scratch/protein_output
22

33
dataset:
44
class: AlphaFoldDB
55
path: ~/scratch/protein-datasets/alphafold
66
species_start: 0
77
species_end: 22
8-
# species_id: 0
8+
# species_id: 3
99
# split_id: 1
1010
transform:
1111
class: ProteinView

util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def build_downstream_solver(cfg, dataset):
120120
solver.load(cfg.checkpoint)
121121

122122
if cfg.get("model_checkpoint") is not None:
123+
if comm.get_rank() == 0:
124+
logger.warning("Load checkpoint from %s" % cfg.model_checkpoint)
123125
cfg.model_checkpoint = os.path.expanduser(cfg.model_checkpoint)
124126
model_dict = torch.load(cfg.model_checkpoint, map_location=torch.device('cpu'))
125127
task.model.load_state_dict(model_dict)
@@ -137,4 +139,4 @@ def build_pretrain_solver(cfg, dataset):
137139
optimizer = core.Configurable.load_config_dict(cfg.optimizer)
138140
solver = core.Engine(task, dataset, None, None, optimizer, **cfg.engine)
139141

140-
return solver
142+
return solver

0 commit comments

Comments
 (0)