Skip to content

Commit

Permalink
Add _get_checkpoint_mapping function (#831)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #831

Reviewed By: JKSenthil

Differential Revision: D57118392

fbshipit-source-id: 319ff905786a181ff302ecb91a61af76561fc942
  • Loading branch information
diego-urgell authored and facebook-github-bot committed May 14, 2024
1 parent 2979b13 commit 5e1db7f
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 4 deletions.
29 changes: 28 additions & 1 deletion tests/framework/callbacks/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@

import unittest

from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state
from torch import nn

from torchtnt.framework._test_utils import (
DummyAutoUnit,
DummyTrainUnit,
get_dummy_eval_state,
get_dummy_fit_state,
get_dummy_train_state,
)

from torchtnt.framework.callbacks._checkpoint_utils import (
_get_step_phase_mapping,
_prepare_app_state_for_checkpoint,
)
from torchtnt.utils.checkpoint import Phase


class CheckpointUtilsTest(unittest.TestCase):
Expand All @@ -26,3 +36,20 @@ def test_get_app_state(self) -> None:
app_state.keys(),
["module", "optimizer", "loss_fn", "train_progress"],
)

def test_get_step_phase_mapping(self) -> None:
unit = DummyAutoUnit(module=nn.Linear(2, 2))
unit.train_progress._num_steps_completed = 5
unit.eval_progress._num_steps_completed = 7

fit_state = get_dummy_fit_state()
self.assertEqual(
{Phase.TRAIN: 5, Phase.EVALUATE: 7},
_get_step_phase_mapping(fit_state, unit),
)

train_state = get_dummy_train_state()
self.assertEqual({Phase.TRAIN: 5}, _get_step_phase_mapping(train_state, unit))

eval_state = get_dummy_eval_state()
self.assertEqual({Phase.EVALUATE: 7}, _get_step_phase_mapping(eval_state, unit))
13 changes: 13 additions & 0 deletions torchtnt/framework/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,19 @@ def get_dummy_train_state(dataloader: Optional[Iterable[object]] = None) -> Stat
)


def get_dummy_eval_state(dataloader: Optional[Iterable[object]] = None) -> State:
return State(
entry_point=EntryPoint.EVALUATE,
eval_state=PhaseState(
dataloader=dataloader or [1, 2, 3, 4],
max_epochs=1,
max_steps=1,
max_steps_per_epoch=1,
),
timer=None,
)


def get_dummy_fit_state() -> State:
return State(
entry_point=EntryPoint.FIT,
Expand Down
27 changes: 24 additions & 3 deletions torchtnt/framework/callbacks/_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
# pyre-strict


from typing import Any, Dict
from typing import Any, cast, Dict, Union

from pyre_extensions import none_throws
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.state import State
from torchtnt.framework.unit import AppStateMixin
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainUnit
from torchtnt.utils.checkpoint import Phase

from torchtnt.utils.stateful import Stateful

Expand All @@ -23,6 +24,26 @@
_EVAL_PROGRESS_STATE_KEY = "eval_progress"


def _get_step_phase_mapping(
state: State, unit: Union[TTrainUnit, TEvalUnit]
) -> Dict[Phase, int]:
"""
Returns a mapping of phase to step, depending on the entrypoint.
For FIT, it always includes train and eval progress.
"""
step_mapping = {}

if state.entry_point in (EntryPoint.TRAIN, EntryPoint.FIT):
train_unit = cast(TTrainUnit, unit)
step_mapping[Phase.TRAIN] = train_unit.train_progress.num_steps_completed

if state.entry_point in (EntryPoint.EVALUATE, EntryPoint.FIT):
eval_unit = cast(TEvalUnit, unit)
step_mapping[Phase.EVALUATE] = eval_unit.eval_progress.num_steps_completed

return step_mapping


def _prepare_app_state(unit: AppStateMixin) -> Dict[str, Any]:
"""Join together all of the tracked stateful entities to simplify registration of snapshottable states, deals with FSDP case"""
app_state = unit.app_state()
Expand Down

0 comments on commit 5e1db7f

Please sign in to comment.