Skip to content

Commit 6e4ae38

Browse files
authored
Change how TorchFT manages user_state_dict (#87)
* Change how TorchFT manages user_state_dict This PR closes some state_dict gaps when integrating with TorchTitan: 1. User state_dict() and load_state_dict() functions can be initialized lazily. 2. Change weights_only to False for torch.load as we may have to load some non-tensor states.
1 parent fa1630d commit 6e4ae38

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

Diff for: torchft/checkpointing.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,9 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
172172
data = f.read()
173173

174174
reader = io.BytesIO(data)
175-
return torch.load(reader, weights_only=True)
175+
# We have to set weights_only to False as there are some non-tensor
176+
# states like lr_scheduler.
177+
return torch.load(reader, weights_only=False)
176178

177179
def address(self) -> str:
178180
"""

Diff for: torchft/manager.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ class Manager:
8787
def __init__(
8888
self,
8989
pg: "ProcessGroup",
90-
load_state_dict: Callable[[T], None],
91-
state_dict: Callable[[], T],
90+
load_state_dict: Optional[Callable[[T], None]],
91+
state_dict: Optional[Callable[[], T]],
9292
min_replica_size: int,
9393
use_async_quorum: bool = True,
9494
timeout: timedelta = timedelta(seconds=60),
@@ -144,7 +144,7 @@ def __init__(
144144
transfering checkpoints to recovering replicas
145145
"""
146146
self._load_state_dict = load_state_dict
147-
self._state_dict = state_dict
147+
self._user_state_dict = state_dict
148148
self._pending_state_dict: Optional[Dict[str, object]] = None
149149
self._use_async_quorum = use_async_quorum
150150
self._timeout = timeout
@@ -159,8 +159,6 @@ def __init__(
159159
world_size = world_size or int(os.environ["WORLD_SIZE"])
160160
self._min_replica_size = min_replica_size
161161

162-
self._user_state_dict = state_dict
163-
164162
if checkpoint_transport is None:
165163
checkpoint_transport = CheckpointServer[Dict[str, T]](
166164
timeout=timeout,
@@ -226,6 +224,12 @@ def __init__(
226224
self._participating_rank: Optional[int] = None
227225
self._participating_world_size: int = 0
228226

227+
def set_state_dict_fns(
228+
self, load_state_dict: Callable[[T], None], state_dict: Callable[[], T]
229+
) -> None:
230+
self._load_state_dict = load_state_dict
231+
self._user_state_dict = state_dict
232+
229233
def shutdown(self, wait: bool = True) -> None:
230234
"""
231235
Shutdown the manager and checkpoint server.
@@ -531,8 +535,12 @@ def _apply_pending_state_dict(self) -> None:
531535
self._logger.info("applying pending state dict")
532536

533537
assert self._pending_state_dict is not None, "checkpoint was not staged"
538+
assert (
539+
self._load_state_dict is not None
540+
), "user load_state_dict is not initialized."
534541
self._load_state_dict(self._pending_state_dict["user"])
535542
self._pending_state_dict = None
543+
self._logger.info("Loaded state dict.")
536544

537545
def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
538546
"""
@@ -602,6 +610,7 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
602610
self._batches_committed = state_dict["batches_committed"]
603611

604612
def _manager_state_dict(self) -> Dict[str, object]:
613+
assert self._user_state_dict is not None, "user state_dict is not initialized."
605614
return {
606615
"user": self._user_state_dict(),
607616
"torchft": self.state_dict(),

Diff for: torchft/manager_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,37 @@ def test_state_dict(self, client_mock: MagicMock) -> None:
9595
self.assertEqual(manager.current_step(), 1234)
9696
self.assertEqual(manager.batches_committed(), 2345)
9797

98+
@patch("torchft.manager.ManagerClient", autospec=True)
99+
def test_user_state_dict(self, client_mock: MagicMock) -> None:
100+
manager = self._create_manager()
101+
102+
self.assertEqual(
103+
manager._manager_state_dict(),
104+
{
105+
"user": {},
106+
"torchft": {
107+
"step": 0,
108+
"batches_committed": 0,
109+
},
110+
},
111+
)
112+
113+
manager.set_state_dict_fns(
114+
self.load_state_dict,
115+
lambda: {"new_state": 1},
116+
)
117+
118+
self.assertEqual(
119+
manager._manager_state_dict(),
120+
{
121+
"user": {"new_state": 1},
122+
"torchft": {
123+
"step": 0,
124+
"batches_committed": 0,
125+
},
126+
},
127+
)
128+
98129
@patch("torchft.manager.ManagerClient", autospec=True)
99130
def test_quorum_happy(self, client_mock: MagicMock) -> None:
100131
manager = self._create_manager()

0 commit comments

Comments
 (0)