@@ -768,8 +768,6 @@ def __init__(
768
768
self ._start_batch = start_batch
769
769
self ._stash_gradients = stash_gradients
770
770
logger .debug (f"Starting semi-sync run at batch: { self ._start_batch } " )
771
-
772
- self ._embedding_streams : List [Optional [torch .Stream ]] = []
773
771
self ._gradients : Dict [str , torch .Tensor ] = {}
774
772
775
773
def _grad_swap (self ) -> None :
@@ -779,14 +777,6 @@ def _grad_swap(self) -> None:
779
777
self ._gradients [name ] = param .grad .clone ()
780
778
param .grad = grad
781
779
782
- def _init_embedding_streams (self ) -> None :
783
- for _ in self ._pipelined_modules :
784
- self ._embedding_streams .append (
785
- (torch .get_device_module (self ._device ).Stream (priority = 0 ))
786
- if self ._device .type in ["cuda" , "mtia" ]
787
- else None
788
- )
789
-
790
780
def _validate_optimizer (self ) -> None :
791
781
for pipelined_module in self ._pipelined_modules :
792
782
pipelined_params = set (pipelined_module .parameters ())
@@ -815,7 +805,6 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
815
805
# pyre-ignore [6]
816
806
EmbeddingPipelinedForward ,
817
807
)
818
- self ._init_embedding_streams ()
819
808
self .wait_sparse_data_dist (self .contexts [0 ])
820
809
self ._validate_optimizer ()
821
810
# pyre-ignore [6]
@@ -916,43 +905,36 @@ def _mlp_forward(
916
905
return self ._model_fwd (batch )
917
906
918
907
def embedding_backward (self , context : EmbeddingTrainPipelineContext ) -> None :
919
- default_stream = torch .get_device_module (self ._device ).current_stream ()
920
908
assert len (context .embedding_features ) == len (context .embedding_tensors )
921
- for stream , emb_tensors , embedding_features , detached_emb_tensors in zip (
922
- self ._embedding_streams ,
909
+ for emb_tensors , embedding_features , detached_emb_tensors in zip (
923
910
context .embedding_tensors ,
924
911
context .embedding_features ,
925
912
context .detached_embedding_tensors ,
926
913
):
927
- with self . _stream_context ( stream ):
928
- grads = [ tensor . grad for tensor in detached_emb_tensors ]
929
- if stream :
930
- stream . wait_stream ( default_stream )
931
- # Some embeddings may never get used in the final loss computation,
932
- # so the grads will be `None`. If we don't exclude these, it will fail
933
- # with error: "grad can be implicitly created only for scalar outputs"
934
- # Alternatively, if the tensor has only 1 element, pytorch can still
935
- # figure out how to do autograd
936
- embs_to_backprop , grads_to_use , invalid_features = [], [], []
937
- assert len ( embedding_features ) == len ( emb_tensors )
938
- for features , tensor , grad in zip (
939
- embedding_features , emb_tensors , grads
940
- ):
941
- if tensor . numel () == 1 or grad is not None :
942
- embs_to_backprop . append ( tensor )
943
- grads_to_use . append ( grad )
914
+ grads = [ tensor . grad for tensor in detached_emb_tensors ]
915
+ # Some embeddings may never get used in the final loss computation,
916
+ # so the grads will be `None`. If we don't exclude these, it will fail
917
+ # with error: "grad can be implicitly created only for scalar outputs"
918
+ # Alternatively, if the tensor has only 1 element, pytorch can still
919
+ # figure out how to do autograd
920
+ embs_to_backprop , grads_to_use , invalid_features = [], [], []
921
+ assert len ( embedding_features ) == len ( emb_tensors )
922
+ for features , tensor , grad in zip ( embedding_features , emb_tensors , grads ):
923
+ if tensor . numel () == 1 or grad is not None :
924
+ embs_to_backprop . append ( tensor )
925
+ grads_to_use . append ( grad )
926
+ else :
927
+ if isinstance ( features , str ):
928
+ invalid_features . append ( features )
929
+ elif isinstance ( features , Iterable ):
930
+ invalid_features . extend ( features )
944
931
else :
945
- if isinstance (features , str ):
946
- invalid_features .append (features )
947
- elif isinstance (features , Iterable ):
948
- invalid_features .extend (features )
949
- else :
950
- invalid_features .append (features )
951
- if invalid_features and context .index == 0 :
952
- logger .warning (
953
- f"SemiSync, the following features have no gradients: { invalid_features } "
954
- )
955
- torch .autograd .backward (embs_to_backprop , grads_to_use )
932
+ invalid_features .append (features )
933
+ if invalid_features and context .index == 0 :
934
+ logger .warning (
935
+ f"SemiSync, the following features have no gradients: { invalid_features } "
936
+ )
937
+ torch .autograd .backward (embs_to_backprop , grads_to_use )
956
938
957
939
def copy_batch_to_gpu (
958
940
self ,
@@ -1012,23 +994,21 @@ def start_embedding_lookup(
1012
994
"""
1013
995
if batch is None :
1014
996
return
997
+
1015
998
with record_function (f"## start_embedding_lookup { context .index } ##" ):
1016
- _wait_for_events (
1017
- batch , context , torch .get_device_module (self ._device ).current_stream ()
1018
- )
999
+ current_stream = torch .get_device_module (self ._device ).current_stream ()
1000
+ _wait_for_events (batch , context , current_stream )
1019
1001
for i , module in enumerate (self ._pipelined_modules ):
1020
- stream = self ._embedding_streams [i ]
1021
- with self ._stream_context (stream ):
1022
- _start_embedding_lookup (
1023
- module ,
1024
- context ,
1025
- source_stream = self ._data_dist_stream ,
1026
- target_stream = stream ,
1027
- stream_context = self ._stream_context ,
1028
- )
1029
- event = torch .get_device_module (self ._device ).Event ()
1030
- event .record ()
1031
- context .events .append (event )
1002
+ _start_embedding_lookup (
1003
+ module ,
1004
+ context ,
1005
+ source_stream = self ._data_dist_stream ,
1006
+ target_stream = current_stream ,
1007
+ stream_context = self ._stream_context ,
1008
+ )
1009
+ event = torch .get_device_module (self ._device ).Event ()
1010
+ event .record ()
1011
+ context .events .append (event )
1032
1012
1033
1013
1034
1014
class PrefetchTrainPipelineSparseDist (TrainPipelineSparseDist [In , Out ]):
0 commit comments