@@ -308,6 +308,7 @@ def __init__(
308308 overlap : bool = False ,
309309 process_group : dist .ProcessGroup = None ,
310310 is_moe : bool = False ,
311+ early_reduce_scatter_release : bool = True ,
311312 ) -> None :
312313 self .process_group = process_group
313314 self .overlap = overlap
@@ -317,6 +318,11 @@ def __init__(
317318 self .reduce_scatter_handlers = {}
318319 self ._module_shapes = {}
319320 self ._forward_prefetch_prerequisites = []
321+ self ._zero_const_pool = {}
322+
323+ self ._enable_early_reduce_scatter_release = early_reduce_scatter_release
324+ self ._early_prev_layer_rs_handles = []
325+ self ._early_curr_layer_rs_handles = []
320326
321327 # real overlap state for each chunk.
322328 self ._overlap_states : Dict [int , ISPOverlapState ] = {}
@@ -510,6 +516,13 @@ def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W061
510516 self ._clear_handle (module )
511517 self ._clear_weight (module )
512518
519+ def _early_reduce_scatter_release_hook (self , * args ): # pylint: disable=W0613
520+ for handle in self ._early_prev_layer_rs_handles :
521+ handle .wait ()
522+
523+ self ._early_prev_layer_rs_handles = self ._early_curr_layer_rs_handles
524+ self ._early_curr_layer_rs_handles = []
525+
513526 def _register_sync_parameters_hook (self ) -> None :
514527 """
515528 register forward hooks and backward hooks for isp modules.
@@ -545,12 +558,18 @@ def _register_sync_parameters_hook(self) -> None:
545558 for module in self ._isp_modules :
546559 module .register_full_backward_hook (self ._post_backward_hook_for_module )
547560
561+ if self ._enable_early_reduce_scatter_release :
562+ for block_idx in range (self ._num_blocks ):
563+ block = self ._index_to_block [block_idx ]
564+ block .register_full_backward_hook (self ._early_reduce_scatter_release_hook )
565+
548566 def _get_constant_zero (self , size : tuple ) -> torch .Tensor :
549- return torch .zeros (
550- * size ,
551- dtype = self .model_conf .dtype ,
552- device = self .model_conf .device ,
553- ).contiguous ()
567+ if size not in self ._zero_const_pool :
568+ self ._zero_const_pool [size ] = torch .zeros (
569+ * size , dtype = self .model_conf .dtype , device = self .model_conf .device
570+ ).contiguous ()
571+
572+ return self ._zero_const_pool [size ]
554573
555574 def communication_mode (self ) -> str :
556575 return "wp"
@@ -637,13 +656,18 @@ def grad_hook(
637656 assert hasattr (module .weight , "isp_reduce_scatter_name" )
638657 key = getattr (module .weight , "isp_reduce_scatter_name" )
639658
640- self . reduce_scatter_handlers [ key ] = reduce_scatter_raw (
659+ output , handle = reduce_scatter_raw (
641660 tensor ,
642661 self .process_group ,
643662 op = reduce_op ,
644663 async_op = async_op ,
645664 )
646665
666+ if self ._enable_early_reduce_scatter_release :
667+ self ._early_curr_layer_rs_handles .append (handle )
668+
669+ self .reduce_scatter_handlers [key ] = (output , handle )
670+
647671 result , handle = (
648672 self ._get_constant_zero (
649673 (
@@ -698,6 +722,10 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06
698722 ):
699723 self ._zero_optim .reduce_left_grads_after_backward ()
700724
725+ if self ._isp_communicator and self ._isp_communicator ._enable_early_reduce_scatter_release :
726+ self ._isp_communicator ._early_prev_layer_rs_handles = []
727+ self ._isp_communicator ._early_curr_layer_rs_handles = []
728+
701729 def post_helper_func (self , scheduler , outputs , label ) -> None : # pylint: disable=W0613
702730 pass
703731
@@ -872,9 +900,7 @@ def _q_kv(self, q: torch.Tensor, kv: torch.Tensor, *args, **kwargs) -> torch.Ten
872900
873901 q , kv = _SeqAllToAll .apply (self .spg , [2 , 3 ], [1 , 1 ], q , kv )
874902
875- torch .cuda .synchronize ()
876903 context = self .local_attn (q , kv , * args , ** kwargs )
877- torch .cuda .synchronize ()
878904
879905 context = _SeqAllToAll .apply (self .spg , 1 , 2 , context )
880906
0 commit comments