Skip to content

Commit 18955d3

Browse files
mrshenlifacebook-github-bot
authored andcommitted
Raise warning when calling collectives on non-member group objects (pytorch#67639)
Summary: Pull Request resolved: pytorch#67639 Due to BC considerations, we cannot directly error out, as that might break existing applications. Raise warnings first to improve debuggability. cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Test Plan: Imported from OSS Reviewed By: rohan-varma Differential Revision: D32075151 Pulled By: mrshenli fbshipit-source-id: 5680d420f5f6cd3f74a36616c03350e8a976b363
1 parent 54241a9 commit 18955d3

File tree

4 files changed

+87
-0
lines changed

4 files changed

+87
-0
lines changed

test/distributed/test_c10d_common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,36 @@ def _test_sequence_num_set_new_group(self, backend):
656656
dist.all_gather_object(obj_list, subgroup_seq, group=subgroup)
657657
self.assertEqual(len(set(obj_list)), 1)
658658

659+
def _test_warn_not_in_group(self, backend):
660+
store = dist.FileStore(self.file_name, self.world_size)
661+
dist.init_process_group(
662+
backend,
663+
world_size=self.world_size,
664+
rank=self.rank,
665+
store=store,
666+
)
667+
in_group_ranks = list(filter(lambda x: x % 2 == 0, range(self.world_size)))
668+
group = dist.new_group(in_group_ranks)
669+
670+
x = torch.zeros(2, 2).cuda(self.rank)
671+
xs = [torch.zeros(2, 2).cuda(self.rank) for _ in range(len(in_group_ranks))]
672+
if self.rank not in in_group_ranks:
673+
msg = ".*{}.*does not belong to.*"
674+
with self.assertWarnsOnceRegex(UserWarning, msg.format("all_gather")):
675+
dist.all_gather(xs, x, group=group)
676+
with self.assertWarnsOnceRegex(UserWarning, msg.format("all_reduce")):
677+
dist.all_reduce(x, group=group)
678+
with self.assertWarnsOnceRegex(UserWarning, msg.format("barrier")):
679+
dist.barrier(group=group)
680+
with self.assertWarnsOnceRegex(UserWarning, msg.format("broadcast")):
681+
dist.broadcast(x, src=0, group=group)
682+
else:
683+
dist.all_gather(xs, x, group=group)
684+
dist.all_reduce(x, group=group)
685+
dist.barrier(group=group)
686+
dist.broadcast(x, src=0, group=group)
687+
688+
659689
class CommTest(AbstractCommTest, MultiProcessTestCase):
660690
def setUp(self):
661691
super(CommTest, self).setUp()

test/distributed/test_c10d_gloo.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2320,6 +2320,11 @@ def test_gloo_barrier_device_ids(self):
23202320
with self.assertRaisesRegex(RuntimeError, "device_ids not supported"):
23212321
c10d.barrier(device_ids=[self.rank])
23222322

2323+
@skip_if_lt_x_gpu(2)
2324+
@requires_gloo()
2325+
def test_gloo_warn_not_in_group(self):
2326+
self._test_warn_not_in_group(backend="gloo")
2327+
23232328

23242329
if __name__ == "__main__":
23252330
assert (

test/distributed/test_c10d_nccl.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,6 +2629,23 @@ def test_nccl_barrier_device_ids_function_argument(self):
26292629
with self.assertRaisesRegex(RuntimeError, "Invalid function argument"):
26302630
c10d.barrier(device_ids=self.rank)
26312631

2632+
@requires_nccl()
2633+
@skip_if_lt_x_gpu(2)
2634+
@with_dist_debug_levels(levels=["DETAIL"])
2635+
def test_nccl_warn_not_in_group_debug_detail(self):
2636+
self._test_warn_not_in_group(backend="nccl")
2637+
2638+
@requires_nccl()
2639+
@skip_if_lt_x_gpu(2)
2640+
@with_dist_debug_levels(levels=["INFO"])
2641+
def test_nccl_warn_not_in_group_debug_info(self):
2642+
self._test_warn_not_in_group(backend="nccl")
2643+
2644+
@requires_nccl()
2645+
@skip_if_lt_x_gpu(2)
2646+
@with_dist_debug_levels(levels=["OFF"])
2647+
def test_nccl_warn_not_in_group_debug_off(self):
2648+
self._test_warn_not_in_group(backend="nccl")
26322649

26332650
if __name__ == "__main__":
26342651
assert (

torch/distributed/distributed_c10d.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,14 @@ def _rank_not_in_group(group: ProcessGroup):
271271
return group == GroupMember.NON_GROUP_MEMBER
272272

273273

274+
def _warn_not_in_group(op_name):
275+
global_rank = -1 if GroupMember.WORLD is None else GroupMember.WORLD.rank()
276+
warnings.warn(
277+
f"Running {op_name} on global rank {global_rank} which does not "
278+
"belong to the given group."
279+
)
280+
281+
274282
def _get_group_rank(group: ProcessGroup, rank):
275283
"""
276284
Helper that gets a given group's local rank in the group from a given global
@@ -879,6 +887,7 @@ def isend(tensor, dst, group=None, tag=0):
879887
"""
880888
_check_single_tensor(tensor, "tensor")
881889
if _rank_not_in_group(group):
890+
_warn_not_in_group("isend")
882891
return
883892

884893
if group is None or group is GroupMember.WORLD:
@@ -908,6 +917,7 @@ def irecv(tensor, src=None, group=None, tag=0):
908917
"""
909918
_check_single_tensor(tensor, "tensor")
910919
if _rank_not_in_group(group):
920+
_warn_not_in_group("irecv")
911921
return
912922

913923
if group is None or group is GroupMember.WORLD:
@@ -939,6 +949,7 @@ def send(tensor, dst, group=None, tag=0):
939949
"""
940950
_check_single_tensor(tensor, "tensor")
941951
if _rank_not_in_group(group):
952+
_warn_not_in_group("send")
942953
return
943954

944955
if group is None or group is GroupMember.WORLD:
@@ -968,6 +979,7 @@ def recv(tensor, src=None, group=None, tag=0):
968979
"""
969980
_check_single_tensor(tensor, "tensor")
970981
if _rank_not_in_group(group):
982+
_warn_not_in_group("recv")
971983
return -1
972984

973985
if group is None:
@@ -1119,6 +1131,7 @@ def broadcast_multigpu(tensor_list, src, group=None, async_op=False, src_tensor=
11191131
11201132
"""
11211133
if _rank_not_in_group(group):
1134+
_warn_not_in_group("broadcast_multigpu")
11221135
return
11231136

11241137
opts = BroadcastOptions()
@@ -1160,6 +1173,7 @@ def broadcast(tensor, src, group=None, async_op=False):
11601173
"""
11611174
_check_single_tensor(tensor, "tensor")
11621175
if _rank_not_in_group(group):
1176+
_warn_not_in_group("broadcast")
11631177
return
11641178

11651179
opts = BroadcastOptions()
@@ -1283,6 +1297,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
12831297
"""
12841298
_check_single_tensor(tensor, "tensor")
12851299
if _rank_not_in_group(group):
1300+
_warn_not_in_group("all_reduce")
12861301
return
12871302

12881303
if tensor.is_complex():
@@ -1339,6 +1354,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
13391354
"""
13401355
_check_tensor_list(tensors, "tensor")
13411356
if _rank_not_in_group(group):
1357+
_warn_not_in_group("all_reduce_coalesced")
13421358
return
13431359

13441360
if any([t.is_complex() for t in tensors]) and not supports_complex(op):
@@ -1394,6 +1410,7 @@ def reduce_multigpu(
13941410
13951411
"""
13961412
if _rank_not_in_group(group):
1413+
_warn_not_in_group("reduce_multigpu")
13971414
return
13981415

13991416
opts = ReduceOptions()
@@ -1439,6 +1456,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
14391456
"""
14401457
_check_single_tensor(tensor, "tensor")
14411458
if _rank_not_in_group(group):
1459+
_warn_not_in_group("reduce")
14421460
return
14431461

14441462
opts = ReduceOptions()
@@ -1505,6 +1523,7 @@ def all_gather_multigpu(
15051523
15061524
"""
15071525
if _rank_not_in_group(group):
1526+
_warn_not_in_group("all_gather_multigpu")
15081527
return
15091528

15101529
output_tensor_lists = [
@@ -1591,6 +1610,7 @@ def all_gather_object(object_list, obj, group=None):
15911610
['foo', 12, {1: 2}]
15921611
"""
15931612
if _rank_not_in_group(group):
1613+
_warn_not_in_group("all_gather_object")
15941614
return
15951615

15961616
input_tensor, local_size = _object_to_tensor(obj)
@@ -1684,6 +1704,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
16841704
['foo', 12, {1: 2}]
16851705
"""
16861706
if _rank_not_in_group(group):
1707+
_warn_not_in_group("gather_object")
16871708
return
16881709

16891710
# Ensure object_gather_list is specified appopriately.
@@ -1792,6 +1813,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
17921813
['foo', 12, {1: 2}]
17931814
"""
17941815
if _rank_not_in_group(group):
1816+
_warn_not_in_group("broadcast_object_list")
17951817
return
17961818

17971819
my_rank = get_rank()
@@ -1903,6 +1925,7 @@ def scatter_object_list(
19031925
[{1: 2}]
19041926
"""
19051927
if _rank_not_in_group(group):
1928+
_warn_not_in_group("scatter_object_list")
19061929
return
19071930

19081931
if (
@@ -2003,6 +2026,7 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
20032026
_check_tensor_list(tensor_list, "tensor_list")
20042027
_check_single_tensor(tensor, "tensor")
20052028
if _rank_not_in_group(group):
2029+
_warn_not_in_group("all_gather")
20062030
return
20072031

20082032
tensor_list = [
@@ -2062,6 +2086,7 @@ def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
20622086
_check_single_tensor(input_tensor, "input_tensor")
20632087
_check_single_tensor(output_tensor, "output_tensor")
20642088
if _rank_not_in_group(group):
2089+
_warn_not_in_group("_all_gather_base")
20652090
return
20662091

20672092
output_tensor = (
@@ -2136,6 +2161,7 @@ def all_gather_coalesced(
21362161
# We only check basic compatibility with C++ params here, C++ code will
21372162
# do shape and type checking.
21382163
if _rank_not_in_group(group):
2164+
_warn_not_in_group("all_gather_coalesced")
21392165
return
21402166
_check_tensor_list(input_tensor_list, "tensor_list")
21412167
if not isinstance(output_tensor_lists, list):
@@ -2206,6 +2232,7 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
22062232
gather_list = []
22072233

22082234
if _rank_not_in_group(group):
2235+
_warn_not_in_group("gather")
22092236
return
22102237

22112238
my_rank = get_rank()
@@ -2262,6 +2289,7 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
22622289
scatter_list = []
22632290

22642291
if _rank_not_in_group(group):
2292+
_warn_not_in_group("scatter")
22652293
return
22662294
scatter_list = [
22672295
t if not t.is_complex() else torch.view_as_real(t) for t in scatter_list
@@ -2347,6 +2375,7 @@ def reduce_scatter_multigpu(
23472375
23482376
"""
23492377
if _rank_not_in_group(group):
2378+
_warn_not_in_group("reduce_scatter_multigpu")
23502379
return
23512380

23522381
opts = ReduceScatterOptions()
@@ -2383,6 +2412,7 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
23832412
_check_single_tensor(output, "output")
23842413
_check_tensor_list(input_list, "input_list")
23852414
if _rank_not_in_group(group):
2415+
_warn_not_in_group("reduce_scatter")
23862416
return
23872417

23882418
opts = ReduceScatterOptions()
@@ -2420,6 +2450,7 @@ def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=Fa
24202450
_check_single_tensor(input, "input")
24212451

24222452
if _rank_not_in_group(group):
2453+
_warn_not_in_group("_reduce_scatter_base")
24232454
return
24242455

24252456
opts = ReduceScatterOptions()
@@ -2534,6 +2565,7 @@ def all_to_all_single(
25342565
tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3
25352566
"""
25362567
if _rank_not_in_group(group):
2568+
_warn_not_in_group("all_to_all_single")
25372569
return
25382570

25392571
opts = AllToAllOptions()
@@ -2655,6 +2687,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
26552687
26562688
"""
26572689
if _rank_not_in_group(group):
2690+
_warn_not_in_group("all_to_all")
26582691
return
26592692

26602693
opts = AllToAllOptions()
@@ -2700,6 +2733,7 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
27002733
None, if not async_op or if not part of the group
27012734
"""
27022735
if _rank_not_in_group(group):
2736+
_warn_not_in_group("barrier")
27032737
return
27042738

27052739
opts = BarrierOptions()
@@ -2780,6 +2814,7 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals
27802814
# Need to call rank not in group before using the group, otherwise
27812815
# "Invalid process group" error is raised.
27822816
if _rank_not_in_group(group):
2817+
_warn_not_in_group("monitored_barrier")
27832818
return
27842819

27852820
if get_backend(group) != Backend.GLOO:

0 commit comments

Comments
 (0)