19
19
import logging
20
20
import queue
21
21
import threading
22
+ from contextlib import contextmanager , nullcontext
22
23
from dataclasses import dataclass
23
24
from datetime import timedelta
24
25
from typing import (
25
26
TYPE_CHECKING ,
26
27
Any ,
27
28
Callable ,
28
29
Dict ,
30
+ Generator ,
29
31
List ,
30
32
Optional ,
31
33
Tuple ,
32
- Type ,
33
34
TypeVar ,
34
35
Union ,
35
36
cast ,
58
59
BroadcastOptions ,
59
60
ReduceOp ,
60
61
Work ,
61
- _world ,
62
62
)
63
63
from torch .futures import Future
64
+ from torch .utils ._pytree import tree_any
64
65
65
66
if TYPE_CHECKING :
66
67
from torchft .manager import Manager
@@ -586,29 +587,52 @@ def __init__(
586
587
self ._timeout = timeout
587
588
588
589
def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
590
+ self ._pg ._assert_alive ()
591
+
589
592
self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
590
- assert _get (self ._rx , self ._timeout ) == self ._op_id
593
+ op_id , event = cast (
594
+ Tuple [int , Optional [torch .cuda .Event ]],
595
+ _get (self ._rx , timeout or self ._timeout ),
596
+ )
597
+ assert op_id == self ._op_id
598
+ if event is not None :
599
+ event .wait ()
591
600
return True
592
601
602
+ def synchronize (self ) -> None :
603
+ # TODO: No one seems to use this and NCCL wait already only waits the
604
+ # stream and is non-blocking on the CPU side so no real need for a
605
+ # separate call.
606
+ raise NotImplementedError ("not implemented" )
607
+
593
608
def get_future (self ) -> Future [object ]:
594
609
return self ._pg ._get_future (self ._op_id )
595
610
596
611
def __del__ (self ) -> None :
597
612
self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598
613
599
614
600
- class _BabyWorkNCCL (_BabyWork ):
601
- def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
602
- self ._tx .put (("synchronize" , self ._op_id ), timeout = self ._timeout )
603
- # pyre-fixme[23]: unable to unpack into 2 values
604
- op_id , event = _get (self ._rx , self ._timeout )
605
- assert op_id == self ._op_id
606
- assert isinstance (event , torch .cuda .Event )
615
+ def _is_any_cuda (obj : object ) -> bool :
616
+ """
617
+ Returns true if any of the tensors in the object are CUDA tensors.
607
618
608
- # Wait on Event makes the stream wait but not the CPU thread.
609
- event .wait ()
619
+ Supports lists, tuples, dicts, and tensors.
620
+ """
621
+ return tree_any (lambda obj : isinstance (obj , torch .Tensor ) and obj .is_cuda , obj )
610
622
611
- return True
623
+
624
+ @dataclass
625
+ class _OpMetadata :
626
+ work : Work
627
+ stream : Optional [torch .cuda .Stream ]
628
+
629
+ @contextmanager
630
+ def set_stream (self ) -> Generator [None , None , None ]:
631
+ if self .stream is not None :
632
+ with torch .cuda .stream (self .stream ):
633
+ yield
634
+ else :
635
+ yield
612
636
613
637
614
638
class ProcessGroupBaby (ProcessGroup ):
@@ -617,11 +641,8 @@ class ProcessGroupBaby(ProcessGroup):
617
641
subprocess. Since it's running in a subprocess all tensors need to be in
618
642
shared memory or will be moved to shared memory. CUDA tensors are implicitly
619
643
share able and don't need any changes.
620
-
621
644
"""
622
645
623
- WORK_CLASS : Type [_BabyWork ] = _BabyWork
624
-
625
646
def __init__ (self , timeout : Union [float , timedelta ] = 60.0 ) -> None :
626
647
super ().__init__ (0 , 1 )
627
648
@@ -679,7 +700,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679
700
680
701
self ._p = ctx .Process (
681
702
target = self ._worker ,
682
- args = (store_addr , rank , world_size , self ._tx , self ._rx , self ._future_queue ),
703
+ args = (
704
+ store_addr ,
705
+ rank ,
706
+ world_size ,
707
+ self ._tx ,
708
+ self ._rx ,
709
+ self ._future_queue ,
710
+ ),
683
711
daemon = True ,
684
712
)
685
713
self ._p .start ()
@@ -716,23 +744,70 @@ def _worker(
716
744
return
717
745
tx .put (None )
718
746
719
- work = {}
747
+ streams : Dict [str , torch .cuda .Stream ] = {}
748
+ work : Dict [int , _OpMetadata ] = {}
720
749
next_op_id : int = 0
721
750
722
751
while True :
723
752
op = rx .get ()
724
753
cmd = op [0 ]
725
754
if cmd == "func" :
726
- func_name , args , kwargs = op [1 :]
727
- args = _PickleSafeOptions .unsafe_args (args )
728
- fn = getattr (pg , func_name )
729
- work [next_op_id ] = fn (* args , ** kwargs )
755
+ func_name , args , kwargs , stream_device , stream_id , event = op [1 :]
756
+
757
+ # To avoid potential deadlocks we need to preserve the
758
+ # stream/synchronization behavior of the parent process.
759
+ # We allocate one Stream per stream_id to make sure that we
760
+ # don't accidentally introduce cross stream synchronization
761
+ # points.
762
+ if stream_id is not None :
763
+ stream_key = f"{ stream_device } /{ stream_id } "
764
+ if stream_key not in streams :
765
+ streams [stream_key ] = torch .cuda .Stream (
766
+ device = stream_device
767
+ )
768
+ stream = streams [stream_key ]
769
+ else :
770
+ stream = None
771
+
772
+ with (
773
+ torch .cuda .stream (stream )
774
+ if stream is not None
775
+ else nullcontext ()
776
+ ):
777
+ # Make the stream wait on the cuda event to make sure we
778
+ # don't start the operation until the tensor is ready.
779
+ if event is not None :
780
+ event .wait ()
781
+
782
+ args = _PickleSafeOptions .unsafe_args (args )
783
+ fn = getattr (pg , func_name )
784
+ work [next_op_id ] = _OpMetadata (
785
+ work = fn (* args , ** kwargs ),
786
+ stream = stream ,
787
+ )
730
788
tx .put (next_op_id )
731
789
next_op_id += 1
732
790
elif cmd == "wait" :
733
791
op_id : int = op [1 ]
734
- work [op_id ].wait ()
735
- tx .put (op_id )
792
+
793
+ metadata = work [op_id ]
794
+
795
+ with metadata .set_stream ():
796
+ # With WorkNCCL this makes the stream wait not the CPU when
797
+ # no timeout is passed.
798
+ metadata .work .wait ()
799
+
800
+ # Register event on the stream that we can pass to the main
801
+ # process.
802
+ event = (
803
+ torch .cuda .current_stream ().record_event (
804
+ torch .cuda .Event (interprocess = True )
805
+ )
806
+ if metadata .stream is not None
807
+ else None
808
+ )
809
+
810
+ tx .put ((op_id , event ))
736
811
elif cmd == "del" :
737
812
op_id : int = op [1 ]
738
813
del work [op_id ]
@@ -746,23 +821,8 @@ def callback(fut: Future[object]) -> None:
746
821
except Exception as e :
747
822
future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
748
823
749
- work [op_id ].get_future ().add_done_callback (callback )
824
+ work [op_id ].work . get_future ().add_done_callback (callback )
750
825
tx .put (op_id )
751
- elif cmd == "synchronize" :
752
- # CUDA only, use events instead of waiting on CPU
753
- op_id = op [1 ]
754
-
755
- # With WorkNCCL this makes the stream wait not the CPU when
756
- # no timeout is passed.
757
- work [op_id ].wait ()
758
-
759
- # Register event on the stream that we can pass to the main
760
- # process.
761
- event = torch .cuda .Event (interprocess = True )
762
- event .record ()
763
-
764
- del work [op_id ]
765
- tx .put ((op_id , event ))
766
826
elif cmd == "num_active_work" :
767
827
tx .put (len (work ))
768
828
else :
@@ -771,6 +831,7 @@ def callback(fut: Future[object]) -> None:
771
831
except Exception as e :
772
832
logger .exception ("worker errored" )
773
833
tx .put (e )
834
+ raise
774
835
775
836
def _future_handler (self , future_queue : mp .Queue ) -> None :
776
837
try :
@@ -792,6 +853,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792
853
logger .exception (f"got unexpected error in future handler: { e } " )
793
854
794
855
def _get_future (self , op_id : int ) -> Future [object ]:
856
+ self ._assert_alive ()
857
+
795
858
with self ._futures_lock :
796
859
fut = Future () # pyre-fixme[29]: is not a function
797
860
self ._futures [op_id ] = fut
@@ -804,22 +867,52 @@ def _get_future(self, op_id: int) -> Future[object]:
804
867
return fut
805
868
806
869
def _run_func (self , func : str , * args : object , ** kwargs : object ) -> Work :
870
+ self ._assert_alive ()
871
+
807
872
rx = self ._rx
808
873
tx = self ._tx
809
874
assert rx is not None
810
875
assert tx is not None
811
876
877
+ is_cuda = _is_any_cuda (args )
878
+
879
+ stream_device = torch .cuda .current_stream ().device if is_cuda else None
880
+ stream_id = torch .cuda .current_stream ().stream_id if is_cuda else None
881
+ event = (
882
+ torch .cuda .current_stream ().record_event (
883
+ torch .cuda .Event (interprocess = True )
884
+ )
885
+ if is_cuda
886
+ else None
887
+ )
888
+
812
889
tx .put (
813
- ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
890
+ (
891
+ "func" ,
892
+ func ,
893
+ _PickleSafeOptions .safe_args (args ),
894
+ kwargs ,
895
+ stream_device ,
896
+ stream_id ,
897
+ event ,
898
+ ),
814
899
timeout = self ._timeout ,
815
900
)
816
901
817
902
op_id = _get (rx , self ._timeout )
818
903
assert isinstance (op_id , int ), f"invalid return { op_id } "
819
904
820
- return self .WORK_CLASS (
821
- pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout
822
- )
905
+ return _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
906
+
907
+ def _assert_alive (self ) -> None :
908
+ """
909
+ Assert that the process group is alive. This is used to ensure that
910
+ operations are not performed on a dead process group and any errors are surfaced.
911
+ """
912
+ p = self ._p
913
+ assert p is not None
914
+ if not p .is_alive ():
915
+ raise RuntimeError (f"child process { p .pid = } is dead { p .exitcode = } " )
823
916
824
917
def allreduce (
825
918
self ,
@@ -952,8 +1045,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
952
1045
tensors may leak in the current PyTorch implementation. TODO fix
953
1046
"""
954
1047
955
- WORK_CLASS = _BabyWorkNCCL
956
-
957
1048
@classmethod
958
1049
def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
959
1050
# pyre-fixme[16]: no attribute ProcessGroupNCCL
0 commit comments