@@ -124,6 +124,20 @@ def allgather(
124
124
"""
125
125
raise NotImplementedError ("not implemented" )
126
126
127
+ # pyre-fixme[14]: inconsistent override
128
+ def allgather_into_tensor_coalesced (
129
+ self ,
130
+ output_tensors : List [torch .Tensor ],
131
+ input_tensors : List [torch .Tensor ],
132
+ opts : AllgatherOptions ,
133
+ ) -> Work :
134
+ """
135
+ Performs an allgather operation on coalesced tensors.
136
+
137
+ See torch.distributed.allgather_coalesced for more details.
138
+ """
139
+ raise NotImplementedError ("not implemented" )
140
+
127
141
# pyre-fixme[14]: inconsistent override
128
142
def allreduce (
129
143
self ,
@@ -212,6 +226,20 @@ def reduce_scatter(
212
226
"""
213
227
raise NotImplementedError ("not implemented" )
214
228
229
+ # pyre-fixme[14]: inconsistent override
230
+ def reduce_scatter_tensor_coalesced (
231
+ self ,
232
+ output_tensors : List [torch .Tensor ],
233
+ input_tensors : List [torch .Tensor ],
234
+ opts : ReduceScatterOptions ,
235
+ ) -> Work :
236
+ """
237
+ Performs a reduce-scatter operation on coalesced tensors.
238
+
239
+ See torch.distributed.reduce_scatter_tensor for more details.
240
+ """
241
+ raise NotImplementedError ("not implemented" )
242
+
215
243
# pyre-fixme[14]: inconsistent override
216
244
def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
217
245
"""
@@ -336,10 +364,20 @@ def allgather(
336
364
self ,
337
365
output_tensors : List [List [torch .Tensor ]],
338
366
input_tensor : List [torch .Tensor ],
339
- opts : object ,
367
+ opts : AllgatherOptions ,
340
368
) -> Work :
341
369
return self .parent .allgather (output_tensors , input_tensor , opts )
342
370
371
+ def allgather_into_tensor_coalesced (
372
+ self ,
373
+ output_tensors : List [torch .Tensor ],
374
+ input_tensors : List [torch .Tensor ],
375
+ opts : AllgatherOptions ,
376
+ ) -> Work :
377
+ return self .parent .allgather_into_tensor_coalesced (
378
+ output_tensors , input_tensors , opts
379
+ )
380
+
343
381
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
344
382
return self .parent .allreduce (tensors , opts )
345
383
@@ -377,6 +415,16 @@ def reduce_scatter(
377
415
) -> Work :
378
416
return self .parent .reduce_scatter (output_tensors , input_tensors , opts )
379
417
418
+ def reduce_scatter_tensor_coalesced (
419
+ self ,
420
+ output_tensors : List [torch .Tensor ],
421
+ input_tensors : List [torch .Tensor ],
422
+ opts : ReduceScatterOptions ,
423
+ ) -> Work :
424
+ return self .parent .reduce_scatter_tensor_coalesced (
425
+ output_tensors , input_tensors , opts
426
+ )
427
+
380
428
def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
381
429
return self .parent .send (tensors , dst_rank , tag )
382
430
@@ -402,8 +450,15 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
402
450
self ._timeout = timeout
403
451
404
452
def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
453
+ pg = BaseProcessGroup (store , rank , world_size )
454
+ pg ._set_default_backend (ProcessGroup .BackendType .GLOO )
405
455
# pyre-fixme[16]: no attribute ProcessGroupGloo
406
- return BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
456
+ backend_class = BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
457
+ backend_class ._set_sequence_number_for_group ()
458
+ pg ._register_backend (
459
+ torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class
460
+ )
461
+ return pg
407
462
408
463
def getBackendName (self ) -> str :
409
464
return "torchft-gloo"
@@ -427,6 +482,28 @@ def reduce_scatter(
427
482
"""
428
483
raise RuntimeError ("ProcessGroupGloo does not support reduce_scatter." )
429
484
485
+ # pyre-fixme[15]: inconsistent override
486
+ def reduce_scatter_tensor_coalesced (
487
+ self ,
488
+ output_tensors : List [torch .Tensor ],
489
+ input_tensors : List [torch .Tensor ],
490
+ opts : ReduceScatterOptions ,
491
+ ) -> None :
492
+ """
493
+ This function is a placeholder for the reduce_scatter_tensor_coalesced
494
+ operation in the ProcessGroupGloo class.
495
+ However, this operation is not supported by the
496
+ Gloo backend, and thus, calling this function will raise a
497
+ RuntimeError.
498
+
499
+ Raises:
500
+ RuntimeError: Always raised since reduce_scatter is not
501
+ supported by ProcessGroupGloo.
502
+ """
503
+ raise RuntimeError (
504
+ "ProcessGroupGloo does not support reduce_scatter_tensor_coalesced."
505
+ )
506
+
430
507
431
508
class ProcessGroupNCCL (ProcessGroupWrapper ):
432
509
"""
@@ -440,8 +517,15 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
440
517
"""
441
518
442
519
def _create_pg (self , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
520
+ pg = BaseProcessGroup (store , rank , world_size )
521
+ pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
443
522
# pyre-fixme[16]: no attribute ProcessGroupNCCL
444
- return BaseProcessGroupNCCL (store , rank , world_size )
523
+ backend_class = BaseProcessGroupNCCL (store , rank , world_size )
524
+ backend_class ._set_sequence_number_for_group ()
525
+ pg ._register_backend (
526
+ torch .device ("cuda" ), ProcessGroup .BackendType .NCCL , backend_class
527
+ )
528
+ return pg
445
529
446
530
def getBackendName (self ) -> str :
447
531
return "torchft-nccl"
@@ -499,6 +583,19 @@ def allgather(
499
583
self ._work .append (res )
500
584
return res
501
585
586
+ def allgather_into_tensor_coalesced (
587
+ self ,
588
+ output_tensors : List [torch .Tensor ],
589
+ input_tensors : List [torch .Tensor ],
590
+ opts : AllgatherOptions ,
591
+ ) -> Work :
592
+ for o , i in zip (output_tensors , input_tensors ):
593
+ o .copy_ (i )
594
+
595
+ res = _DummyWork (output_tensors )
596
+ self ._work .append (res )
597
+ return res
598
+
502
599
def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
503
600
res = _DummyWork (tensors )
504
601
self ._work .append (res )
@@ -548,6 +645,19 @@ def reduce_scatter(
548
645
self ._work .append (res )
549
646
return res
550
647
648
+ def reduce_scatter_tensor_coalesced (
649
+ self ,
650
+ output_tensors : List [torch .Tensor ],
651
+ input_tensors : List [torch .Tensor ],
652
+ opts : ReduceScatterOptions ,
653
+ ) -> Work :
654
+ for o , i in zip (output_tensors , input_tensors ):
655
+ o .copy_ (i )
656
+
657
+ res = _DummyWork (output_tensors )
658
+ self ._work .append (res )
659
+ return res
660
+
551
661
def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
552
662
return _DummyWork (None )
553
663
@@ -1134,6 +1244,20 @@ def allgather(
1134
1244
_maybe_share_tensors (input_tensor )
1135
1245
return self ._run_func ("allgather" , output_tensors , input_tensor , opts )
1136
1246
1247
+ def allgather_into_tensor_coalesced (
1248
+ self ,
1249
+ output_tensors : List [torch .Tensor ],
1250
+ input_tensors : List [torch .Tensor ],
1251
+ opts : AllgatherOptions ,
1252
+ ) -> Work :
1253
+ _assert_list (output_tensors )
1254
+ _assert_list (input_tensors )
1255
+ _maybe_share_tensors (output_tensors )
1256
+ _maybe_share_tensors (input_tensors )
1257
+ return self ._run_func (
1258
+ "allgather_into_tensor_coalesced" , output_tensors , input_tensors , opts
1259
+ )
1260
+
1137
1261
def allreduce (
1138
1262
self ,
1139
1263
tensors : List [torch .Tensor ],
@@ -1200,6 +1324,20 @@ def reduce_scatter(
1200
1324
_maybe_share_tensors (input_tensors )
1201
1325
return self ._run_func ("reduce_scatter" , output_tensors , input_tensors , opts )
1202
1326
1327
+ def reduce_scatter_tensor_coalesced (
1328
+ self ,
1329
+ output_tensors : List [torch .Tensor ],
1330
+ input_tensors : List [torch .Tensor ],
1331
+ opts : ReduceScatterOptions ,
1332
+ ) -> Work :
1333
+ _assert_list (output_tensors )
1334
+ _assert_list (input_tensors )
1335
+ _maybe_share_tensors (output_tensors )
1336
+ _maybe_share_tensors (input_tensors )
1337
+ return self ._run_func (
1338
+ "reduce_scatter_tensor_coalesced" , output_tensors , input_tensors , opts
1339
+ )
1340
+
1203
1341
def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
1204
1342
_assert_list (tensors )
1205
1343
_maybe_share_tensors (tensors )
@@ -1278,8 +1416,14 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
1278
1416
1279
1417
@classmethod
1280
1418
def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
1419
+ pg = BaseProcessGroup (store , rank , world_size )
1420
+ pg ._set_default_backend (ProcessGroup .BackendType .GLOO )
1281
1421
# pyre-fixme[16]: no attribute ProcessGroupGloo
1282
- return BaseProcessGroupGloo (store , rank , world_size )
1422
+ backend_class = BaseProcessGroupGloo (store , rank , world_size )
1423
+ pg ._register_backend (
1424
+ torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class
1425
+ )
1426
+ return pg
1283
1427
1284
1428
def getBackendName (self ) -> str :
1285
1429
return "torchft-baby-gloo"
@@ -1303,6 +1447,28 @@ def reduce_scatter(
1303
1447
"""
1304
1448
raise RuntimeError ("ProcessGroupBabyGloo does not support reduce_scatter." )
1305
1449
1450
+ # pyre-fixme[15]: inconsistent override
1451
+ def reduce_scatter_tensor_coalesced (
1452
+ self ,
1453
+ output_tensors : List [torch .Tensor ],
1454
+ input_tensors : List [torch .Tensor ],
1455
+ opts : ReduceScatterOptions ,
1456
+ ) -> None :
1457
+ """
1458
+ This function is a placeholder for the reduce_scatter_tensor_coalesced
1459
+ operation in the ProcessGroupBabyGloo class.
1460
+ However, this operation is not supported by the
1461
+ Gloo backend, and thus, calling this function will raise a
1462
+ RuntimeError.
1463
+
1464
+ Raises:
1465
+ RuntimeError: Always raised since reduce_scatter is not
1466
+ supported by ProcessGroupBabyGloo.
1467
+ """
1468
+ raise RuntimeError (
1469
+ "ProcessGroupBabyGloo does not support reduce_scatter_tensor_coalesced."
1470
+ )
1471
+
1306
1472
1307
1473
class ProcessGroupBabyNCCL (ProcessGroupBaby ):
1308
1474
"""
@@ -1322,8 +1488,15 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
1322
1488
1323
1489
@classmethod
1324
1490
def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
1491
+ pg = BaseProcessGroup (store , rank , world_size )
1492
+ pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
1325
1493
# pyre-fixme[16]: no attribute ProcessGroupNCCL
1326
- return BaseProcessGroupNCCL (store , rank , world_size )
1494
+ backend_class = BaseProcessGroupNCCL (store , rank , world_size )
1495
+ backend_class ._set_sequence_number_for_group ()
1496
+ pg ._register_backend (
1497
+ torch .device ("cuda" ), ProcessGroup .BackendType .NCCL , backend_class
1498
+ )
1499
+ return pg
1327
1500
1328
1501
def getBackendName (self ) -> str :
1329
1502
return "torchft-baby-nccl"
0 commit comments