Skip to content

Commit 508350b

Browse files
authored
[FT] Change the integration from using ManagedDeviceMesh to set_all_r… (#1109)
Fixes #1105 While using `ManagedDeviceMesh` makes the integration code cleaner, `ManagedDeviceMesh` currently suffers from the composability issue with TP due to the limitations of `DeviceMesh`. This PR changes the integration to using `FSDP.set_all_reduce_hook()`. We will revisit the `ManagedDeviceMesh` once `DeviceMesh` becomes more friendly to the inheritance use cases.
1 parent 5078e92 commit 508350b

File tree

3 files changed

+74
-43
lines changed

3 files changed

+74
-43
lines changed

torchtitan/components/ft.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
from typing import Optional
1111

1212
import torch
13+
import torch.distributed as dist
1314
import torch.distributed._functional_collectives as funcol
15+
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
1416
from torch.distributed.device_mesh import DeviceMesh
17+
from torch.distributed.distributed_c10d import ReduceOp
1518
from torch.distributed.tensor import DTensor
1619
from torchtitan.config_manager import JobConfig
1720
from torchtitan.distributed import ParallelDims
@@ -34,6 +37,9 @@ def __init__(
3437
self._manager = manager
3538
self.group_size = group_size
3639
self.replica_id = replica_id
40+
if has_torchft and manager is not None:
41+
self.replicate_pg = ft.process_group.ManagedProcessGroup(self._manager)
42+
self.replicate_pg.register("dp_replicate")
3743

3844
@property
3945
def enabled(self) -> bool:
@@ -47,6 +53,17 @@ def manager(self) -> "ft.Manager":
4753
def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]:
4854
return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank
4955

56+
def set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None:
57+
def all_reduce_hook(output):
58+
dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG)
59+
60+
def apply_set_all_reduce_hook(m):
61+
if isinstance(m, FSDPModule):
62+
m.set_all_reduce_hook(all_reduce_hook)
63+
64+
for part in model_parts:
65+
part.apply(apply_set_all_reduce_hook)
66+
5067

5168
def init_ft_manager(job: JobConfig) -> FTManager:
5269
"""Initialize the FT manager if TorchFT is enabled.
@@ -55,7 +72,7 @@ def init_ft_manager(job: JobConfig) -> FTManager:
5572
job (JobConfig): The job configuration.
5673
5774
Returns:
58-
Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None.
75+
FTManager: A wrapper around TorchFT.Manager
5976
"""
6077
if not job.fault_tolerance.enable:
6178
return FTManager(None)
@@ -66,7 +83,7 @@ def init_ft_manager(job: JobConfig) -> FTManager:
6683
if job.fault_tolerance.min_replica_size < 1:
6784
raise ValueError("At least one FT replica is required.")
6885

69-
pg = ft.ProcessGroupBabyNCCL()
86+
pg = ft.ProcessGroupNCCL()
7087

7188
return FTManager(
7289
ft.Manager(

torchtitan/distributed/utils.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,51 @@
1717
from torch.distributed.device_mesh import DeviceMesh
1818
from torch.distributed.tensor import DTensor
1919

20-
from torchtitan.components.ft import ft_clip_grad_norm_util, ft_dist_reduce
2120
from torchtitan.tools.logging import logger
2221
from torchtitan.tools.utils import device_module, device_type
2322

2423

25-
def _dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float:
26-
# Remove FT replicate dimension if it exists.
27-
x, reduceOp, mesh = ft_dist_reduce(x, reduceOp, mesh)
24+
def _dist_reduce(
25+
x: torch.Tensor,
26+
reduceOp: str,
27+
mesh: DeviceMesh,
28+
extra_pg: dist.ProcessGroup | None = None,
29+
) -> float:
30+
"""Perform distributed reduction on a tensor.
2831
32+
Args:
33+
x (torch.Tensor): Input tensor.
34+
reduceOp (str): Reduce operation to perform.
35+
mesh (DeviceMesh): Device mesh to use for reduction.
36+
extra_pg (dist.ProcessGroup, optional): Extra process group to use for reduction.
37+
Defaults to None. If provided, this all_reduce will be called for the extra
38+
process group, and then the result will be all_reduced for the mesh.
39+
"""
2940
if isinstance(x, DTensor):
3041
# functional collectives do not support DTensor inputs
3142
x = x.full_tensor()
43+
44+
if extra_pg is not None:
45+
x = funcol.all_reduce(x, reduceOp=reduceOp, group=extra_pg)
46+
3247
assert x.numel() == 1 # required by `.item()`
3348
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
3449

3550

36-
def dist_max(x: torch.Tensor, mesh: DeviceMesh) -> float:
37-
return _dist_reduce(x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh)
51+
def dist_max(
52+
x: torch.Tensor, mesh: DeviceMesh, extra_pg: dist.ProcessGroup | None
53+
) -> float:
54+
return _dist_reduce(
55+
x, reduceOp=c10d.ReduceOp.MAX.name, mesh=mesh, extra_pg=extra_pg
56+
)
3857

3958

40-
def dist_mean(x: torch.Tensor, mesh: DeviceMesh) -> float:
41-
return _dist_reduce(x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh)
59+
def dist_mean(
60+
x: torch.Tensor, mesh: DeviceMesh, extra_pg: dist.ProcessGroup | None
61+
) -> float:
62+
return _dist_reduce(
63+
x, reduceOp=c10d.ReduceOp.AVG.name, mesh=mesh, extra_pg=extra_pg
64+
)
4265

4366

4467
def set_determinism(
@@ -301,8 +324,6 @@ def clip_grad_norm_(
301324
# Will reach here if any non-PP parallelism is used.
302325
# If only using PP, total_norm will be a local tensor.
303326

304-
# Remove FT replicate dimension if it exists.
305-
total_norm = ft_clip_grad_norm_util(total_norm)
306327
total_norm = total_norm.full_tensor()
307328

308329
if pp_mesh is not None:

torchtitan/train.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -78,32 +78,19 @@ def __init__(self, job_config: JobConfig):
7878
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
7979
# Device has to be set before creating TorchFT manager.
8080
device_module.set_device(self.device)
81-
ft_manager = ft.init_ft_manager(job_config)
8281

8382
# init distributed
8483
world_size = int(os.environ["WORLD_SIZE"])
8584
parallelism_config = job_config.parallelism
86-
if not ft_manager.enabled:
87-
self.parallel_dims = parallel_dims = ParallelDims(
88-
dp_shard=parallelism_config.data_parallel_shard_degree,
89-
dp_replicate=parallelism_config.data_parallel_replicate_degree,
90-
cp=parallelism_config.context_parallel_degree,
91-
tp=parallelism_config.tensor_parallel_degree,
92-
pp=parallelism_config.pipeline_parallel_degree,
93-
world_size=world_size,
94-
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
95-
)
96-
else:
97-
self.parallel_dims = parallel_dims = ft.FTParallelDims(
98-
dp_shard=parallelism_config.data_parallel_shard_degree,
99-
dp_replicate=parallelism_config.data_parallel_replicate_degree,
100-
cp=parallelism_config.context_parallel_degree,
101-
tp=parallelism_config.tensor_parallel_degree,
102-
pp=parallelism_config.pipeline_parallel_degree,
103-
world_size=world_size,
104-
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
105-
ft_manager=ft_manager,
106-
)
85+
self.parallel_dims = parallel_dims = ParallelDims(
86+
dp_shard=parallelism_config.data_parallel_shard_degree,
87+
dp_replicate=parallelism_config.data_parallel_replicate_degree,
88+
cp=parallelism_config.context_parallel_degree,
89+
tp=parallelism_config.tensor_parallel_degree,
90+
pp=parallelism_config.pipeline_parallel_degree,
91+
world_size=world_size,
92+
enable_loss_parallel=not parallelism_config.disable_loss_parallel,
93+
)
10794
dist_utils.init_distributed(job_config)
10895

10996
# build meshes
@@ -114,6 +101,12 @@ def __init__(self, job_config: JobConfig):
114101
else:
115102
dp_degree, dp_rank = 1, 0
116103

104+
self.ft_manager = ft.init_ft_manager(job_config)
105+
# If TorchFT is enabled, the dp_rank and dp_degree, which are used for
106+
# dataloader must be changed.
107+
if self.ft_manager.enabled:
108+
dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank)
109+
117110
# Set random seed, and maybe enable deterministic mode
118111
# (mainly for debugging, expect perf loss).
119112
dist_utils.set_determinism(
@@ -131,11 +124,6 @@ def __init__(self, job_config: JobConfig):
131124
else None
132125
)
133126

134-
# If TorchFT is enabled, the dp_rank and dp_degree, which are used for
135-
# dataloader must be changed.
136-
if ft_manager.enabled:
137-
dp_degree, dp_rank = ft_manager.get_dp_info(dp_degree, dp_rank)
138-
139127
self.dataloader = self.train_spec.build_dataloader_fn(
140128
dp_world_size=dp_degree,
141129
dp_rank=dp_rank,
@@ -241,6 +229,9 @@ def __init__(self, job_config: JobConfig):
241229

242230
self.model_parts = [model]
243231

232+
if self.ft_manager.enabled:
233+
self.ft_manager.set_all_reduce_hook(self.model_parts)
234+
244235
# initialize device memory monitor and get peak flops for MFU calculation
245236
device_memory_monitor = self.metrics_processor.device_memory_monitor
246237
gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name)
@@ -254,7 +245,7 @@ def __init__(self, job_config: JobConfig):
254245

255246
# build optimizer after applying parallelisms to the model
256247
self.optimizers = self.train_spec.build_optimizers_fn(
257-
self.model_parts, job_config, ft_manager
248+
self.model_parts, job_config, self.ft_manager
258249
)
259250
self.lr_schedulers = self.train_spec.build_lr_schedulers_fn(
260251
self.optimizers, job_config
@@ -280,7 +271,7 @@ def __init__(self, job_config: JobConfig):
280271
lr_schedulers=self.lr_schedulers,
281272
states={"train_state": self},
282273
job_config=job_config,
283-
ft_manager=ft_manager,
274+
ft_manager=self.ft_manager,
284275
)
285276

286277
self.train_context = dist_utils.get_train_context(
@@ -384,11 +375,13 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
384375
parallel_dims.dp_replicate_enabled
385376
or parallel_dims.dp_shard_enabled
386377
or parallel_dims.cp_enabled
378+
or self.ft_manager.enabled
387379
):
388380
loss = loss.detach()
381+
ft_pg = self.ft_manager.replicate_pg if self.ft_manager.enabled else None
389382
global_avg_loss, global_max_loss = (
390-
dist_utils.dist_mean(loss, world_mesh["dp_cp"]),
391-
dist_utils.dist_max(loss, world_mesh["dp_cp"]),
383+
dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg),
384+
dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg),
392385
)
393386
else:
394387
global_avg_loss = global_max_loss = loss.detach().item()

0 commit comments

Comments
 (0)