Skip to content

Commit 48f1b94

Browse files
committed
feat(isp): add early_reduce_scatter_release support
1 parent 4a6b453 commit 48f1b94

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

internlm/core/parallel/comm/isp.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def __init__(
308308
overlap: bool = False,
309309
process_group: dist.ProcessGroup = None,
310310
is_moe: bool = False,
311+
early_reduce_scatter_release: bool = True,
311312
) -> None:
312313
self.process_group = process_group
313314
self.overlap = overlap
@@ -317,6 +318,11 @@ def __init__(
317318
self.reduce_scatter_handlers = {}
318319
self._module_shapes = {}
319320
self._forward_prefetch_prerequisites = []
321+
self._zero_const_pool = {}
322+
323+
self._enable_early_reduce_scatter_release = early_reduce_scatter_release
324+
self._early_prev_layer_rs_handles = []
325+
self._early_curr_layer_rs_handles = []
320326

321327
# real overlap state for each chunk.
322328
self._overlap_states: Dict[int, ISPOverlapState] = {}
@@ -510,6 +516,13 @@ def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W061
510516
self._clear_handle(module)
511517
self._clear_weight(module)
512518

519+
def _early_reduce_scatter_release_hook(self, *args): # pylint: disable=W0613
520+
for handle in self._early_prev_layer_rs_handles:
521+
handle.wait()
522+
523+
self._early_prev_layer_rs_handles = self._early_curr_layer_rs_handles
524+
self._early_curr_layer_rs_handles = []
525+
513526
def _register_sync_parameters_hook(self) -> None:
514527
"""
515528
register forward hooks and backward hooks for isp modules.
@@ -545,12 +558,18 @@ def _register_sync_parameters_hook(self) -> None:
545558
for module in self._isp_modules:
546559
module.register_full_backward_hook(self._post_backward_hook_for_module)
547560

561+
if self._enable_early_reduce_scatter_release:
562+
for block_idx in range(self._num_blocks):
563+
block = self._index_to_block[block_idx]
564+
block.register_full_backward_hook(self._early_reduce_scatter_release_hook)
565+
548566
def _get_constant_zero(self, size: tuple) -> torch.Tensor:
549-
return torch.zeros(
550-
*size,
551-
dtype=self.model_conf.dtype,
552-
device=self.model_conf.device,
553-
).contiguous()
567+
if size not in self._zero_const_pool:
568+
self._zero_const_pool[size] = torch.zeros(
569+
*size, dtype=self.model_conf.dtype, device=self.model_conf.device
570+
).contiguous()
571+
572+
return self._zero_const_pool[size]
554573

555574
def communication_mode(self) -> str:
556575
return "wp"
@@ -637,13 +656,18 @@ def grad_hook(
637656
assert hasattr(module.weight, "isp_reduce_scatter_name")
638657
key = getattr(module.weight, "isp_reduce_scatter_name")
639658

640-
self.reduce_scatter_handlers[key] = reduce_scatter_raw(
659+
output, handle = reduce_scatter_raw(
641660
tensor,
642661
self.process_group,
643662
op=reduce_op,
644663
async_op=async_op,
645664
)
646665

666+
if self._enable_early_reduce_scatter_release:
667+
self._early_curr_layer_rs_handles.append(handle)
668+
669+
self.reduce_scatter_handlers[key] = (output, handle)
670+
647671
result, handle = (
648672
self._get_constant_zero(
649673
(
@@ -698,6 +722,10 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06
698722
):
699723
self._zero_optim.reduce_left_grads_after_backward()
700724

725+
if self._isp_communicator and self._isp_communicator._enable_early_reduce_scatter_release:
726+
self._isp_communicator._early_prev_layer_rs_handles = []
727+
self._isp_communicator._early_curr_layer_rs_handles = []
728+
701729
def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613
702730
pass
703731

@@ -872,9 +900,7 @@ def _q_kv(self, q: torch.Tensor, kv: torch.Tensor, *args, **kwargs) -> torch.Ten
872900

873901
q, kv = _SeqAllToAll.apply(self.spg, [2, 3], [1, 1], q, kv)
874902

875-
torch.cuda.synchronize()
876903
context = self.local_attn(q, kv, *args, **kwargs)
877-
torch.cuda.synchronize()
878904

879905
context = _SeqAllToAll.apply(self.spg, 1, 2, context)
880906

internlm/initialize/launch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,11 +450,19 @@ def args_sanity_check():
450450
gpc.config.parallel["weight"]["overlap"] = False
451451
if gpc.config.parallel["tensor"]["mode"] != TensorParallelMode.isp.name:
452452
assert gpc.config.parallel["weight"]["size"] <= 1, "weight parallel is only supported with isp"
453+
454+
if "early_reduce_scatter_release" not in gpc.config.parallel.weight:
455+
gpc.config.parallel.weight["early_reduce_scatter_release"] = True
456+
453457
# set default value for expert_weight parallel
454458
if gpc.config.parallel["expert_weight"].get("overlap", None) is None:
455459
gpc.config.parallel["expert_weight"]["overlap"] = False
456460
if gpc.config.parallel["expert"].get("no_tp", None) is None:
457461
gpc.config.parallel["expert"]["no_tp"] = False
462+
463+
if "early_reduce_scatter_release" not in gpc.config.parallel.expert_weight:
464+
gpc.config.parallel.expert_weight["early_reduce_scatter_release"] = True
465+
458466
# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
459467
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
460468
assert (

internlm/train/pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
361361
gpc.config.parallel.weight.overlap,
362362
gpc.get_group(ParallelMode.WEIGHT),
363363
is_moe=False,
364+
early_reduce_scatter_release=gpc.config.parallel.weight.early_reduce_scatter_release,
364365
)
365366
# register communicator for isp column parallel linear.
366367
ColumnParallelLinear.register_cls_communicator(isp_communicator)
@@ -386,6 +387,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
386387
gpc.config.parallel.expert_weight.overlap,
387388
gpc.get_group(ParallelMode.EXPERT_WEIGHT),
388389
is_moe=True,
390+
early_reduce_scatter_release=gpc.config.parallel.expert_weight.early_reduce_scatter_release,
389391
)
390392
for moe in _submodule_filter(model, Experts):
391393
for column_linear in _submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)):

0 commit comments

Comments
 (0)