Skip to content

Commit a8e715c

Browse files
authored
Merge pull request #485 from datamol-io/make_wandb_upload_optional
make wandb checkpoint upload optional, change mlp_expansion_ratio def…
2 parents ef1cb40 + 06756fc commit a8e715c

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

graphium/cli/train_finetune_test.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
136136
filename_datetime_suffix = now.strftime("%Y%m%d_%H%M%S")
137137
# Append the datetime string to the existing filename in the cfg dictionary
138138
cfg["trainer"]["model_checkpoint"]["filename"] += f"_{filename_datetime_suffix}"
139+
cfg["trainer"]["model_checkpoint"]["dirpath"] = (
140+
cfg["trainer"]["model_checkpoint"]["dirpath"][:-1] + f"_{filename_datetime_suffix}"
141+
)
139142

140143
dst_dir = cfg["constants"].get("results_dir")
141144
hydra_cfg = HydraConfig.get()
@@ -236,7 +239,7 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
236239
predictor.set_max_nodes_edges_per_graph(datamodule, stages=["train", "val"])
237240

238241
# When resuming training from a checkpoint, we need to provide the path to the checkpoint in the config
239-
resume_ckpt_path = cfg["trainer"].pop("resume_from_checkpoint", None)
242+
resume_ckpt_path = cfg["trainer"].get("resume_from_checkpoint", None)
240243

241244
# Run the model training
242245
with SafeRun(name="TRAINING", raise_error=cfg["constants"]["raise_train_error"], verbose=True):
@@ -266,10 +269,11 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
266269
logger.info("Total compute time:", timeit.default_timer() - st)
267270
logger.info("-" * 50)
268271

269-
if wandb_cfg is not None:
272+
save_checkpoint_to_wandb = cfg["trainer"].get("save_checkpoint_to_wandb")
273+
if save_checkpoint_to_wandb is True:
270274
# Save initial model state - and upload checkpoint to wandb
271275
if cfg["trainer"]["model_checkpoint"]["save_last"] is True:
272-
checkpoint_path = f"{cfg['trainer']['model_checkpoint']['dirpath']}{cfg['trainer']['model_checkpoint']['filename']}.ckpt"
276+
checkpoint_path = f"{cfg['trainer']['model_checkpoint']['dirpath']}/{cfg['trainer']['model_checkpoint']['filename']}-v1.ckpt"
273277
# Log the initial model checkpoint to wandb
274278
wandb.save(checkpoint_path)
275279
wandb.finish()

0 commit comments

Comments
 (0)