@@ -96,7 +96,22 @@ def from_str(cls, checkpoint_path: str) -> "CheckpointPath":
96
96
Raises:
97
97
ValueError: If the path is malformed and can't be parsed.
98
98
"""
99
- path_match = cls .PATH_REGEX .match (checkpoint_path )
99
+ ckpt_path = cls .__new__ (cls )
100
+ ckpt_path ._populate_from_str (checkpoint_path )
101
+ return ckpt_path
102
+
103
+ def _populate_from_str (self , checkpoint_path : str ) -> None :
104
+ """
105
+ Reusable method to parse a checkpoint path string, extract the components, and populate
106
+ a checkpoint path instance.
107
+
108
+ Args:
109
+ checkpoint_path: The checkpoint path string.
110
+
111
+ Raises:
112
+ ValueError: If the path is malformed (either non-parsable, or contains wrong data types)
113
+ """
114
+ path_match = self .PATH_REGEX .match (checkpoint_path )
100
115
if not path_match :
101
116
raise ValueError (
102
117
f"Attempted to parse malformed checkpoint path: { checkpoint_path } ."
@@ -109,12 +124,10 @@ def from_str(cls, checkpoint_path: str) -> "CheckpointPath":
109
124
metric_value_f = float (metric_value )
110
125
metric_data = MetricData (name = metric_name , value = metric_value_f )
111
126
112
- return CheckpointPath (
113
- dirpath = dirpath ,
114
- epoch = int (epoch ),
115
- step = int (step ),
116
- metric_data = metric_data ,
117
- )
127
+ self .dirpath = dirpath .rstrip ("/" )
128
+ self .epoch = int (epoch )
129
+ self .step = int (step )
130
+ self .metric_data = metric_data
118
131
119
132
except ValueError :
120
133
# Should never happen since path matches regex
@@ -205,18 +218,7 @@ def __getstate__(self) -> str:
205
218
206
219
def __setstate__ (self , state : str ) -> None :
207
220
# Match regex directly to avoid creating a new instance with `from_str`
208
- path_match = self .PATH_REGEX .match (state )
209
- assert path_match , f"Malformed checkpoint found when unpickling: { state } "
210
-
211
- dirpath , epoch , step , metric_name , metric_value = path_match .groups ()
212
- self .dirpath = dirpath .rstrip ("/" )
213
- self .epoch = int (epoch )
214
- self .step = int (step )
215
- self .metric_data = (
216
- MetricData (name = metric_name , value = float (metric_value ))
217
- if metric_name and metric_value
218
- else None
219
- )
221
+ self ._populate_from_str (state )
220
222
221
223
222
224
class CheckpointManager :
0 commit comments