Skip to content

Commit 5e1db7f

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add _get_checkpoint_mapping function (#831)
Summary: Pull Request resolved: #831 Reviewed By: JKSenthil Differential Revision: D57118392 fbshipit-source-id: 319ff905786a181ff302ecb91a61af76561fc942
1 parent 2979b13 commit 5e1db7f

File tree

3 files changed

+65
-4
lines changed

3 files changed

+65
-4
lines changed

tests/framework/callbacks/test_checkpoint_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,21 @@
88

99
import unittest
1010

11-
from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state
11+
from torch import nn
12+
13+
from torchtnt.framework._test_utils import (
14+
DummyAutoUnit,
15+
DummyTrainUnit,
16+
get_dummy_eval_state,
17+
get_dummy_fit_state,
18+
get_dummy_train_state,
19+
)
1220

1321
from torchtnt.framework.callbacks._checkpoint_utils import (
22+
_get_step_phase_mapping,
1423
_prepare_app_state_for_checkpoint,
1524
)
25+
from torchtnt.utils.checkpoint import Phase
1626

1727

1828
class CheckpointUtilsTest(unittest.TestCase):
@@ -26,3 +36,20 @@ def test_get_app_state(self) -> None:
2636
app_state.keys(),
2737
["module", "optimizer", "loss_fn", "train_progress"],
2838
)
39+
40+
def test_get_step_phase_mapping(self) -> None:
41+
unit = DummyAutoUnit(module=nn.Linear(2, 2))
42+
unit.train_progress._num_steps_completed = 5
43+
unit.eval_progress._num_steps_completed = 7
44+
45+
fit_state = get_dummy_fit_state()
46+
self.assertEqual(
47+
{Phase.TRAIN: 5, Phase.EVALUATE: 7},
48+
_get_step_phase_mapping(fit_state, unit),
49+
)
50+
51+
train_state = get_dummy_train_state()
52+
self.assertEqual({Phase.TRAIN: 5}, _get_step_phase_mapping(train_state, unit))
53+
54+
eval_state = get_dummy_eval_state()
55+
self.assertEqual({Phase.EVALUATE: 7}, _get_step_phase_mapping(eval_state, unit))

torchtnt/framework/_test_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,19 @@ def get_dummy_train_state(dataloader: Optional[Iterable[object]] = None) -> Stat
3333
)
3434

3535

36+
def get_dummy_eval_state(dataloader: Optional[Iterable[object]] = None) -> State:
37+
return State(
38+
entry_point=EntryPoint.EVALUATE,
39+
eval_state=PhaseState(
40+
dataloader=dataloader or [1, 2, 3, 4],
41+
max_epochs=1,
42+
max_steps=1,
43+
max_steps_per_epoch=1,
44+
),
45+
timer=None,
46+
)
47+
48+
3649
def get_dummy_fit_state() -> State:
3750
return State(
3851
entry_point=EntryPoint.FIT,

torchtnt/framework/callbacks/_checkpoint_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
# pyre-strict
88

99

10-
from typing import Any, Dict
10+
from typing import Any, cast, Dict, Union
1111

1212
from pyre_extensions import none_throws
1313
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
14-
from torchtnt.framework.state import State
15-
from torchtnt.framework.unit import AppStateMixin
14+
from torchtnt.framework.state import EntryPoint, State
15+
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainUnit
16+
from torchtnt.utils.checkpoint import Phase
1617

1718
from torchtnt.utils.stateful import Stateful
1819

@@ -23,6 +24,26 @@
2324
_EVAL_PROGRESS_STATE_KEY = "eval_progress"
2425

2526

27+
def _get_step_phase_mapping(
28+
state: State, unit: Union[TTrainUnit, TEvalUnit]
29+
) -> Dict[Phase, int]:
30+
"""
31+
Returns a mapping of phase to step, depending on the entrypoint.
32+
For FIT, it always includes train and eval progress.
33+
"""
34+
step_mapping = {}
35+
36+
if state.entry_point in (EntryPoint.TRAIN, EntryPoint.FIT):
37+
train_unit = cast(TTrainUnit, unit)
38+
step_mapping[Phase.TRAIN] = train_unit.train_progress.num_steps_completed
39+
40+
if state.entry_point in (EntryPoint.EVALUATE, EntryPoint.FIT):
41+
eval_unit = cast(TEvalUnit, unit)
42+
step_mapping[Phase.EVALUATE] = eval_unit.eval_progress.num_steps_completed
43+
44+
return step_mapping
45+
46+
2647
def _prepare_app_state(unit: AppStateMixin) -> Dict[str, Any]:
2748
"""Join together all of the tracked stateful entities to simplify registration of snapshottable states, deals with FSDP case"""
2849
app_state = unit.app_state()

0 commit comments

Comments
 (0)