From 2979b13278de80fff9e796d33a69c3306e4bd5b0 Mon Sep 17 00:00:00 2001 From: Diego Urgell Date: Mon, 13 May 2024 17:16:14 -0700 Subject: [PATCH] Enable multi-phase handling in CheckpointPath (#811) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/811 Reviewed By: JKSenthil Differential Revision: D56275780 fbshipit-source-id: 08b85887f13e231d38409d71be4ba9ce7bdb8d48 --- tests/utils/test_checkpoint.py | 347 +++++++++++++++++++++++++++++++-- torchtnt/utils/checkpoint.py | 114 +++++++++-- 2 files changed, 424 insertions(+), 37 deletions(-) diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py index 8e4f7e9a29..d9929a6b38 100644 --- a/tests/utils/test_checkpoint.py +++ b/tests/utils/test_checkpoint.py @@ -33,6 +33,7 @@ get_checkpoint_dirpaths, get_latest_checkpoint_path, MetricData, + Phase, ) from torchtnt.utils.distributed import ( PGWrapper, @@ -56,6 +57,8 @@ def test_from_str(self) -> None: "foo/epoch_20_step_30_val_loss=1a", "foo/epoch_2_step_15_mean=hello", "foo/epoch_2.6_step_23", + "foo/epoch_3_pred_step_3", + "foo/epoch_3__step_3", ] for path in malformed_paths: with self.assertRaisesRegex( @@ -66,6 +69,10 @@ def test_from_str(self) -> None: # valid paths valid_paths = [ ("foo/epoch_0_step_1", CheckpointPath("foo", epoch=0, step=1)), + ( + "foo_bar/fizz_buzz/epoch_0_step_1", + CheckpointPath("foo_bar/fizz_buzz", epoch=0, step={Phase.NONE: 1}), + ), ( "foo/epoch_14_step_3_mean=15.0", CheckpointPath( @@ -98,10 +105,54 @@ def test_from_str(self) -> None: CheckpointPath( "file://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61", epoch=2, - step=1, + step={Phase.NONE: 1}, + metric_data=MetricData("acc", 0.98), + ), + ), + ( + "foo/bar/epoch_23_train_step_31_mean_loss_squared=0.0", + CheckpointPath( + "foo/bar/", + epoch=23, + step={Phase.TRAIN: 31}, + metric_data=MetricData("mean_loss_squared", 0.0), + ), + ), + ( + "file://path/some/checkpoints_dir/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_eval_step_1_acc=0.98", + CheckpointPath( + "file://path/some/checkpoints_dir/0b20e70f-9ad2-4904-b7d6-e8da48087d61", + epoch=2, + step={Phase.EVALUATE: 1}, metric_data=MetricData("acc", 0.98), ), ), + ( + "foo/bar/epoch_23_train_step_31_eval_step_15_mean_loss_squared=-23.6", + CheckpointPath( + "foo/bar/", + epoch=23, + step={Phase.TRAIN: 31, Phase.EVALUATE: 15}, + metric_data=MetricData("mean_loss_squared", -23.6), + ), + ), + ( + "file://path/some/checkpoints_dir/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_56_train_step_60_eval_step_1", + CheckpointPath( + "file://path/some/checkpoints_dir/0b20e70f-9ad2-4904-b7d6-e8da48087d61", + epoch=56, + step={Phase.TRAIN: 60, Phase.EVALUATE: 1}, + ), + ), + ( + "foo/bar/epoch_0_train_step_0_eval_step_0_mean_loss_squared=0.0", + CheckpointPath( + "foo/bar/", + epoch=0, + step={Phase.TRAIN: 0, Phase.EVALUATE: 0}, + metric_data=MetricData("mean_loss_squared", 0.0), + ), + ), ] for path, expected_ckpt in valid_paths: parsed_ckpt = CheckpointPath.from_str(path) @@ -142,10 +193,52 @@ def test_compare_by_recency(self) -> None: ) self.assertTrue(twin1 == twin2) + legacy1 = CheckpointPath("foo", epoch=3, step=3) + old = CheckpointPath("foo", epoch=3, step={Phase.TRAIN: 5}) + legacy2 = CheckpointPath("foo", epoch=3, step=7) + new = CheckpointPath("foo", epoch=3, step={Phase.TRAIN: 5, Phase.EVALUATE: 5}) + legacy3 = CheckpointPath("foo", epoch=3, step=10) + legacy4 = CheckpointPath("foor", epoch=4, step=12) + + self.assertTrue(old.newer_than(legacy1)) + self.assertTrue(new.newer_than(legacy1)) + self.assertTrue(new.newer_than(old)) + self.assertFalse(old.newer_than(new)) + self.assertTrue(legacy2.newer_than(old)) + self.assertTrue(new.newer_than(legacy2)) + self.assertTrue(new.newer_than(legacy3)) + self.assertFalse(legacy3.newer_than(new)) + self.assertTrue(legacy4.newer_than(legacy3)) + self.assertFalse(new == old) + + old = CheckpointPath("foo", epoch=3, step={Phase.TRAIN: 5}) + new = CheckpointPath("foo", epoch=3, step={Phase.TRAIN: 6}) + self.assertTrue(new.newer_than(old)) + self.assertFalse(old.newer_than(new)) + + train_only = CheckpointPath("foo", epoch=3, step={Phase.TRAIN: 10}) + eval_only = CheckpointPath("foo", epoch=3, step={Phase.EVALUATE: 10}) + multiphase_1 = CheckpointPath( + "foo", epoch=3, step={Phase.TRAIN: 5, Phase.EVALUATE: 5} + ) + multiphase_2 = CheckpointPath( + "foo", epoch=4, step={Phase.TRAIN: 15, Phase.EVALUATE: 10} + ) + multiphase_3 = CheckpointPath( + "foo", epoch=4, step={Phase.TRAIN: 20, Phase.EVALUATE: 10} + ) + + self.assertTrue(eval_only > train_only) + self.assertFalse(eval_only < train_only) + self.assertTrue(train_only < multiphase_1) + self.assertTrue(eval_only > multiphase_1) + self.assertTrue(eval_only < multiphase_2) + self.assertTrue(multiphase_2 < multiphase_3) + def test_compare_by_optimality(self) -> None: # not both metric aware ckpt1 = CheckpointPath("foo", epoch=0, step=1) - ckpt2 = CheckpointPath("foo", epoch=1, step=1) + ckpt2 = CheckpointPath("foo", epoch=1, step={Phase.TRAIN: 5}) ckpt3 = CheckpointPath( "foo", epoch=1, step=1, metric_data=MetricData("bar", 1.0) ) @@ -170,7 +263,10 @@ def test_compare_by_optimality(self) -> None: "foo", epoch=0, step=1, metric_data=MetricData("foo", 1.0) ) larger = CheckpointPath( - "foo", epoch=0, step=1, metric_data=MetricData("foo", 2.0) + "foo", + epoch=0, + step={Phase.TRAIN: 1}, + metric_data=MetricData("foo", 2.0), ) self.assertTrue(larger.more_optimal_than(smaller, mode="max")) self.assertFalse(smaller.more_optimal_than(larger, mode="max")) @@ -181,6 +277,7 @@ def test_pickling(self) -> None: for path in ( "foo/epoch_0_step_1", "file://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98", + "foo/epoch_0_train_step_2_eval_step_5", ): ckpt = CheckpointPath.from_str(path) @@ -205,6 +302,49 @@ def test_checkpoint_ordering(self) -> None: sorted_paths = [str(x) for x in sorted(ckpts)] self.assertEqual(sorted_paths, [paths[2], paths[0], paths[1]]) + paths = [ + "foo/epoch_0_step_1", + "foo/epoch_1_train_step_20_val_loss=0.09", + "foo/epoch_0_train_step_10_val_loss=10.0", + "foo/epoch_1_step_30_val_loss=29.0", + "foo/epoch_0_eval_step_2_val_loss=13.0", + "foo/epoch_1_train_step_25_val_loss=0.06", + ] + ckpts = [CheckpointPath.from_str(path) for path in paths] + self.assertEqual( + [str(path) for path in sorted(ckpts)], + [paths[0], paths[2], paths[4], paths[1], paths[5], paths[3]], + ) + paths = [ + "foo/epoch_1_train_step_20_eval_step_10_val_loss=0.09", + "foo/epoch_0_train_step_10_eval_step_0_val_loss=10.0", + "foo/epoch_1_step_32_val_loss=29.0", # phase naive + "foo/epoch_1_train_step_20_eval_step_15_val_loss=0.02", + "foo/epoch_0_train_step_15_eval_step_0_val_loss=13.0", + "foo/epoch_1_train_step_25_val_loss=0.06", + "foo/epoch_0_train_step_15_eval_step_5_val_loss=18.0", + "foo/epoch_0_eval_step_25_val_loss=18.0", + "foo/epoch_0_step_10", # phase naive + ] + ckpts = [CheckpointPath.from_str(path) for path in paths] + for ckpt in ckpts: + print(ckpt.__repr__()) + + self.assertEqual( + [str(path) for path in sorted(ckpts)], + [ + paths[8], + paths[1], + paths[4], + paths[6], + paths[7], + paths[5], + paths[0], + paths[2], + paths[3], + ], + ) + class CheckpointManagerTest(unittest.TestCase): def test_create_checkpoint_manager(self) -> None: @@ -217,6 +357,10 @@ def test_create_checkpoint_manager(self) -> None: f"{temp_dir}/epoch_1_step_2_loss=0.5", f"{temp_dir}/epoch_2_step_5_loss=0.3", f"{temp_dir}/epoch_0_step_2_acc=0.7", + f"{temp_dir}/epoch_2_train_step_2_eval_step_5", + f"{temp_dir}/epoch_3_train_step_5_loss=-0.2", + f"{temp_dir}/epoch_3_eval_step_2_loss=0.2", + f"{temp_dir}/epoch_3_train_step_1_eval_step_5_loss=0.1", ] for path in paths: os.mkdir(path) @@ -237,6 +381,10 @@ def test_create_checkpoint_manager(self) -> None: f"{temp_dir}/epoch_1_step_2_loss=0.5", f"{temp_dir}/epoch_1_step_3", f"{temp_dir}/epoch_2_step_5_loss=0.3", + f"{temp_dir}/epoch_2_train_step_2_eval_step_5", + f"{temp_dir}/epoch_3_train_step_5_loss=-0.2", + f"{temp_dir}/epoch_3_eval_step_2_loss=0.2", + f"{temp_dir}/epoch_3_train_step_1_eval_step_5_loss=0.1", ], ) @@ -253,6 +401,9 @@ def test_create_checkpoint_manager(self) -> None: [ f"{temp_dir}/epoch_1_step_2_loss=0.5", f"{temp_dir}/epoch_2_step_5_loss=0.3", + f"{temp_dir}/epoch_3_eval_step_2_loss=0.2", + f"{temp_dir}/epoch_3_train_step_1_eval_step_5_loss=0.1", + f"{temp_dir}/epoch_3_train_step_5_loss=-0.2", f"{temp_dir}/epoch_0_step_5_loss=-0.3", ], ) @@ -269,6 +420,9 @@ def test_create_checkpoint_manager(self) -> None: [x.path for x in ckpt_manager._ckpt_paths], [ f"{temp_dir}/epoch_0_step_5_loss=-0.3", + f"{temp_dir}/epoch_3_train_step_5_loss=-0.2", + f"{temp_dir}/epoch_3_train_step_1_eval_step_5_loss=0.1", + f"{temp_dir}/epoch_3_eval_step_2_loss=0.2", f"{temp_dir}/epoch_2_step_5_loss=0.3", f"{temp_dir}/epoch_1_step_2_loss=0.5", ], @@ -284,6 +438,42 @@ def test_create_checkpoint_manager(self) -> None: ) self.assertEqual(ckpt_manager._ckpt_paths, []) + # More intense metric sorting test + with tempfile.TemporaryDirectory() as temp_dir: + paths = [ + f"{temp_dir}/epoch_1_train_step_20_val_loss=0.09", + f"{temp_dir}/epoch_0_train_step_10_eval_step_53412092_val_loss=10.0", + f"{temp_dir}/epoch_4_step_130_val_loss=29.0", + f"{temp_dir}/epoch_0_eval_step_2_val_loss=13.0", + f"{temp_dir}/epoch_1_train_step_25_val_loss=0.06", + ] + for path in paths: + os.mkdir(path) + + ckpt_manager = CheckpointManager( + temp_dir, + best_checkpoint_config=BestCheckpointConfig( + monitored_metric="val_loss", mode="min" + ), + keep_last_n_checkpoints=3, + ) + self.assertEqual( + [str(path) for path in ckpt_manager._ckpt_paths], + [paths[2], paths[3], paths[1], paths[0], paths[4]], + ) + + ckpt_manager = CheckpointManager( + temp_dir, + best_checkpoint_config=BestCheckpointConfig( + monitored_metric="val_loss", mode="max" + ), + keep_last_n_checkpoints=3, + ) + self.assertEqual( + [str(path) for path in ckpt_manager._ckpt_paths], + [paths[4], paths[0], paths[1], paths[3], paths[2]], + ) + @skip_if_not_distributed def test_create_checkpoint_manager_distributed(self) -> None: spawn_multi_process( @@ -380,6 +570,13 @@ def test_generate_checkpoint_path(self) -> None: "foo/epoch_1_step_3", ) + self.assertEqual( + ckpt_manager.generate_checkpoint_path( + 1, {Phase.TRAIN: 5, Phase.EVALUATE: 7} + ).path, + "foo/epoch_1_train_step_5_eval_step_7", + ) + ckpt_manager._best_checkpoint_config = BestCheckpointConfig( monitored_metric="val_loss", mode="min" ) @@ -412,7 +609,7 @@ def test_generate_checkpoint_path(self) -> None: ckpt_manager.generate_checkpoint_path(1, 2, MetricData("val_loss", 3.5)) def test_append_checkpoint_by_recency(self) -> None: - ckpt_manager = CheckpointManager("foo", keep_last_n_checkpoints=2) + ckpt_manager = CheckpointManager("foo", keep_last_n_checkpoints=3) ckpt_manager._ckpt_paths = [CheckpointPath("foo", 0, 0)] # without need to remove old by recency @@ -422,12 +619,28 @@ def test_append_checkpoint_by_recency(self) -> None: [CheckpointPath("foo", 0, 0), CheckpointPath("foo", 0, 1)], ) + ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, {Phase.TRAIN: 5})) + self.assertEqual( + ckpt_manager._ckpt_paths, + [ + CheckpointPath("foo", 0, 0), + CheckpointPath("foo", 0, 1), + CheckpointPath("foo", 0, {Phase.TRAIN: 5}), + ], + ) + # removing old by recency with patch("fsspec.implementations.local.LocalFileSystem.rm") as mock_rm: - ckpt_manager.append_checkpoint(CheckpointPath("foo", 0, 2)) + ckpt_manager.append_checkpoint( + CheckpointPath("foo", 0, {Phase.EVALUATE: 10}) + ) self.assertEqual( ckpt_manager._ckpt_paths, - [CheckpointPath("foo", 0, 1), CheckpointPath("foo", 0, 2)], + [ + CheckpointPath("foo", 0, 1), + CheckpointPath("foo", 0, {Phase.TRAIN: 5}), + CheckpointPath("foo", 0, {Phase.EVALUATE: 10}), + ], ) mock_rm.assert_called_once_with("foo/epoch_0_step_0", recursive=True) @@ -597,6 +810,76 @@ def test_latest_checkpoint_path(self) -> None: path_2, ) + with tempfile.TemporaryDirectory() as temp_dir: + path_1 = os.path.join(temp_dir, "epoch_0_train_step_0") + os.mkdir(path_1) + self._create_snapshot_metadata(path_1) + path_2 = os.path.join(temp_dir, "epoch_0_train_step_100_val_loss=0.002") + os.mkdir(path_2) + self._create_snapshot_metadata(path_2) + + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2 + ) + + path_3 = os.path.join(temp_dir, "epoch_0_eval_step_0") + os.mkdir(path_3) + self._create_snapshot_metadata(path_3) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_3 + ) + + # Missing metadata file + path_4 = os.path.join(temp_dir, "epoch_1_train_step_1") + os.mkdir(path_4) + self.assertEqual(get_latest_checkpoint_path(temp_dir), path_4) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_3 + ) + self._create_snapshot_metadata(path_4) + + # Ill-formatted name + path_5 = os.path.join(temp_dir, "epoch_700") + os.mkdir(path_5) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_4 + ) + + # With both phases + path_6 = os.path.join(temp_dir, "epoch_1_train_step_5_eval_step_5") + os.mkdir(path_6) + self._create_snapshot_metadata(path_6) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_6 + ) + + path_9 = os.path.join(temp_dir, "epoch_1_train_step_5_eval_step_10") + os.mkdir(path_9) + self._create_snapshot_metadata(path_9) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_9 + ) + + path_10 = os.path.join(temp_dir, "epoch_2_train_step_10_eval_step_10") + os.mkdir(path_10) + self._create_snapshot_metadata(path_10) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_10 + ) + + path_11 = os.path.join(temp_dir, "epoch_2_train_step_15_eval_step_10") + os.mkdir(path_11) + self._create_snapshot_metadata(path_11) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_11 + ) + + # Test legacy path with fewer steps + + # Test legacy path with equal steps + + # Test legacy path with more steps + @skip_if_not_distributed def test_latest_checkpoint_path_distributed(self) -> None: spawn_multi_process( @@ -686,7 +969,9 @@ def test_best_checkpoint_path(self) -> None: ) # handle "max" mode correctly - best_path_3 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.1") + best_path_3 = os.path.join( + temp_dir, "epoch_0_train_step_100_eval_step_25_val_loss=0.1" + ) os.mkdir(best_path_3) self.assertEqual( get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"), @@ -714,11 +999,16 @@ def test_retrieve_checkpoint_dirpaths(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: paths = [ "epoch_0_step_10", - "epoch_1_step_10", "epoch_2_step_10", - "epoch_0_step_5", + "epoch_0_eval_step_5", "epoch_0_step_6", - "epoch_0_step_3", + "epoch_0_train_step_3", + "epoch_0_step_10_train_loss=13.0", + "epoch_34_eval_step_10_val_loss=10.0", + "epoch_3_train_step_10_val_loss=5.1", + "epoch_1_step_10", + "epoch_25_train_step_10_eval_step_15", + "epoch_1_train_step_10_eval_step_0_train_loss=10.0", ] for path in paths[:-1]: os.mkdir(os.path.join(temp_dir, path)) @@ -743,10 +1033,13 @@ def test_retrieve_checkpoint_dirpaths(self) -> None: ) # check metadata file is correct filtered for - # by creating metadata for 3rd path in list + # by creating metadata for 3rd and 6th path in list with open(os.path.join(temp_dir, paths[2], ".metadata"), "w"): pass + with open(os.path.join(temp_dir, paths[5], ".metadata"), "w"): + pass + self.assertEqual( { str(x) @@ -754,7 +1047,7 @@ def test_retrieve_checkpoint_dirpaths(self) -> None: temp_dir, metadata_fname=".metadata" ) }, - {os.path.join(temp_dir, paths[2])}, + {os.path.join(temp_dir, paths[2]), os.path.join(temp_dir, paths[5])}, ) def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None: @@ -763,11 +1056,12 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None: """ with tempfile.TemporaryDirectory() as temp_dir: paths = [ - "epoch_0_step_10_val_loss=10.0", + "epoch_0_train_step_10_val_loss=10.0", "epoch_1_step_10_val_loss=5.0", - "epoch_2_step_10", - "epoch_0_step_5", - "epoch_0_step_6_train_loss=13.0", + "epoch_0_train_step_7_eval_step_12_val_loss=12.9", + "epoch_2_eval_step_10", + "epoch_0_train_step_5", + "epoch_0_train_step_6_train_loss=13.0", ] for path in paths: os.mkdir(os.path.join(temp_dir, path)) @@ -794,7 +1088,7 @@ def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None: ) }, { - os.path.join(temp_dir, path) for path in paths[:2] + os.path.join(temp_dir, path) for path in paths[:3] }, # since last path is a file ) self.assertEqual( @@ -861,21 +1155,32 @@ def test_get_checkpoint_dirpaths(self) -> None: """ with tempfile.TemporaryDirectory() as temp_dir: path1 = os.path.join(temp_dir, "epoch_1_step_20") - path2 = os.path.join(temp_dir, "epoch_4_step_130") + path2 = os.path.join(temp_dir, "epoch_4_eval_step_130") path3 = os.path.join(temp_dir, "epoch_0_step_10") + path4 = os.path.join( + temp_dir, "epoch_0_train_step_10_eval_step_15_train_loss=13.0" + ) + malformed_path = os.path.join( + temp_dir, "epoch_train_0_step_10_val_loss=10.0" + ) os.mkdir(path1) os.mkdir(path2) os.mkdir(path3) + os.mkdir(path4) + os.mkdir(malformed_path) self.assertEqual( {str(x) for x in get_checkpoint_dirpaths(temp_dir)}, - {path1, path2, path3}, + {path1, path2, path3, path4}, ) with tempfile.TemporaryDirectory() as temp_dir: path1 = os.path.join(temp_dir, "epoch_1_step_20_val_loss=0.01") - path2 = os.path.join(temp_dir, "epoch_4_step_130_val_loss=-0.2") - path3 = os.path.join(temp_dir, "epoch_0_step_10_val_loss=0.12") + path2 = os.path.join(temp_dir, "epoch_4_train_step_130_val_loss=-0.2") + path3 = os.path.join(temp_dir, "epoch_0_eval_step_10_val_loss=0.12") + path4 = os.path.join( + temp_dir, "epoch_0_train_step_10_eval_step_15_val_loss=13.0" + ) os.mkdir(path1) os.mkdir(path2) os.mkdir(path3) diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 6d1188b6cb..825a0a0ab1 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -10,8 +10,10 @@ import os import re from dataclasses import dataclass +from enum import Enum from functools import total_ordering -from typing import List, Literal, Optional, Pattern +from operator import xor +from typing import Dict, List, Literal, Optional, Pattern, Tuple, Union import fsspec import torch.distributed as dist @@ -46,39 +48,62 @@ class BestCheckpointConfig: mode: Literal["min", "max"] = "min" +class Phase(Enum): + NONE = 0 # Only used for backwards compatibility + TRAIN = 1 + EVALUATE = 2 + + @total_ordering class CheckpointPath: """ Representation of a checkpoint path. Handles parsing and serialization of the specific path format. - Currently, the basic compliant path format is: /epoch__step_ - If a metric is being tracked, it's added to the name: /epoch__step__= - This class is well-ordered by checkpoint recency, so any comparisons will operate using the epoch + step. Sorting by - metric can be done by extracting the metric value from the metric_data attribute. + A CheckpointPath can be metric aware/naive, and/or phase aware/naive. This means: + - A metric aware checkpoint stores the value of a particular metric name. It is possible to compare optimality of two + checkpoints that are tracking the same metric by using the `is_more_optimal_than` method. + - A phase aware checkpoint stores the step number for a particular phase. This provides a better UX when doing manual + exploration of the generated checkpoints. The phase will also be considered to determine recency. However, note that + it is possible to compare phase naive and phase aware checkpoints by recency, with the latter always being newer. For + two phase aware checkpoints, first the evaluation steps (if any) are compared, then the train steps. It is always assumed + that evaluations happens after training for a particular epoch. + + Examples of compliant checkpoint paths: + - phase-naive and metric-naive- /epoch__step_ + - phase-naive and metric-aware- /epoch__step__= + - phase-aware (train only) and metric-naive- /epoch__train_step__= + - phase-aware and metric-aware- /epoch__train_step__eval_step__= """ - PATH_REGEX: Pattern = re.compile( - r"^(.+)epoch_(\d+)_step_(\d+)(?:_(.+)=(-?\d+\.?\d*))?\/?$" + PHASE_NAIVE_REGEX: Pattern = re.compile( + r"^(.+)epoch_(\d+)_step_(\d+)(?:_(\w+)=(-?\d+\.?\d*))?\/?$" + ) + + PHASE_AWARE_REGEX: Pattern = re.compile( + r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$" ) def __init__( self, dirpath: str, epoch: int, - step: int, + step: Union[Dict[Phase, int], int], metric_data: Optional[MetricData] = None, ) -> None: """ Args: dirpath: The base directory path that checkpoints are saved in. epoch: The epoch number of this checkpoint. - step: The step number of this checkpoint. + step: The step number of this checkpoint. This may be an integer for a phase-naive checkpoint. To + create a phase-aware path, a dict can be used to map each phase to its step number. metric_data: Optional data about the metric being tracked. Should contain both metric name and value. """ self.dirpath: str = dirpath.rstrip("/") self.epoch = epoch - self.step = step self.metric_data = metric_data + self.step: Dict[Phase, int] = ( + step if isinstance(step, dict) else {Phase.NONE: step} + ) @classmethod def from_str(cls, checkpoint_path: str) -> "CheckpointPath": @@ -111,14 +136,33 @@ def _populate_from_str(self, checkpoint_path: str) -> None: Raises: ValueError: If the path is malformed (either non-parsable, or contains wrong data types) """ - path_match = self.PATH_REGEX.match(checkpoint_path) + is_phase_aware = ( + "train_step" in checkpoint_path or "eval_step" in checkpoint_path + ) + regex = self.PHASE_AWARE_REGEX if is_phase_aware else self.PHASE_NAIVE_REGEX + path_match = regex.match(checkpoint_path) if not path_match: raise ValueError( f"Attempted to parse malformed checkpoint path: {checkpoint_path}." ) - dirpath, epoch, step, metric_name, metric_value = path_match.groups() try: + step_mapping: Dict[Phase, int] = {} + if is_phase_aware: + dirpath, epoch, train_steps, eval_steps, metric_name, metric_value = ( + path_match.groups() + ) + if train_steps is not None: + step_mapping[Phase.TRAIN] = int(train_steps) + if eval_steps is not None: + step_mapping[Phase.EVALUATE] = int(eval_steps) + + else: + dirpath, epoch, naive_steps, metric_name, metric_value = ( + path_match.groups() + ) + step_mapping[Phase.NONE] = int(naive_steps) + metric_data: Optional[MetricData] = None if metric_name: metric_value_f = float(metric_value) @@ -126,7 +170,7 @@ def _populate_from_str(self, checkpoint_path: str) -> None: self.dirpath = dirpath.rstrip("/") self.epoch = int(epoch) - self.step = int(step) + self.step = step_mapping self.metric_data = metric_data except ValueError: @@ -141,23 +185,58 @@ def path(self) -> str: Returns: The full path to the checkpoint directory. """ - name = f"epoch_{self.epoch}_step_{self.step}" + name = f"epoch_{self.epoch}" + + if not self._is_phase_aware(): + name += f"_step_{self.step[Phase.NONE]}" + else: + if Phase.TRAIN in self.step: + name += f"_train_step_{self.step[Phase.TRAIN]}" + if Phase.EVALUATE in self.step: + name += f"_eval_step_{self.step[Phase.EVALUATE]}" + if self.metric_data: name += f"_{self.metric_data.name}={self.metric_data.value}" return os.path.join(self.dirpath, name) + def _is_phase_aware(self) -> bool: + return Phase.NONE not in self.step + def newer_than(self, other: "CheckpointPath") -> bool: """ Given another CheckpointPath instance, determine if this checkpoint is strictly newer than the other. + Note that recency is determine in terms of the epoch, phase, and number of steps. It is NOT related to + the timestamp the checkpoint was saved. Returns: True if this checkpoint is newer than the other, otherwise False. """ + # First check always the epoch if self.epoch != other.epoch: return self.epoch > other.epoch - return self.step > other.step + # If comparing a phase-aware vs phase-naive checkpoint, determine recency by the total number of steps. + # In case both are the same, the phase aware checkpoint is considered newer. + if xor(self._is_phase_aware(), other._is_phase_aware()): + if sum(self.step.values()) != sum(other.step.values()): + return sum(self.step.values()) > sum(other.step.values()) + return self._is_phase_aware() + + # For two phase-naive checkpoints, we only need to look at the step + if not self._is_phase_aware(): + return self.step[Phase.NONE] > other.step[Phase.NONE] + + # If one checkpoint has eval steps and the other doesn't, the one with eval steps is always newer + if xor(Phase.EVALUATE in self.step, Phase.EVALUATE in other.step): + return Phase.EVALUATE in self.step + + # Otherwise, compare first by eval and then train steps + return self._get_phase_steps() > other._get_phase_steps() + + def _get_phase_steps(self) -> Tuple[int, int]: + """Tuple with the phase steps ordered by phase priority in comparison (first eval, then train).""" + return self.step.get(Phase.EVALUATE, 0), self.step.get(Phase.TRAIN, 0) def more_optimal_than( self, other: "CheckpointPath", mode: Literal["min", "max"] @@ -322,7 +401,10 @@ def prune_surplus_checkpoints(self) -> None: self.remove_checkpoint() def generate_checkpoint_path( - self, epoch: int, step: int, metric_data: Optional[MetricData] = None + self, + epoch: int, + step: Union[int, Dict[Phase, int]], + metric_data: Optional[MetricData] = None, ) -> CheckpointPath: """ Given the current epoch, step, and possibly a metric_data value, determine the path