Skip to content

Commit

Permalink
Add signpost interval for _generate_checkpoint_and_upkeep (#979)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #979

Reviewed By: anshulverma, JKSenthil

Differential Revision: D70585216

fbshipit-source-id: 0adeafe0269027bce2579ded61846e072afc455b
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Mar 5, 2025
1 parent 714ae04 commit fb2f350
Showing 1 changed file with 71 additions and 55 deletions.
126 changes: 71 additions & 55 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging
import math
from datetime import timedelta
from typing import Any, cast, Iterable, List, Literal, Optional, Union
from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union

import fsspec

Expand Down Expand Up @@ -39,6 +39,7 @@
Phase,
)
from torchtnt.utils.distributed import get_world_size, PGWrapper
from torchtnt.utils.event_handlers import log_interval
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -201,70 +202,85 @@ def _generate_checkpoint_and_upkeep(
Returns:
True if checkpoint was successfully saved. False otherwise.
"""
# 1) generate checkpoint name
epoch = _get_epoch(state, unit)
step_mapping = _get_step_phase_mapping(state, unit)

# 1.1) append metric data only if best_checkpoint_config is defined
metric_data: Optional[MetricData] = None
if self._best_checkpoint_config and (
metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit))
):
metric_data = MetricData(
name=none_throws(self._best_checkpoint_config).monitored_metric,
value=metric_value,
)

checkpoint_path = self._checkpoint_manager.generate_checkpoint_path(
epoch,
step_mapping,
metric_data,
process_group=self._process_group,
)

# 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints
# since neither best_checkpoint_config nor keep_last_n_checkpoints are supported.
if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path):
return False
log_interval_metadata: Dict[str, str] = {
"category": "checkpointing",
"active_phase": str(state.active_phase),
"hook": hook,
"epoch": str(_get_epoch(state, unit)),
"step": str(
_get_step_phase_mapping(state, unit).get(
state.active_phase.into_phase(), 0
)
),
}

if hook == "on_train_end":
# 2.1) Make sure that last checkpoint does not already exist
if self._checkpoint_manager.does_checkpoint_exist(
checkpoint_path, self._process_group
with log_interval(
"_generate_checkpoint_and_upkeep", metadata=log_interval_metadata
):
# 1) generate checkpoint name
epoch = _get_epoch(state, unit)
step_mapping = _get_step_phase_mapping(state, unit)

# 1.1) append metric data only if best_checkpoint_config is defined
metric_data: Optional[MetricData] = None
if self._best_checkpoint_config and (
metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit))
):
rank_zero_warn(
"Final checkpoint already exists, skipping.", logger=logger
metric_data = MetricData(
name=none_throws(self._best_checkpoint_config).monitored_metric,
value=metric_value,
)

checkpoint_path = self._checkpoint_manager.generate_checkpoint_path(
epoch,
step_mapping,
metric_data,
process_group=self._process_group,
)

# 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints
# since neither best_checkpoint_config nor keep_last_n_checkpoints are supported.
if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path):
return False

# 2.2) If doing fit without eval checkpointing, only consider training progress when
# checking if last checkpoint exists.
if (
state.entry_point == EntryPoint.FIT
and self._save_every_n_eval_epochs is None
and self._checkpoint_manager._ckpt_paths
and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN]
== cast(TTrainUnit, unit).train_progress.num_steps_completed
if hook == "on_train_end":
# 2.1) Make sure that last checkpoint does not already exist
if self._checkpoint_manager.does_checkpoint_exist(
checkpoint_path, self._process_group
):
rank_zero_warn(
"Final checkpoint already exists, skipping.", logger=logger
)
return False

# 2.2) If doing fit without eval checkpointing, only consider training progress when
# checking if last checkpoint exists.
if (
state.entry_point == EntryPoint.FIT
and self._save_every_n_eval_epochs is None
and self._checkpoint_manager._ckpt_paths
and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN]
== cast(TTrainUnit, unit).train_progress.num_steps_completed
):
rank_zero_info(
"Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.",
logger=logger,
)
return False

# 3) try to save checkpoint
if not self._checkpoint_impl(
state, unit, checkpoint_id=checkpoint_path.path, hook=hook
):
rank_zero_info(
"Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.",
logger=logger,
)
return False

# 3) try to save checkpoint
if not self._checkpoint_impl(
state, unit, checkpoint_id=checkpoint_path.path, hook=hook
):
return False
# 4) track checkpoint and clean up surplus if needed
self._checkpoint_manager.append_checkpoint(checkpoint_path)

# 4) track checkpoint and clean up surplus if needed
self._checkpoint_manager.append_checkpoint(checkpoint_path)
# 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully
unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path)

# 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully
unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path)

return True
return True

def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:
"""
Expand Down

0 comments on commit fb2f350

Please sign in to comment.