Skip to content

Commit b7ad52a

Browse files
zhangxiaoli73pytorchmergebot
authored andcommitted
Use new group instead of split group on non-CUDA device (pytorch#141469)
Motivation: Currently, `split_group` only works for NCCL backend. https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py#L4745. Then we need to use `use_group` on other non-CUDA device. Pull Request resolved: pytorch#141469 Approved by: https://github.com/kwen2501, https://github.com/gujinghui, https://github.com/albanD
1 parent 57c46af commit b7ad52a

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

torch/distributed/device_mesh.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,17 +560,19 @@ def _init_process_groups(self):
560560
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
561561
# mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups.
562562
dim_group = None
563+
has_split_group = False
563564
if (
564565
bound_device_id := getattr(
565566
default_group, "bound_device_id", None
566567
)
567-
) is not None:
568+
) is not None and torch.cuda.is_available():
568569
dim_group = split_group(
569570
parent_pg=default_group,
570571
pg_options=pg_options,
571572
split_ranks=pg_ranks_by_dim.tolist(),
572573
group_desc=group_desc,
573574
)
575+
has_split_group = True
574576

575577
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
576578
# and append the `(group_tag, subgroup_ranks, and group_name)` tuple to the `dim_group_infos` list when
@@ -583,7 +585,7 @@ def _init_process_groups(self):
583585
# We temporarily revert the re-use subgroup, since it breaks two internal tests.
584586
# Temporarily reverting to resolve test timeout while root-causing.
585587
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
586-
if bound_device_id is None:
588+
if bound_device_id is None or not has_split_group:
587589
dim_group = new_group(
588590
ranks=subgroup_ranks,
589591
backend=backend,

torch/distributed/tensor/parallel/fsdp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def _chunk_tensor(
186186
inner_param,
187187
rank,
188188
world_size,
189-
torch.cuda.device_count(),
189+
torch.accelerator.device_count(),
190190
pg,
191191
)
192192

0 commit comments

Comments
 (0)