Skip to content

Commit d718410

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Encapsulate checkpoint str parsing in a method (#825)
Summary: Pull Request resolved: #825 Reviewed By: JKSenthil Differential Revision: D57055916 fbshipit-source-id: 84beac98e4319445a3e2f463729038d4670a4684
1 parent 697193b commit d718410

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

torchtnt/utils/checkpoint.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,22 @@ def from_str(cls, checkpoint_path: str) -> "CheckpointPath":
9696
Raises:
9797
ValueError: If the path is malformed and can't be parsed.
9898
"""
99-
path_match = cls.PATH_REGEX.match(checkpoint_path)
99+
ckpt_path = cls.__new__(cls)
100+
ckpt_path._populate_from_str(checkpoint_path)
101+
return ckpt_path
102+
103+
def _populate_from_str(self, checkpoint_path: str) -> None:
104+
"""
105+
Reusable method to parse a checkpoint path string, extract the components, and populate
106+
a checkpoint path instance.
107+
108+
Args:
109+
checkpoint_path: The checkpoint path string.
110+
111+
Raises:
112+
ValueError: If the path is malformed (either non-parsable, or contains wrong data types)
113+
"""
114+
path_match = self.PATH_REGEX.match(checkpoint_path)
100115
if not path_match:
101116
raise ValueError(
102117
f"Attempted to parse malformed checkpoint path: {checkpoint_path}."
@@ -109,12 +124,10 @@ def from_str(cls, checkpoint_path: str) -> "CheckpointPath":
109124
metric_value_f = float(metric_value)
110125
metric_data = MetricData(name=metric_name, value=metric_value_f)
111126

112-
return CheckpointPath(
113-
dirpath=dirpath,
114-
epoch=int(epoch),
115-
step=int(step),
116-
metric_data=metric_data,
117-
)
127+
self.dirpath = dirpath.rstrip("/")
128+
self.epoch = int(epoch)
129+
self.step = int(step)
130+
self.metric_data = metric_data
118131

119132
except ValueError:
120133
# Should never happen since path matches regex
@@ -205,18 +218,7 @@ def __getstate__(self) -> str:
205218

206219
def __setstate__(self, state: str) -> None:
207220
# Match regex directly to avoid creating a new instance with `from_str`
208-
path_match = self.PATH_REGEX.match(state)
209-
assert path_match, f"Malformed checkpoint found when unpickling: {state}"
210-
211-
dirpath, epoch, step, metric_name, metric_value = path_match.groups()
212-
self.dirpath = dirpath.rstrip("/")
213-
self.epoch = int(epoch)
214-
self.step = int(step)
215-
self.metric_data = (
216-
MetricData(name=metric_name, value=float(metric_value))
217-
if metric_name and metric_value
218-
else None
219-
)
221+
self._populate_from_str(state)
220222

221223

222224
class CheckpointManager:

0 commit comments

Comments
 (0)