@@ -136,6 +136,9 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
136
136
filename_datetime_suffix = now .strftime ("%Y%m%d_%H%M%S" )
137
137
# Append the datetime string to the existing filename in the cfg dictionary
138
138
cfg ["trainer" ]["model_checkpoint" ]["filename" ] += f"_{ filename_datetime_suffix } "
139
+ cfg ["trainer" ]["model_checkpoint" ]["dirpath" ] = (
140
+ cfg ["trainer" ]["model_checkpoint" ]["dirpath" ][:- 1 ] + f"_{ filename_datetime_suffix } "
141
+ )
139
142
140
143
dst_dir = cfg ["constants" ].get ("results_dir" )
141
144
hydra_cfg = HydraConfig .get ()
@@ -236,7 +239,7 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
236
239
predictor .set_max_nodes_edges_per_graph (datamodule , stages = ["train" , "val" ])
237
240
238
241
# When resuming training from a checkpoint, we need to provide the path to the checkpoint in the config
239
- resume_ckpt_path = cfg ["trainer" ].pop ("resume_from_checkpoint" , None )
242
+ resume_ckpt_path = cfg ["trainer" ].get ("resume_from_checkpoint" , None )
240
243
241
244
# Run the model training
242
245
with SafeRun (name = "TRAINING" , raise_error = cfg ["constants" ]["raise_train_error" ], verbose = True ):
@@ -266,10 +269,11 @@ def run_training_finetuning_testing(cfg: DictConfig) -> None:
266
269
logger .info ("Total compute time:" , timeit .default_timer () - st )
267
270
logger .info ("-" * 50 )
268
271
269
- if wandb_cfg is not None :
272
+ save_checkpoint_to_wandb = cfg ["trainer" ].get ("save_checkpoint_to_wandb" )
273
+ if save_checkpoint_to_wandb is True :
270
274
# Save initial model state - and upload checkpoint to wandb
271
275
if cfg ["trainer" ]["model_checkpoint" ]["save_last" ] is True :
272
- checkpoint_path = f"{ cfg ['trainer' ]['model_checkpoint' ]['dirpath' ]} { cfg ['trainer' ]['model_checkpoint' ]['filename' ]} .ckpt"
276
+ checkpoint_path = f"{ cfg ['trainer' ]['model_checkpoint' ]['dirpath' ]} / { cfg ['trainer' ]['model_checkpoint' ]['filename' ]} -v1 .ckpt"
273
277
# Log the initial model checkpoint to wandb
274
278
wandb .save (checkpoint_path )
275
279
wandb .finish ()
0 commit comments