@@ -768,8 +768,6 @@ def __init__(
768768 self ._start_batch = start_batch
769769 self ._stash_gradients = stash_gradients
770770 logger .debug (f"Starting semi-sync run at batch: { self ._start_batch } " )
771-
772- self ._embedding_streams : List [Optional [torch .Stream ]] = []
773771 self ._gradients : Dict [str , torch .Tensor ] = {}
774772
775773 def _grad_swap (self ) -> None :
@@ -779,14 +777,6 @@ def _grad_swap(self) -> None:
779777 self ._gradients [name ] = param .grad .clone ()
780778 param .grad = grad
781779
782- def _init_embedding_streams (self ) -> None :
783- for _ in self ._pipelined_modules :
784- self ._embedding_streams .append (
785- (torch .get_device_module (self ._device ).Stream (priority = 0 ))
786- if self ._device .type in ["cuda" , "mtia" ]
787- else None
788- )
789-
790780 def _validate_optimizer (self ) -> None :
791781 for pipelined_module in self ._pipelined_modules :
792782 pipelined_params = set (pipelined_module .parameters ())
@@ -815,7 +805,6 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
815805 # pyre-ignore [6]
816806 EmbeddingPipelinedForward ,
817807 )
818- self ._init_embedding_streams ()
819808 self .wait_sparse_data_dist (self .contexts [0 ])
820809 self ._validate_optimizer ()
821810 # pyre-ignore [6]
@@ -916,43 +905,36 @@ def _mlp_forward(
916905 return self ._model_fwd (batch )
917906
918907 def embedding_backward (self , context : EmbeddingTrainPipelineContext ) -> None :
919- default_stream = torch .get_device_module (self ._device ).current_stream ()
920908 assert len (context .embedding_features ) == len (context .embedding_tensors )
921- for stream , emb_tensors , embedding_features , detached_emb_tensors in zip (
922- self ._embedding_streams ,
909+ for emb_tensors , embedding_features , detached_emb_tensors in zip (
923910 context .embedding_tensors ,
924911 context .embedding_features ,
925912 context .detached_embedding_tensors ,
926913 ):
927- with self . _stream_context ( stream ):
928- grads = [ tensor . grad for tensor in detached_emb_tensors ]
929- if stream :
930- stream . wait_stream ( default_stream )
931- # Some embeddings may never get used in the final loss computation,
932- # so the grads will be `None`. If we don't exclude these, it will fail
933- # with error: "grad can be implicitly created only for scalar outputs"
934- # Alternatively, if the tensor has only 1 element, pytorch can still
935- # figure out how to do autograd
936- embs_to_backprop , grads_to_use , invalid_features = [], [], []
937- assert len ( embedding_features ) == len ( emb_tensors )
938- for features , tensor , grad in zip (
939- embedding_features , emb_tensors , grads
940- ):
941- if tensor . numel () == 1 or grad is not None :
942- embs_to_backprop . append ( tensor )
943- grads_to_use . append ( grad )
914+ grads = [ tensor . grad for tensor in detached_emb_tensors ]
915+ # Some embeddings may never get used in the final loss computation,
916+ # so the grads will be `None`. If we don't exclude these, it will fail
917+ # with error: "grad can be implicitly created only for scalar outputs"
918+ # Alternatively, if the tensor has only 1 element, pytorch can still
919+ # figure out how to do autograd
920+ embs_to_backprop , grads_to_use , invalid_features = [], [], []
921+ assert len ( embedding_features ) == len ( emb_tensors )
922+ for features , tensor , grad in zip ( embedding_features , emb_tensors , grads ):
923+ if tensor . numel () == 1 or grad is not None :
924+ embs_to_backprop . append ( tensor )
925+ grads_to_use . append ( grad )
926+ else :
927+ if isinstance ( features , str ):
928+ invalid_features . append ( features )
929+ elif isinstance ( features , Iterable ):
930+ invalid_features . extend ( features )
944931 else :
945- if isinstance (features , str ):
946- invalid_features .append (features )
947- elif isinstance (features , Iterable ):
948- invalid_features .extend (features )
949- else :
950- invalid_features .append (features )
951- if invalid_features and context .index == 0 :
952- logger .warning (
953- f"SemiSync, the following features have no gradients: { invalid_features } "
954- )
955- torch .autograd .backward (embs_to_backprop , grads_to_use )
932+ invalid_features .append (features )
933+ if invalid_features and context .index == 0 :
934+ logger .warning (
935+ f"SemiSync, the following features have no gradients: { invalid_features } "
936+ )
937+ torch .autograd .backward (embs_to_backprop , grads_to_use )
956938
957939 def copy_batch_to_gpu (
958940 self ,
@@ -1012,23 +994,21 @@ def start_embedding_lookup(
1012994 """
1013995 if batch is None :
1014996 return
997+
1015998 with record_function (f"## start_embedding_lookup { context .index } ##" ):
1016- _wait_for_events (
1017- batch , context , torch .get_device_module (self ._device ).current_stream ()
1018- )
999+ current_stream = torch .get_device_module (self ._device ).current_stream ()
1000+ _wait_for_events (batch , context , current_stream )
10191001 for i , module in enumerate (self ._pipelined_modules ):
1020- stream = self ._embedding_streams [i ]
1021- with self ._stream_context (stream ):
1022- _start_embedding_lookup (
1023- module ,
1024- context ,
1025- source_stream = self ._data_dist_stream ,
1026- target_stream = stream ,
1027- stream_context = self ._stream_context ,
1028- )
1029- event = torch .get_device_module (self ._device ).Event ()
1030- event .record ()
1031- context .events .append (event )
1002+ _start_embedding_lookup (
1003+ module ,
1004+ context ,
1005+ source_stream = self ._data_dist_stream ,
1006+ target_stream = current_stream ,
1007+ stream_context = self ._stream_context ,
1008+ )
1009+ event = torch .get_device_module (self ._device ).Event ()
1010+ event .record ()
1011+ context .events .append (event )
10321012
10331013
10341014class PrefetchTrainPipelineSparseDist (TrainPipelineSparseDist [In , Out ]):
0 commit comments