Skip to content

Commit d427bef

Browse files
allenwang28Allen Wang
andauthored
Adds support for allgather_into_tensor_coalesced and reduce_scatter_tensor_coalesced (#114)
* initial commit to add final collectives * adds tests, modifies process group creation to register a backend * slight cleanups --------- Co-authored-by: Allen Wang <[email protected]>
1 parent c782f4e commit d427bef

File tree

2 files changed

+313
-7
lines changed

2 files changed

+313
-7
lines changed

torchft/process_group.py

Lines changed: 178 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,20 @@ def allgather(
124124
"""
125125
raise NotImplementedError("not implemented")
126126

127+
# pyre-fixme[14]: inconsistent override
128+
def allgather_into_tensor_coalesced(
129+
self,
130+
output_tensors: List[torch.Tensor],
131+
input_tensors: List[torch.Tensor],
132+
opts: AllgatherOptions,
133+
) -> Work:
134+
"""
135+
Performs an allgather operation on coalesced tensors.
136+
137+
See torch.distributed.allgather_coalesced for more details.
138+
"""
139+
raise NotImplementedError("not implemented")
140+
127141
# pyre-fixme[14]: inconsistent override
128142
def allreduce(
129143
self,
@@ -212,6 +226,20 @@ def reduce_scatter(
212226
"""
213227
raise NotImplementedError("not implemented")
214228

229+
# pyre-fixme[14]: inconsistent override
230+
def reduce_scatter_tensor_coalesced(
231+
self,
232+
output_tensors: List[torch.Tensor],
233+
input_tensors: List[torch.Tensor],
234+
opts: ReduceScatterOptions,
235+
) -> Work:
236+
"""
237+
Performs a reduce-scatter operation on coalesced tensors.
238+
239+
See torch.distributed.reduce_scatter_tensor for more details.
240+
"""
241+
raise NotImplementedError("not implemented")
242+
215243
# pyre-fixme[14]: inconsistent override
216244
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
217245
"""
@@ -336,10 +364,20 @@ def allgather(
336364
self,
337365
output_tensors: List[List[torch.Tensor]],
338366
input_tensor: List[torch.Tensor],
339-
opts: object,
367+
opts: AllgatherOptions,
340368
) -> Work:
341369
return self.parent.allgather(output_tensors, input_tensor, opts)
342370

371+
def allgather_into_tensor_coalesced(
372+
self,
373+
output_tensors: List[torch.Tensor],
374+
input_tensors: List[torch.Tensor],
375+
opts: AllgatherOptions,
376+
) -> Work:
377+
return self.parent.allgather_into_tensor_coalesced(
378+
output_tensors, input_tensors, opts
379+
)
380+
343381
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
344382
return self.parent.allreduce(tensors, opts)
345383

@@ -377,6 +415,16 @@ def reduce_scatter(
377415
) -> Work:
378416
return self.parent.reduce_scatter(output_tensors, input_tensors, opts)
379417

418+
def reduce_scatter_tensor_coalesced(
419+
self,
420+
output_tensors: List[torch.Tensor],
421+
input_tensors: List[torch.Tensor],
422+
opts: ReduceScatterOptions,
423+
) -> Work:
424+
return self.parent.reduce_scatter_tensor_coalesced(
425+
output_tensors, input_tensors, opts
426+
)
427+
380428
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
381429
return self.parent.send(tensors, dst_rank, tag)
382430

@@ -402,8 +450,15 @@ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
402450
self._timeout = timeout
403451

404452
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
453+
pg = BaseProcessGroup(store, rank, world_size)
454+
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
405455
# pyre-fixme[16]: no attribute ProcessGroupGloo
406-
return BaseProcessGroupGloo(store, rank, world_size, self._timeout)
456+
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
457+
backend_class._set_sequence_number_for_group()
458+
pg._register_backend(
459+
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
460+
)
461+
return pg
407462

408463
def getBackendName(self) -> str:
409464
return "torchft-gloo"
@@ -427,6 +482,28 @@ def reduce_scatter(
427482
"""
428483
raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.")
429484

485+
# pyre-fixme[15]: inconsistent override
486+
def reduce_scatter_tensor_coalesced(
487+
self,
488+
output_tensors: List[torch.Tensor],
489+
input_tensors: List[torch.Tensor],
490+
opts: ReduceScatterOptions,
491+
) -> None:
492+
"""
493+
This function is a placeholder for the reduce_scatter_tensor_coalesced
494+
operation in the ProcessGroupGloo class.
495+
However, this operation is not supported by the
496+
Gloo backend, and thus, calling this function will raise a
497+
RuntimeError.
498+
499+
Raises:
500+
RuntimeError: Always raised since reduce_scatter is not
501+
supported by ProcessGroupGloo.
502+
"""
503+
raise RuntimeError(
504+
"ProcessGroupGloo does not support reduce_scatter_tensor_coalesced."
505+
)
506+
430507

431508
class ProcessGroupNCCL(ProcessGroupWrapper):
432509
"""
@@ -440,8 +517,15 @@ class ProcessGroupNCCL(ProcessGroupWrapper):
440517
"""
441518

442519
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
520+
pg = BaseProcessGroup(store, rank, world_size)
521+
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
443522
# pyre-fixme[16]: no attribute ProcessGroupNCCL
444-
return BaseProcessGroupNCCL(store, rank, world_size)
523+
backend_class = BaseProcessGroupNCCL(store, rank, world_size)
524+
backend_class._set_sequence_number_for_group()
525+
pg._register_backend(
526+
torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
527+
)
528+
return pg
445529

446530
def getBackendName(self) -> str:
447531
return "torchft-nccl"
@@ -499,6 +583,19 @@ def allgather(
499583
self._work.append(res)
500584
return res
501585

586+
def allgather_into_tensor_coalesced(
587+
self,
588+
output_tensors: List[torch.Tensor],
589+
input_tensors: List[torch.Tensor],
590+
opts: AllgatherOptions,
591+
) -> Work:
592+
for o, i in zip(output_tensors, input_tensors):
593+
o.copy_(i)
594+
595+
res = _DummyWork(output_tensors)
596+
self._work.append(res)
597+
return res
598+
502599
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
503600
res = _DummyWork(tensors)
504601
self._work.append(res)
@@ -548,6 +645,19 @@ def reduce_scatter(
548645
self._work.append(res)
549646
return res
550647

648+
def reduce_scatter_tensor_coalesced(
649+
self,
650+
output_tensors: List[torch.Tensor],
651+
input_tensors: List[torch.Tensor],
652+
opts: ReduceScatterOptions,
653+
) -> Work:
654+
for o, i in zip(output_tensors, input_tensors):
655+
o.copy_(i)
656+
657+
res = _DummyWork(output_tensors)
658+
self._work.append(res)
659+
return res
660+
551661
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
552662
return _DummyWork(None)
553663

@@ -1134,6 +1244,20 @@ def allgather(
11341244
_maybe_share_tensors(input_tensor)
11351245
return self._run_func("allgather", output_tensors, input_tensor, opts)
11361246

1247+
def allgather_into_tensor_coalesced(
1248+
self,
1249+
output_tensors: List[torch.Tensor],
1250+
input_tensors: List[torch.Tensor],
1251+
opts: AllgatherOptions,
1252+
) -> Work:
1253+
_assert_list(output_tensors)
1254+
_assert_list(input_tensors)
1255+
_maybe_share_tensors(output_tensors)
1256+
_maybe_share_tensors(input_tensors)
1257+
return self._run_func(
1258+
"allgather_into_tensor_coalesced", output_tensors, input_tensors, opts
1259+
)
1260+
11371261
def allreduce(
11381262
self,
11391263
tensors: List[torch.Tensor],
@@ -1200,6 +1324,20 @@ def reduce_scatter(
12001324
_maybe_share_tensors(input_tensors)
12011325
return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)
12021326

1327+
def reduce_scatter_tensor_coalesced(
1328+
self,
1329+
output_tensors: List[torch.Tensor],
1330+
input_tensors: List[torch.Tensor],
1331+
opts: ReduceScatterOptions,
1332+
) -> Work:
1333+
_assert_list(output_tensors)
1334+
_assert_list(input_tensors)
1335+
_maybe_share_tensors(output_tensors)
1336+
_maybe_share_tensors(input_tensors)
1337+
return self._run_func(
1338+
"reduce_scatter_tensor_coalesced", output_tensors, input_tensors, opts
1339+
)
1340+
12031341
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
12041342
_assert_list(tensors)
12051343
_maybe_share_tensors(tensors)
@@ -1278,8 +1416,14 @@ class ProcessGroupBabyGloo(ProcessGroupBaby):
12781416

12791417
@classmethod
12801418
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
1419+
pg = BaseProcessGroup(store, rank, world_size)
1420+
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
12811421
# pyre-fixme[16]: no attribute ProcessGroupGloo
1282-
return BaseProcessGroupGloo(store, rank, world_size)
1422+
backend_class = BaseProcessGroupGloo(store, rank, world_size)
1423+
pg._register_backend(
1424+
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
1425+
)
1426+
return pg
12831427

12841428
def getBackendName(self) -> str:
12851429
return "torchft-baby-gloo"
@@ -1303,6 +1447,28 @@ def reduce_scatter(
13031447
"""
13041448
raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.")
13051449

1450+
# pyre-fixme[15]: inconsistent override
1451+
def reduce_scatter_tensor_coalesced(
1452+
self,
1453+
output_tensors: List[torch.Tensor],
1454+
input_tensors: List[torch.Tensor],
1455+
opts: ReduceScatterOptions,
1456+
) -> None:
1457+
"""
1458+
This function is a placeholder for the reduce_scatter_tensor_coalesced
1459+
operation in the ProcessGroupBabyGloo class.
1460+
However, this operation is not supported by the
1461+
Gloo backend, and thus, calling this function will raise a
1462+
RuntimeError.
1463+
1464+
Raises:
1465+
RuntimeError: Always raised since reduce_scatter is not
1466+
supported by ProcessGroupBabyGloo.
1467+
"""
1468+
raise RuntimeError(
1469+
"ProcessGroupBabyGloo does not support reduce_scatter_tensor_coalesced."
1470+
)
1471+
13061472

13071473
class ProcessGroupBabyNCCL(ProcessGroupBaby):
13081474
"""
@@ -1322,8 +1488,15 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
13221488

13231489
@classmethod
13241490
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
1491+
pg = BaseProcessGroup(store, rank, world_size)
1492+
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
13251493
# pyre-fixme[16]: no attribute ProcessGroupNCCL
1326-
return BaseProcessGroupNCCL(store, rank, world_size)
1494+
backend_class = BaseProcessGroupNCCL(store, rank, world_size)
1495+
backend_class._set_sequence_number_for_group()
1496+
pg._register_backend(
1497+
torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
1498+
)
1499+
return pg
13271500

13281501
def getBackendName(self) -> str:
13291502
return "torchft-baby-nccl"

0 commit comments

Comments
 (0)