File tree Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Expand file tree Collapse file tree 2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -560,17 +560,19 @@ def _init_process_groups(self):
560
560
# numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4
561
561
# mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups.
562
562
dim_group = None
563
+ has_split_group = False
563
564
if (
564
565
bound_device_id := getattr (
565
566
default_group , "bound_device_id" , None
566
567
)
567
- ) is not None :
568
+ ) is not None and torch . cuda . is_available () :
568
569
dim_group = split_group (
569
570
parent_pg = default_group ,
570
571
pg_options = pg_options ,
571
572
split_ranks = pg_ranks_by_dim .tolist (),
572
573
group_desc = group_desc ,
573
574
)
575
+ has_split_group = True
574
576
575
577
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
576
578
# and append the `(group_tag, subgroup_ranks, and group_name)` tuple to the `dim_group_infos` list when
@@ -583,7 +585,7 @@ def _init_process_groups(self):
583
585
# We temporarily revert the re-use subgroup, since it breaks two internal tests.
584
586
# Temporarily reverting to resolve test timeout while root-causing.
585
587
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
586
- if bound_device_id is None :
588
+ if bound_device_id is None or not has_split_group :
587
589
dim_group = new_group (
588
590
ranks = subgroup_ranks ,
589
591
backend = backend ,
Original file line number Diff line number Diff line change @@ -186,7 +186,7 @@ def _chunk_tensor(
186
186
inner_param ,
187
187
rank ,
188
188
world_size ,
189
- torch .cuda .device_count (),
189
+ torch .accelerator .device_count (),
190
190
pg ,
191
191
)
192
192
You can’t perform that action at this time.
0 commit comments