Skip to content

Commit c052e42

Browse files
committed
initial commit
1 parent 39a40b2 commit c052e42

File tree

4 files changed

+212
-30
lines changed

4 files changed

+212
-30
lines changed

torchft/checkpointing.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Callable, Generic, TypeVar
2222

2323
import torch
24+
import torch.distributed as dist
2425

2526
from torchft.http import _IPv6HTTPServer
2627

@@ -76,6 +77,14 @@ def do_GET(self):
7677

7778
sd = state_dict()
7879

80+
def func(obj):
81+
if isinstance(obj, dist.tensor.DTensor) and hasattr(obj, "device_mesh") and hasattr(obj.device_mesh, "replicate_pg"):
82+
obj.device_mesh.replicate_pg = None
83+
84+
from torch.utils._pytree import tree_map
85+
86+
tree_map(func, sd["user"])
87+
7988
torch.save(sd, self.wfile)
8089
except Exception as e:
8190
logger.exception(
@@ -113,7 +122,9 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
113122
data = f.read()
114123

115124
reader = io.BytesIO(data)
116-
return torch.load(reader, weights_only=True)
125+
print(f"{reader.read(100)=}")
126+
reader.seek(0)
127+
return torch.load(reader, weights_only=False)
117128

118129
def address(self) -> str:
119130
"""

torchft/manager.py

+9
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,15 @@ def _async_quorum(
487487
self._pending_state_dict = CheckpointServer.load_from_address(
488488
checkpoint_server_address, timeout=self._timeout
489489
)
490+
491+
def func(obj):
492+
if isinstance(obj, dist.tensor.DTensor) and hasattr(obj, "device_mesh") and hasattr(obj.device_mesh, "replicate_pg"):
493+
obj.device_mesh.replicate_pg = self._pg
494+
495+
from torch.utils._pytree import tree_map
496+
497+
tree_map(func, self._pending_state_dict["user"])
498+
490499
self.load_state_dict(self._pending_state_dict["torchft"])
491500
# we apply the user state dict only when safe from the main thread
492501

torchft/process_group.py

+33-29
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ def __init__(
864864
raise ValueError(
865865
"ManagedDeviceMesh doesn't support both mesh and parent are None."
866866
)
867-
self.mesh = mesh
867+
self._mesh = mesh
868868
self.mesh_dim_names = mesh_dim_names
869869
self.replicate_pg = replicate_pg
870870
self.replicate_dim = replicate_dim
@@ -893,17 +893,17 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
893893
elif mesh_dim_names in self.flatten_meshes:
894894
return self.flatten_meshes[mesh_dim_names]
895895
else:
896-
assert self.mesh is not None
897-
return self.mesh[mesh_dim_names]
896+
assert self._mesh is not None
897+
return self._mesh[mesh_dim_names]
898898
else:
899899
assert isinstance(mesh_dim_names, tuple)
900900
if self.replicate_dim_name in mesh_dim_names:
901-
assert self.mesh is not None
902-
return self.mesh[mesh_dim_names]
901+
assert self._mesh is not None
902+
return self._mesh[mesh_dim_names]
903903
else:
904-
assert self.mesh is not None
904+
assert self._mesh is not None
905905
return ManagedDeviceMesh(
906-
self.mesh[mesh_dim_names],
906+
self._mesh[mesh_dim_names],
907907
mesh_dim_names,
908908
self.replicate_pg,
909909
mesh_dim_names.index(self.replicate_dim_name),
@@ -924,8 +924,8 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGr
924924
elif dim == self.replicate_dim:
925925
return self.replicate_pg
926926
else:
927-
assert self.mesh is not None
928-
return self.mesh.get_group(self._real_mesh_dim(dim))
927+
assert self._mesh is not None
928+
return self._mesh.get_group(self._real_mesh_dim(dim))
929929

930930
def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
931931
flatten_mesh = _FlattenDeviceMesh(self)
@@ -939,32 +939,32 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
939939

940940
def size(self, mesh_dim: Optional[int] = None) -> int:
941941
if mesh_dim is None:
942-
if self.mesh is None:
942+
if self._mesh is None:
943943
return self.replicate_pg.size()
944944
else:
945-
assert self.mesh is not None
946-
return self.mesh.size() * self.replicate_pg.size()
945+
assert self._mesh is not None
946+
return self._mesh.size() * self.replicate_pg.size()
947947
elif mesh_dim == self.replicate_dim:
948948
return self.replicate_pg.size()
949949
else:
950-
assert self.mesh is not None
951-
return self.mesh.size(self._real_mesh_dim(mesh_dim))
950+
assert self._mesh is not None
951+
return self._mesh.size(self._real_mesh_dim(mesh_dim))
952952

953953
@property
954954
def ndim(self) -> int:
955-
assert self.mesh is not None
956-
return self.mesh.ndim + 1
955+
assert self._mesh is not None
956+
return self._mesh.ndim + 1
957957

958958
@property
959959
def shape(self) -> Tuple[int, ...]:
960-
assert self.mesh is not None
961-
ret: List[int] = list(self.mesh.shape)
960+
assert self._mesh is not None
961+
ret: List[int] = list(self._mesh.shape)
962962
ret.insert(self.replicate_dim, self.replicate_pg.size())
963963
return tuple(ret)
964964

965965
def get_rank(self) -> int:
966-
assert self.mesh is not None
967-
return self.mesh.get_rank()
966+
assert self._mesh is not None
967+
return self._mesh.get_rank()
968968

969969
def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
970970
if isinstance(mesh_dim, str):
@@ -973,33 +973,37 @@ def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
973973
dim = 0 if mesh_dim is None else int(mesh_dim)
974974

975975
if mesh_dim is None:
976-
if self.mesh is None:
976+
if self._mesh is None:
977977
return get_rank(self.replicate_pg)
978978

979979
assert self.replicate_dim == 0, "replicate_dim must be the first one"
980-
assert self.mesh is not None
981-
other_dim_size = self.mesh.size()
982-
assert self.mesh is not None
983-
other_dim_rank = self.mesh.get_local_rank()
980+
assert self._mesh is not None
981+
other_dim_size = self._mesh.size()
982+
assert self._mesh is not None
983+
other_dim_rank = self._mesh.get_local_rank()
984984
replicate_pg_rank = get_rank(self.replicate_pg)
985985
return other_dim_size * replicate_pg_rank + other_dim_rank
986986
elif dim == self.replicate_dim:
987987
return get_rank(self.replicate_pg)
988988
else:
989-
assert self.mesh is not None
990-
return self.mesh.get_local_rank(self._real_mesh_dim(dim))
989+
assert self._mesh is not None
990+
return self._mesh.get_local_rank(self._real_mesh_dim(dim))
991991

992992
def get_coordinate(self) -> Optional[List[int]]:
993993
"""
994994
Return the relative indices of this rank relative to all
995995
dimensions of the mesh. If this rank is not part of the mesh, return None.
996996
"""
997-
assert self.mesh is not None
998-
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
997+
assert self._mesh is not None
998+
return self._mesh._coordinate_on_dim if self._mesh._coordinate_on_dim else None
999999

10001000
def get_all_groups(self) -> List[BaseProcessGroup]:
10011001
raise NotImplementedError
10021002

1003+
@property
1004+
def mesh(self):
1005+
return self._mesh.mesh
1006+
10031007

10041008
class _FlattenDeviceMesh(DeviceMesh):
10051009
def __init__(self, managed_mesh: ManagedDeviceMesh) -> None:

train_fsdp.py

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import os
2+
from datasets import load_dataset
3+
4+
import torch
5+
from transformers import LlamaForCausalLM, AutoTokenizer
6+
from torch.distributed._composable.fsdp import fully_shard
7+
import torch.distributed as dist
8+
from tqdm import tqdm
9+
from transformers.data import DataCollatorForSeq2Seq
10+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
11+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
12+
13+
from torchdata.stateful_dataloader import StatefulDataLoader
14+
15+
from torchft import (
16+
DistributedSampler,
17+
Manager,
18+
Optimizer,
19+
ProcessGroupBabyNCCL,
20+
ProcessGroupGloo,
21+
)
22+
from torchft.process_group import ft_init_device_mesh
23+
24+
def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None, manager=None):
25+
26+
if replica_group_size is None or sharding_group_size is None:
27+
raise ValueError("Both replica_group_size and sharding_group_size must be provided.")
28+
29+
device = device or f"cuda"
30+
31+
device_mesh = ft_init_device_mesh(
32+
device_type=device,
33+
mesh_shape=(replica_group_size, sharding_group_size),
34+
mesh_dim_names=("dp_replicate", "dp_shard"),
35+
replicate_dim=0,
36+
manager=manager,
37+
)
38+
if device_mesh is None:
39+
raise RuntimeError("Failed to create a valid device mesh.")
40+
41+
return device_mesh
42+
43+
def parallelize_llama(model, mesh):
44+
sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)]
45+
46+
for m in reversed(list(model.modules())):
47+
if any(c(m) for c in sharding_conditions):
48+
# fully_shard(m, mesh=mesh, reshard_after_forward=True)
49+
fully_shard(m, mesh=mesh)
50+
# fully_shard([model.model.embed_tokens, model.lm_head], mesh=mesh)
51+
fully_shard(model, mesh=mesh)
52+
53+
def main():
54+
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
55+
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
56+
NUM_REPLICAS = int(os.environ.get("NUM_REPLICAS", 2))
57+
58+
rank = int(os.environ.get("RANK", 0))
59+
60+
model_name = "Meta-Llama/Llama-3.2-1B-Instruct"
61+
tokenizer = AutoTokenizer.from_pretrained(model_name)
62+
model = LlamaForCausalLM.from_pretrained(model_name)
63+
64+
if not tokenizer.pad_token_id:
65+
tokenizer.pad_token_id = tokenizer.eos_token_id
66+
67+
# If there is a mismatch between tokenizer vocab size and embedding matrix,
68+
# throw a warning and then expand the embedding matrix
69+
assert len(tokenizer) == model.get_input_embeddings().weight.shape[0]
70+
71+
train_data = load_dataset("samsum", split="train")
72+
73+
class SAMSumDataset(torch.utils.data.Dataset):
74+
def __init__(self, data, tokenizer):
75+
self.data = data
76+
self.tokenizer = tokenizer
77+
def __getitem__(self, idx):
78+
text = self.data[idx]
79+
prompt = self.tokenizer.encode(tokenizer.bos_token + f"Summarize this dialog: {text['dialogue']}\n---\nSummary: ", add_special_tokens=False)
80+
summary = self.tokenizer.encode(text["summary"] + self.tokenizer.eos_token, add_special_tokens=False)
81+
input_ids = prompt + summary
82+
labels = len(prompt) * [-100] + summary
83+
return {"input_ids": input_ids, "labels": labels}
84+
def __len__(self):
85+
return len(self.data)
86+
87+
88+
train_dataset = SAMSumDataset(train_data, tokenizer)
89+
90+
batch_size = 8
91+
92+
sampler = DistributedSampler(
93+
train_dataset,
94+
replica_group=REPLICA_GROUP_ID,
95+
num_replica_groups=NUM_REPLICA_GROUPS,
96+
rank=rank,
97+
shuffle=True,
98+
num_replicas=NUM_REPLICAS,
99+
)
100+
101+
train_dataloader = StatefulDataLoader(train_dataset, batch_size=batch_size, collate_fn=DataCollatorForSeq2Seq(tokenizer), sampler=sampler)
102+
103+
def load_state_dict(state_dict):
104+
set_state_dict(
105+
model,
106+
optimizer.optim,
107+
model_state_dict=state_dict["model"],
108+
optim_state_dict=state_dict["optim"],
109+
)
110+
111+
112+
def state_dict():
113+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer.optim)
114+
return {
115+
"model": model_state_dict,
116+
"optim": optimizer_state_dict,
117+
}
118+
119+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120+
121+
pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo()
122+
123+
manager = Manager(
124+
pg=pg,
125+
min_replica_size=1,
126+
load_state_dict=load_state_dict,
127+
state_dict=state_dict,
128+
replica_id=f"train_fsdp_{REPLICA_GROUP_ID}",
129+
)
130+
131+
mesh = hsdp_device_mesh(NUM_REPLICA_GROUPS, NUM_REPLICAS, "cuda" if torch.cuda.is_available() else "cpu", manager=manager)
132+
133+
parallelize_llama(model, mesh)
134+
135+
model.to(device)
136+
137+
optimizer = Optimizer(manager, torch.optim.Adam(model.parameters(), lr=1e-5))
138+
139+
optimizer.zero_grad()
140+
141+
while manager.current_step() < 500:
142+
model.train()
143+
for batch in tqdm(train_dataloader):
144+
input_ids = batch["input_ids"].to(device)
145+
labels = batch["labels"].to(device)
146+
optimizer.zero_grad()
147+
148+
outputs = model(input_ids, labels=labels)
149+
loss = outputs.loss
150+
loss.backward()
151+
optimizer.step()
152+
153+
if manager.current_step() % 100 == 0:
154+
print(f"[{manager.current_step()}] loss = {loss.item()}")
155+
156+
157+
if __name__ == "__main__":
158+
main()

0 commit comments

Comments
 (0)