@@ -271,6 +271,14 @@ def _rank_not_in_group(group: ProcessGroup):
271
271
return group == GroupMember .NON_GROUP_MEMBER
272
272
273
273
274
+ def _warn_not_in_group (op_name ):
275
+ global_rank = - 1 if GroupMember .WORLD is None else GroupMember .WORLD .rank ()
276
+ warnings .warn (
277
+ f"Running { op_name } on global rank { global_rank } which does not "
278
+ "belong to the given group."
279
+ )
280
+
281
+
274
282
def _get_group_rank (group : ProcessGroup , rank ):
275
283
"""
276
284
Helper that gets a given group's local rank in the group from a given global
@@ -879,6 +887,7 @@ def isend(tensor, dst, group=None, tag=0):
879
887
"""
880
888
_check_single_tensor (tensor , "tensor" )
881
889
if _rank_not_in_group (group ):
890
+ _warn_not_in_group ("isend" )
882
891
return
883
892
884
893
if group is None or group is GroupMember .WORLD :
@@ -908,6 +917,7 @@ def irecv(tensor, src=None, group=None, tag=0):
908
917
"""
909
918
_check_single_tensor (tensor , "tensor" )
910
919
if _rank_not_in_group (group ):
920
+ _warn_not_in_group ("irecv" )
911
921
return
912
922
913
923
if group is None or group is GroupMember .WORLD :
@@ -939,6 +949,7 @@ def send(tensor, dst, group=None, tag=0):
939
949
"""
940
950
_check_single_tensor (tensor , "tensor" )
941
951
if _rank_not_in_group (group ):
952
+ _warn_not_in_group ("send" )
942
953
return
943
954
944
955
if group is None or group is GroupMember .WORLD :
@@ -968,6 +979,7 @@ def recv(tensor, src=None, group=None, tag=0):
968
979
"""
969
980
_check_single_tensor (tensor , "tensor" )
970
981
if _rank_not_in_group (group ):
982
+ _warn_not_in_group ("recv" )
971
983
return - 1
972
984
973
985
if group is None :
@@ -1119,6 +1131,7 @@ def broadcast_multigpu(tensor_list, src, group=None, async_op=False, src_tensor=
1119
1131
1120
1132
"""
1121
1133
if _rank_not_in_group (group ):
1134
+ _warn_not_in_group ("broadcast_multigpu" )
1122
1135
return
1123
1136
1124
1137
opts = BroadcastOptions ()
@@ -1160,6 +1173,7 @@ def broadcast(tensor, src, group=None, async_op=False):
1160
1173
"""
1161
1174
_check_single_tensor (tensor , "tensor" )
1162
1175
if _rank_not_in_group (group ):
1176
+ _warn_not_in_group ("broadcast" )
1163
1177
return
1164
1178
1165
1179
opts = BroadcastOptions ()
@@ -1283,6 +1297,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False):
1283
1297
"""
1284
1298
_check_single_tensor (tensor , "tensor" )
1285
1299
if _rank_not_in_group (group ):
1300
+ _warn_not_in_group ("all_reduce" )
1286
1301
return
1287
1302
1288
1303
if tensor .is_complex ():
@@ -1339,6 +1354,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):
1339
1354
"""
1340
1355
_check_tensor_list (tensors , "tensor" )
1341
1356
if _rank_not_in_group (group ):
1357
+ _warn_not_in_group ("all_reduce_coalesced" )
1342
1358
return
1343
1359
1344
1360
if any ([t .is_complex () for t in tensors ]) and not supports_complex (op ):
@@ -1394,6 +1410,7 @@ def reduce_multigpu(
1394
1410
1395
1411
"""
1396
1412
if _rank_not_in_group (group ):
1413
+ _warn_not_in_group ("reduce_multigpu" )
1397
1414
return
1398
1415
1399
1416
opts = ReduceOptions ()
@@ -1439,6 +1456,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
1439
1456
"""
1440
1457
_check_single_tensor (tensor , "tensor" )
1441
1458
if _rank_not_in_group (group ):
1459
+ _warn_not_in_group ("reduce" )
1442
1460
return
1443
1461
1444
1462
opts = ReduceOptions ()
@@ -1505,6 +1523,7 @@ def all_gather_multigpu(
1505
1523
1506
1524
"""
1507
1525
if _rank_not_in_group (group ):
1526
+ _warn_not_in_group ("all_gather_multigpu" )
1508
1527
return
1509
1528
1510
1529
output_tensor_lists = [
@@ -1591,6 +1610,7 @@ def all_gather_object(object_list, obj, group=None):
1591
1610
['foo', 12, {1: 2}]
1592
1611
"""
1593
1612
if _rank_not_in_group (group ):
1613
+ _warn_not_in_group ("all_gather_object" )
1594
1614
return
1595
1615
1596
1616
input_tensor , local_size = _object_to_tensor (obj )
@@ -1684,6 +1704,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
1684
1704
['foo', 12, {1: 2}]
1685
1705
"""
1686
1706
if _rank_not_in_group (group ):
1707
+ _warn_not_in_group ("gather_object" )
1687
1708
return
1688
1709
1689
1710
# Ensure object_gather_list is specified appopriately.
@@ -1792,6 +1813,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
1792
1813
['foo', 12, {1: 2}]
1793
1814
"""
1794
1815
if _rank_not_in_group (group ):
1816
+ _warn_not_in_group ("broadcast_object_list" )
1795
1817
return
1796
1818
1797
1819
my_rank = get_rank ()
@@ -1903,6 +1925,7 @@ def scatter_object_list(
1903
1925
[{1: 2}]
1904
1926
"""
1905
1927
if _rank_not_in_group (group ):
1928
+ _warn_not_in_group ("scatter_object_list" )
1906
1929
return
1907
1930
1908
1931
if (
@@ -2003,6 +2026,7 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
2003
2026
_check_tensor_list (tensor_list , "tensor_list" )
2004
2027
_check_single_tensor (tensor , "tensor" )
2005
2028
if _rank_not_in_group (group ):
2029
+ _warn_not_in_group ("all_gather" )
2006
2030
return
2007
2031
2008
2032
tensor_list = [
@@ -2062,6 +2086,7 @@ def _all_gather_base(output_tensor, input_tensor, group=None, async_op=False):
2062
2086
_check_single_tensor (input_tensor , "input_tensor" )
2063
2087
_check_single_tensor (output_tensor , "output_tensor" )
2064
2088
if _rank_not_in_group (group ):
2089
+ _warn_not_in_group ("_all_gather_base" )
2065
2090
return
2066
2091
2067
2092
output_tensor = (
@@ -2136,6 +2161,7 @@ def all_gather_coalesced(
2136
2161
# We only check basic compatibility with C++ params here, C++ code will
2137
2162
# do shape and type checking.
2138
2163
if _rank_not_in_group (group ):
2164
+ _warn_not_in_group ("all_gather_coalesced" )
2139
2165
return
2140
2166
_check_tensor_list (input_tensor_list , "tensor_list" )
2141
2167
if not isinstance (output_tensor_lists , list ):
@@ -2206,6 +2232,7 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
2206
2232
gather_list = []
2207
2233
2208
2234
if _rank_not_in_group (group ):
2235
+ _warn_not_in_group ("gather" )
2209
2236
return
2210
2237
2211
2238
my_rank = get_rank ()
@@ -2262,6 +2289,7 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
2262
2289
scatter_list = []
2263
2290
2264
2291
if _rank_not_in_group (group ):
2292
+ _warn_not_in_group ("scatter" )
2265
2293
return
2266
2294
scatter_list = [
2267
2295
t if not t .is_complex () else torch .view_as_real (t ) for t in scatter_list
@@ -2347,6 +2375,7 @@ def reduce_scatter_multigpu(
2347
2375
2348
2376
"""
2349
2377
if _rank_not_in_group (group ):
2378
+ _warn_not_in_group ("reduce_scatter_multigpu" )
2350
2379
return
2351
2380
2352
2381
opts = ReduceScatterOptions ()
@@ -2383,6 +2412,7 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
2383
2412
_check_single_tensor (output , "output" )
2384
2413
_check_tensor_list (input_list , "input_list" )
2385
2414
if _rank_not_in_group (group ):
2415
+ _warn_not_in_group ("reduce_scatter" )
2386
2416
return
2387
2417
2388
2418
opts = ReduceScatterOptions ()
@@ -2420,6 +2450,7 @@ def _reduce_scatter_base(output, input, op=ReduceOp.SUM, group=None, async_op=Fa
2420
2450
_check_single_tensor (input , "input" )
2421
2451
2422
2452
if _rank_not_in_group (group ):
2453
+ _warn_not_in_group ("_reduce_scatter_base" )
2423
2454
return
2424
2455
2425
2456
opts = ReduceScatterOptions ()
@@ -2534,6 +2565,7 @@ def all_to_all_single(
2534
2565
tensor([4+4j, 8+8j, 12+12j, 16+16j]) # Rank 3
2535
2566
"""
2536
2567
if _rank_not_in_group (group ):
2568
+ _warn_not_in_group ("all_to_all_single" )
2537
2569
return
2538
2570
2539
2571
opts = AllToAllOptions ()
@@ -2655,6 +2687,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
2655
2687
2656
2688
"""
2657
2689
if _rank_not_in_group (group ):
2690
+ _warn_not_in_group ("all_to_all" )
2658
2691
return
2659
2692
2660
2693
opts = AllToAllOptions ()
@@ -2700,6 +2733,7 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None):
2700
2733
None, if not async_op or if not part of the group
2701
2734
"""
2702
2735
if _rank_not_in_group (group ):
2736
+ _warn_not_in_group ("barrier" )
2703
2737
return
2704
2738
2705
2739
opts = BarrierOptions ()
@@ -2780,6 +2814,7 @@ def monitored_barrier(group=GroupMember.WORLD, timeout=None, wait_all_ranks=Fals
2780
2814
# Need to call rank not in group before using the group, otherwise
2781
2815
# "Invalid process group" error is raised.
2782
2816
if _rank_not_in_group (group ):
2817
+ _warn_not_in_group ("monitored_barrier" )
2783
2818
return
2784
2819
2785
2820
if get_backend (group ) != Backend .GLOO :
0 commit comments