Skip to content

Commit 4e5c337

Browse files
committed
[WIP][RFC] Required changes for integration with TorchTitan
Summary: We are not going to land this PR, this PR may be further divided into several PRs. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent ccf74d4 commit 4e5c337

File tree

4 files changed

+53
-14
lines changed

4 files changed

+53
-14
lines changed

torchft/checkpointing.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,10 @@ def do_GET(self):
134134
self.end_headers()
135135

136136
state_dict = ckpt_server._state_dict
137-
137+
self._logger.warning("Before torch.save ===================.")
138138
torch.save(state_dict, self.wfile)
139+
self._logger.warning("After torch.save ===================.")
140+
139141
except Exception as e:
140142
logger.exception(
141143
f"Exception in checkpoint server when handling {self.path=}: {e}",
@@ -172,7 +174,7 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
172174
data = f.read()
173175

174176
reader = io.BytesIO(data)
175-
return torch.load(reader, weights_only=True)
177+
return torch.load(reader, weights_only=False)
176178

177179
def address(self) -> str:
178180
"""

torchft/manager.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ class Manager:
8787
def __init__(
8888
self,
8989
pg: "ProcessGroup",
90-
load_state_dict: Callable[[T], None],
91-
state_dict: Callable[[], T],
90+
load_state_dict: Optional[Callable[[T], None]],
91+
state_dict: Optional[Callable[[], T]],
9292
min_replica_size: int,
9393
use_async_quorum: bool = True,
9494
timeout: timedelta = timedelta(seconds=60),
@@ -144,7 +144,6 @@ def __init__(
144144
transfering checkpoints to recovering replicas
145145
"""
146146
self._load_state_dict = load_state_dict
147-
self._state_dict = state_dict
148147
self._pending_state_dict: Optional[Dict[str, object]] = None
149148
self._use_async_quorum = use_async_quorum
150149
self._timeout = timeout
@@ -226,6 +225,12 @@ def __init__(
226225
self._participating_rank: Optional[int] = None
227226
self._participating_world_size: int = 0
228227

228+
def set_state_dict_fns(
229+
self, load_state_dict: Callable[T, None], state_dict: Callable[[], T]
230+
) -> None:
231+
self._load_state_dict = load_state_dict
232+
self._user_state_dict = state_dict
233+
229234
def shutdown(self, wait: bool = True) -> None:
230235
"""
231236
Shutdown the manager and checkpoint server.
@@ -533,6 +538,7 @@ def _apply_pending_state_dict(self) -> None:
533538
assert self._pending_state_dict is not None, "checkpoint was not staged"
534539
self._load_state_dict(self._pending_state_dict["user"])
535540
self._pending_state_dict = None
541+
self._logger.info("Loaded state dict.")
536542

537543
def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
538544
"""
@@ -602,10 +608,13 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
602608
self._batches_committed = state_dict["batches_committed"]
603609

604610
def _manager_state_dict(self) -> Dict[str, object]:
605-
return {
611+
self._logger.warning("Before state_dict ===================.")
612+
ret = {
606613
"user": self._user_state_dict(),
607614
"torchft": self.state_dict(),
608615
}
616+
self._logger.warning("After state_dict ===================.")
617+
return ret
609618

610619
def state_dict(self) -> Dict[str, int]:
611620
"""

torchft/optim.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
1313
"""
1414

15-
from typing import TYPE_CHECKING, Optional
15+
from typing import Any, TYPE_CHECKING, Optional
1616

1717
from torch.optim import Optimizer
1818

@@ -52,3 +52,11 @@ def step(self, closure: Optional[object] = None) -> None:
5252
assert closure is None, "optimizers that use closures are not supported"
5353
if self.manager.should_commit():
5454
self.optim.step()
55+
56+
@property
57+
def param_groups(self) -> Any:
58+
return self.optim.param_groups
59+
60+
@property
61+
def state(self) -> Any:
62+
return self.optim.state

torchft/process_group.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import threading
2222
from abc import ABC
2323
from datetime import timedelta
24-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
24+
from typing import Any, TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
2525

2626
import torch
2727
import torch.distributed as dist
@@ -858,6 +858,8 @@ def extend_device_mesh(
858858

859859

860860
class ManagedDeviceMesh(DeviceMesh):
861+
replicate_pg_singleton: Optional["ManagedProcessGroup"]
862+
861863
def __init__(
862864
self,
863865
mesh: Optional[DeviceMesh],
@@ -886,6 +888,15 @@ def __init__(
886888
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
887889
self._thread_id: Optional[int] = None
888890

891+
def __getstate__(self) -> Dict[str, Any]:
892+
state = self.__dict__.copy()
893+
state["replicate_pg"] = None
894+
return state
895+
896+
def __setstate__(self, state: Dict[str, Any]) -> None:
897+
self.__dict__.update(state)
898+
self.replicate_pg = self.replicate_pg_singleton
899+
889900
def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
890901
if isinstance(mesh_dim_names, str):
891902
if mesh_dim_names == self.replicate_dim_name:
@@ -903,13 +914,14 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
903914
return self.mesh[mesh_dim_names]
904915
else:
905916
assert isinstance(mesh_dim_names, tuple)
906-
if self.replicate_dim_name in mesh_dim_names:
917+
if self.replicate_dim_name not in mesh_dim_names:
907918
assert self.mesh is not None
908919
return self.mesh[mesh_dim_names]
909920
else:
910921
assert self.mesh is not None
922+
mesh_dim_names_wo_replicate = tuple(n for n in mesh_dim_names if n != self.replicate_dim_name)
911923
return ManagedDeviceMesh(
912-
self.mesh[mesh_dim_names],
924+
self.mesh[mesh_dim_names_wo_replicate],
913925
mesh_dim_names,
914926
self.replicate_pg,
915927
mesh_dim_names.index(self.replicate_dim_name),
@@ -944,14 +956,16 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
944956
return flatten_mesh
945957

946958
def size(self, mesh_dim: Optional[int] = None) -> int:
959+
replicate_pg_size = self.replicate_pg.size()
960+
replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
947961
if mesh_dim is None:
948962
if self.mesh is None:
949-
return self.replicate_pg.size()
963+
return replicate_pg_size
950964
else:
951965
assert self.mesh is not None
952-
return self.mesh.size() * self.replicate_pg.size()
966+
return self.mesh.size() * replicate_pg_size
953967
elif mesh_dim == self.replicate_dim:
954-
return self.replicate_pg.size()
968+
return replicate_pg_size
955969
else:
956970
assert self.mesh is not None
957971
return self.mesh.size(self._real_mesh_dim(mesh_dim))
@@ -1001,7 +1015,11 @@ def get_coordinate(self) -> Optional[List[int]]:
10011015
dimensions of the mesh. If this rank is not part of the mesh, return None.
10021016
"""
10031017
assert self.mesh is not None
1004-
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
1018+
ret = self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
1019+
if ret:
1020+
ret = ret.copy()
1021+
ret.insert(get_rank(self.replicate_pg), self.replicate_dim)
1022+
return ret
10051023

10061024
def get_all_groups(self) -> List[BaseProcessGroup]:
10071025
raise NotImplementedError
@@ -1076,6 +1094,8 @@ def ft_init_device_mesh(
10761094
# the same backend has been registered.
10771095
replicate_pg.register(mesh_dim_names[replicate_dim])
10781096

1097+
ManagedDeviceMesh.replicate_pg_singleton = replicate_pg
1098+
10791099
return ManagedDeviceMesh(
10801100
mesh=mesh,
10811101
mesh_dim_names=mesh_dim_names,

0 commit comments

Comments
 (0)