Skip to content

Commit 30296e9

Browse files
committed
ProcessGroupBabyNCCL: support multiple streams and use event on start
1 parent 68e1d28 commit 30296e9

File tree

2 files changed

+168
-49
lines changed

2 files changed

+168
-49
lines changed

torchft/process_group.py

+137-46
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@
1919
import logging
2020
import queue
2121
import threading
22+
from contextlib import contextmanager, nullcontext
2223
from dataclasses import dataclass
2324
from datetime import timedelta
2425
from typing import (
2526
TYPE_CHECKING,
2627
Any,
2728
Callable,
2829
Dict,
30+
Generator,
2931
List,
3032
Optional,
3133
Tuple,
32-
Type,
3334
TypeVar,
3435
Union,
3536
cast,
@@ -58,9 +59,9 @@
5859
BroadcastOptions,
5960
ReduceOp,
6061
Work,
61-
_world,
6262
)
6363
from torch.futures import Future
64+
from torch.utils._pytree import tree_any
6465

6566
if TYPE_CHECKING:
6667
from torchft.manager import Manager
@@ -586,29 +587,52 @@ def __init__(
586587
self._timeout = timeout
587588

588589
def wait(self, timeout: Optional[timedelta] = None) -> bool:
590+
self._pg._assert_alive()
591+
589592
self._tx.put(("wait", self._op_id), timeout=self._timeout)
590-
assert _get(self._rx, self._timeout) == self._op_id
593+
op_id, event = cast(
594+
Tuple[int, Optional[torch.cuda.Event]],
595+
_get(self._rx, timeout or self._timeout),
596+
)
597+
assert op_id == self._op_id
598+
if event is not None:
599+
event.wait()
591600
return True
592601

602+
def synchronize(self) -> None:
603+
# TODO: No one seems to use this and NCCL wait already only waits the
604+
# stream and is non-blocking on the CPU side so no real need for a
605+
# separate call.
606+
raise NotImplementedError("not implemented")
607+
593608
def get_future(self) -> Future[object]:
594609
return self._pg._get_future(self._op_id)
595610

596611
def __del__(self) -> None:
597612
self._tx.put(("del", self._op_id), timeout=self._timeout)
598613

599614

600-
class _BabyWorkNCCL(_BabyWork):
601-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
602-
self._tx.put(("synchronize", self._op_id), timeout=self._timeout)
603-
# pyre-fixme[23]: unable to unpack into 2 values
604-
op_id, event = _get(self._rx, self._timeout)
605-
assert op_id == self._op_id
606-
assert isinstance(event, torch.cuda.Event)
615+
def _is_any_cuda(obj: object) -> bool:
616+
"""
617+
Returns true if any of the tensors in the object are CUDA tensors.
607618
608-
# Wait on Event makes the stream wait but not the CPU thread.
609-
event.wait()
619+
Supports lists, tuples, dicts, and tensors.
620+
"""
621+
return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_cuda, obj)
610622

611-
return True
623+
624+
@dataclass
625+
class _OpMetadata:
626+
work: Work
627+
stream: Optional[torch.cuda.Stream]
628+
629+
@contextmanager
630+
def set_stream(self) -> Generator[None, None, None]:
631+
if self.stream is not None:
632+
with torch.cuda.stream(self.stream):
633+
yield
634+
else:
635+
yield
612636

613637

614638
class ProcessGroupBaby(ProcessGroup):
@@ -617,11 +641,8 @@ class ProcessGroupBaby(ProcessGroup):
617641
subprocess. Since it's running in a subprocess all tensors need to be in
618642
shared memory or will be moved to shared memory. CUDA tensors are implicitly
619643
share able and don't need any changes.
620-
621644
"""
622645

623-
WORK_CLASS: Type[_BabyWork] = _BabyWork
624-
625646
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
626647
super().__init__(0, 1)
627648

@@ -679,7 +700,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679700

680701
self._p = ctx.Process(
681702
target=self._worker,
682-
args=(store_addr, rank, world_size, self._tx, self._rx, self._future_queue),
703+
args=(
704+
store_addr,
705+
rank,
706+
world_size,
707+
self._tx,
708+
self._rx,
709+
self._future_queue,
710+
),
683711
daemon=True,
684712
)
685713
self._p.start()
@@ -716,23 +744,70 @@ def _worker(
716744
return
717745
tx.put(None)
718746

719-
work = {}
747+
streams: Dict[str, torch.cuda.Stream] = {}
748+
work: Dict[int, _OpMetadata] = {}
720749
next_op_id: int = 0
721750

722751
while True:
723752
op = rx.get()
724753
cmd = op[0]
725754
if cmd == "func":
726-
func_name, args, kwargs = op[1:]
727-
args = _PickleSafeOptions.unsafe_args(args)
728-
fn = getattr(pg, func_name)
729-
work[next_op_id] = fn(*args, **kwargs)
755+
func_name, args, kwargs, stream_device, stream_id, event = op[1:]
756+
757+
# To avoid potential deadlocks we need to preserve the
758+
# stream/synchronization behavior of the parent process.
759+
# We allocate one Stream per stream_id to make sure that we
760+
# don't accidentally introduce cross stream synchronization
761+
# points.
762+
if stream_id is not None:
763+
stream_key = f"{stream_device}/{stream_id}"
764+
if stream_key not in streams:
765+
streams[stream_key] = torch.cuda.Stream(
766+
device=stream_device
767+
)
768+
stream = streams[stream_key]
769+
else:
770+
stream = None
771+
772+
with (
773+
torch.cuda.stream(stream)
774+
if stream is not None
775+
else nullcontext()
776+
):
777+
# Make the stream wait on the cuda event to make sure we
778+
# don't start the operation until the tensor is ready.
779+
if event is not None:
780+
event.wait()
781+
782+
args = _PickleSafeOptions.unsafe_args(args)
783+
fn = getattr(pg, func_name)
784+
work[next_op_id] = _OpMetadata(
785+
work=fn(*args, **kwargs),
786+
stream=stream,
787+
)
730788
tx.put(next_op_id)
731789
next_op_id += 1
732790
elif cmd == "wait":
733791
op_id: int = op[1]
734-
work[op_id].wait()
735-
tx.put(op_id)
792+
793+
metadata = work[op_id]
794+
795+
with metadata.set_stream():
796+
# With WorkNCCL this makes the stream wait not the CPU when
797+
# no timeout is passed.
798+
metadata.work.wait()
799+
800+
# Register event on the stream that we can pass to the main
801+
# process.
802+
event = (
803+
torch.cuda.current_stream().record_event(
804+
torch.cuda.Event(interprocess=True)
805+
)
806+
if metadata.stream is not None
807+
else None
808+
)
809+
810+
tx.put((op_id, event))
736811
elif cmd == "del":
737812
op_id: int = op[1]
738813
del work[op_id]
@@ -746,23 +821,8 @@ def callback(fut: Future[object]) -> None:
746821
except Exception as e:
747822
future_queue.put((op_id, _FUTURE_EXCEPTION, e))
748823

749-
work[op_id].get_future().add_done_callback(callback)
824+
work[op_id].work.get_future().add_done_callback(callback)
750825
tx.put(op_id)
751-
elif cmd == "synchronize":
752-
# CUDA only, use events instead of waiting on CPU
753-
op_id = op[1]
754-
755-
# With WorkNCCL this makes the stream wait not the CPU when
756-
# no timeout is passed.
757-
work[op_id].wait()
758-
759-
# Register event on the stream that we can pass to the main
760-
# process.
761-
event = torch.cuda.Event(interprocess=True)
762-
event.record()
763-
764-
del work[op_id]
765-
tx.put((op_id, event))
766826
elif cmd == "num_active_work":
767827
tx.put(len(work))
768828
else:
@@ -771,6 +831,7 @@ def callback(fut: Future[object]) -> None:
771831
except Exception as e:
772832
logger.exception("worker errored")
773833
tx.put(e)
834+
raise
774835

775836
def _future_handler(self, future_queue: mp.Queue) -> None:
776837
try:
@@ -792,6 +853,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792853
logger.exception(f"got unexpected error in future handler: {e}")
793854

794855
def _get_future(self, op_id: int) -> Future[object]:
856+
self._assert_alive()
857+
795858
with self._futures_lock:
796859
fut = Future() # pyre-fixme[29]: is not a function
797860
self._futures[op_id] = fut
@@ -804,22 +867,52 @@ def _get_future(self, op_id: int) -> Future[object]:
804867
return fut
805868

806869
def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
870+
self._assert_alive()
871+
807872
rx = self._rx
808873
tx = self._tx
809874
assert rx is not None
810875
assert tx is not None
811876

877+
is_cuda = _is_any_cuda(args)
878+
879+
stream_device = torch.cuda.current_stream().device if is_cuda else None
880+
stream_id = torch.cuda.current_stream().stream_id if is_cuda else None
881+
event = (
882+
torch.cuda.current_stream().record_event(
883+
torch.cuda.Event(interprocess=True)
884+
)
885+
if is_cuda
886+
else None
887+
)
888+
812889
tx.put(
813-
("func", func, _PickleSafeOptions.safe_args(args), kwargs),
890+
(
891+
"func",
892+
func,
893+
_PickleSafeOptions.safe_args(args),
894+
kwargs,
895+
stream_device,
896+
stream_id,
897+
event,
898+
),
814899
timeout=self._timeout,
815900
)
816901

817902
op_id = _get(rx, self._timeout)
818903
assert isinstance(op_id, int), f"invalid return {op_id}"
819904

820-
return self.WORK_CLASS(
821-
pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout
822-
)
905+
return _BabyWork(pg=self, tx=tx, rx=rx, op_id=op_id, timeout=self._timeout)
906+
907+
def _assert_alive(self) -> None:
908+
"""
909+
Assert that the process group is alive. This is used to ensure that
910+
operations are not performed on a dead process group and any errors are surfaced.
911+
"""
912+
p = self._p
913+
assert p is not None
914+
if not p.is_alive():
915+
raise RuntimeError(f"child process {p.pid=} is dead {p.exitcode=}")
823916

824917
def allreduce(
825918
self,
@@ -952,8 +1045,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
9521045
tensors may leak in the current PyTorch implementation. TODO fix
9531046
"""
9541047

955-
WORK_CLASS = _BabyWorkNCCL
956-
9571048
@classmethod
9581049
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
9591050
# pyre-fixme[16]: no attribute ProcessGroupNCCL

torchft/process_group_test.py

+31-3
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,31 @@ def test_baby_gloo_apis(self) -> None:
266266

267267
self.assertEqual(a.num_active_work(), 0)
268268

269+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
270+
@skipUnless(torch.cuda.is_available(), "needs CUDA")
271+
def test_baby_nccl_apis(self) -> None:
272+
# set to 1 if more than >=2 gpus
273+
device_id = 1 % torch.cuda.device_count()
274+
torch.cuda.set_device(device_id)
275+
276+
store = TCPStore(
277+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
278+
)
279+
280+
store_addr = f"localhost:{store.port}/prefix"
281+
282+
a = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10))
283+
a.configure(store_addr, 0, 1)
284+
285+
_test_pg(a, torch.randn((2, 3), device="cuda"))
286+
287+
torch.cuda.synchronize()
288+
289+
# force collection to ensure no BabyWork objects remain
290+
gc.collect()
291+
292+
self.assertEqual(a.num_active_work(), 0)
293+
269294
def test_dummy(self) -> None:
270295
pg = ProcessGroupDummy(0, 1)
271296
m = nn.Linear(3, 4)
@@ -282,12 +307,15 @@ def test_baby_nccl_2gpu(self) -> None:
282307
store_addr: str = f"localhost:{store.port}/prefix"
283308

284309
def run(rank: int) -> Tuple[torch.Tensor, Work]:
285-
a = ProcessGroupBabyNCCL()
310+
a = ProcessGroupBabyNCCL(
311+
timeout=timedelta(seconds=10.0),
312+
)
286313
a.configure(store_addr, rank, 2)
287-
288314
self.assertEqual(a.size(), 2)
289315

290-
at = torch.tensor([rank + 1], device=f"cuda:{rank}")
316+
# We test using set_device to ensure stream device is correct.
317+
torch.cuda.set_device(rank)
318+
at = torch.tensor([rank + 1], device="cuda")
291319

292320
a_work = a.allreduce([at], ReduceOp.SUM)
293321
return at, a_work

0 commit comments

Comments
 (0)