diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 6b4e825823..a11dbfa70e 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -10,7 +10,7 @@ import logging import math from datetime import timedelta -from typing import Any, cast, Iterable, List, Literal, Optional, Union +from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union import fsspec @@ -39,6 +39,7 @@ Phase, ) from torchtnt.utils.distributed import get_world_size, PGWrapper +from torchtnt.utils.event_handlers import log_interval from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn logger: logging.Logger = logging.getLogger(__name__) @@ -201,70 +202,85 @@ def _generate_checkpoint_and_upkeep( Returns: True if checkpoint was successfully saved. False otherwise. """ - # 1) generate checkpoint name - epoch = _get_epoch(state, unit) - step_mapping = _get_step_phase_mapping(state, unit) - - # 1.1) append metric data only if best_checkpoint_config is defined - metric_data: Optional[MetricData] = None - if self._best_checkpoint_config and ( - metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit)) - ): - metric_data = MetricData( - name=none_throws(self._best_checkpoint_config).monitored_metric, - value=metric_value, - ) - - checkpoint_path = self._checkpoint_manager.generate_checkpoint_path( - epoch, - step_mapping, - metric_data, - process_group=self._process_group, - ) - - # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints - # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported. - if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path): - return False + log_interval_metadata: Dict[str, str] = { + "category": "checkpointing", + "active_phase": str(state.active_phase), + "hook": hook, + "epoch": str(_get_epoch(state, unit)), + "step": str( + _get_step_phase_mapping(state, unit).get( + state.active_phase.into_phase(), 0 + ) + ), + } - if hook == "on_train_end": - # 2.1) Make sure that last checkpoint does not already exist - if self._checkpoint_manager.does_checkpoint_exist( - checkpoint_path, self._process_group + with log_interval( + "_generate_checkpoint_and_upkeep", metadata=log_interval_metadata + ): + # 1) generate checkpoint name + epoch = _get_epoch(state, unit) + step_mapping = _get_step_phase_mapping(state, unit) + + # 1.1) append metric data only if best_checkpoint_config is defined + metric_data: Optional[MetricData] = None + if self._best_checkpoint_config and ( + metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit)) ): - rank_zero_warn( - "Final checkpoint already exists, skipping.", logger=logger + metric_data = MetricData( + name=none_throws(self._best_checkpoint_config).monitored_metric, + value=metric_value, ) + + checkpoint_path = self._checkpoint_manager.generate_checkpoint_path( + epoch, + step_mapping, + metric_data, + process_group=self._process_group, + ) + + # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints + # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported. + if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path): return False - # 2.2) If doing fit without eval checkpointing, only consider training progress when - # checking if last checkpoint exists. - if ( - state.entry_point == EntryPoint.FIT - and self._save_every_n_eval_epochs is None - and self._checkpoint_manager._ckpt_paths - and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN] - == cast(TTrainUnit, unit).train_progress.num_steps_completed + if hook == "on_train_end": + # 2.1) Make sure that last checkpoint does not already exist + if self._checkpoint_manager.does_checkpoint_exist( + checkpoint_path, self._process_group + ): + rank_zero_warn( + "Final checkpoint already exists, skipping.", logger=logger + ) + return False + + # 2.2) If doing fit without eval checkpointing, only consider training progress when + # checking if last checkpoint exists. + if ( + state.entry_point == EntryPoint.FIT + and self._save_every_n_eval_epochs is None + and self._checkpoint_manager._ckpt_paths + and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN] + == cast(TTrainUnit, unit).train_progress.num_steps_completed + ): + rank_zero_info( + "Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.", + logger=logger, + ) + return False + + # 3) try to save checkpoint + if not self._checkpoint_impl( + state, unit, checkpoint_id=checkpoint_path.path, hook=hook ): - rank_zero_info( - "Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.", - logger=logger, - ) return False - # 3) try to save checkpoint - if not self._checkpoint_impl( - state, unit, checkpoint_id=checkpoint_path.path, hook=hook - ): - return False + # 4) track checkpoint and clean up surplus if needed + self._checkpoint_manager.append_checkpoint(checkpoint_path) - # 4) track checkpoint and clean up surplus if needed - self._checkpoint_manager.append_checkpoint(checkpoint_path) + # 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully + unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path) - # 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully - unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path) - - return True + return True def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]: """