@@ -87,8 +87,8 @@ class Manager:
87
87
def __init__ (
88
88
self ,
89
89
pg : "ProcessGroup" ,
90
- load_state_dict : Callable [[T ], None ],
91
- state_dict : Callable [[], T ],
90
+ load_state_dict : Optional [ Callable [[T ], None ] ],
91
+ state_dict : Optional [ Callable [[], T ] ],
92
92
min_replica_size : int ,
93
93
use_async_quorum : bool = True ,
94
94
timeout : timedelta = timedelta (seconds = 60 ),
@@ -144,7 +144,7 @@ def __init__(
144
144
transfering checkpoints to recovering replicas
145
145
"""
146
146
self ._load_state_dict = load_state_dict
147
- self ._state_dict = state_dict
147
+ self ._user_state_dict = state_dict
148
148
self ._pending_state_dict : Optional [Dict [str , object ]] = None
149
149
self ._use_async_quorum = use_async_quorum
150
150
self ._timeout = timeout
@@ -159,8 +159,6 @@ def __init__(
159
159
world_size = world_size or int (os .environ ["WORLD_SIZE" ])
160
160
self ._min_replica_size = min_replica_size
161
161
162
- self ._user_state_dict = state_dict
163
-
164
162
if checkpoint_transport is None :
165
163
checkpoint_transport = CheckpointServer [Dict [str , T ]](
166
164
timeout = timeout ,
@@ -226,6 +224,12 @@ def __init__(
226
224
self ._participating_rank : Optional [int ] = None
227
225
self ._participating_world_size : int = 0
228
226
227
+ def set_state_dict_fns (
228
+ self , load_state_dict : Callable [[T ], None ], state_dict : Callable [[], T ]
229
+ ) -> None :
230
+ self ._load_state_dict = load_state_dict
231
+ self ._user_state_dict = state_dict
232
+
229
233
def shutdown (self , wait : bool = True ) -> None :
230
234
"""
231
235
Shutdown the manager and checkpoint server.
@@ -531,8 +535,12 @@ def _apply_pending_state_dict(self) -> None:
531
535
self ._logger .info ("applying pending state dict" )
532
536
533
537
assert self ._pending_state_dict is not None , "checkpoint was not staged"
538
+ assert (
539
+ self ._load_state_dict is not None
540
+ ), "user load_state_dict is not initialized."
534
541
self ._load_state_dict (self ._pending_state_dict ["user" ])
535
542
self ._pending_state_dict = None
543
+ self ._logger .info ("Loaded state dict." )
536
544
537
545
def should_commit (self , timeout : Optional [timedelta ] = None ) -> bool :
538
546
"""
@@ -602,6 +610,7 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
602
610
self ._batches_committed = state_dict ["batches_committed" ]
603
611
604
612
def _manager_state_dict (self ) -> Dict [str , object ]:
613
+ assert self ._user_state_dict is not None , "user state_dict is not initialized."
605
614
return {
606
615
"user" : self ._user_state_dict (),
607
616
"torchft" : self .state_dict (),
0 commit comments