Skip to content

Commit cbf73d6

Browse files
committed
Merge branch 'develop' into feat/refactor-impl
2 parents 8e04b09 + 30bb508 commit cbf73d6

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

internlm/solver/optimizer/hybrid_zero_optim.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from internlm.core.context import global_context as gpc
2525
from internlm.core.parallel.comm import ISPCommunicatorWrapper, ParamAsyncBcastHandler
26-
from internlm.model.model_ops.modules.utils import is_gate_param, is_moe_param
26+
from internlm.model.model_ops.modules.utils import is_moe_param
2727
from internlm.monitor import send_alert_message
2828
from internlm.solver.optimizer.store import (
2929
BucketStore,
@@ -44,7 +44,7 @@
4444
from internlm.utils.config import Config
4545
from internlm.utils.logger import get_logger
4646
from internlm.utils.megatron_timers import megatron_timer as timer
47-
from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel
47+
from internlm.utils.parallel import is_using_isp, should_reduce_replica_param
4848
from internlm.utils.timeout import llm_timeout
4949

5050
from .base_optimizer import BaseOptimizer
@@ -393,11 +393,7 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613
393393
# the grad of layernorm should be all-reduce across the global process group
394394
# here is the first stage all-reduce in tp/wp process group
395395
# the second stage all-reduce will be processed in reduce_grad_hook
396-
if (
397-
is_using_sequence_parallel()
398-
and hasattr(param, IS_REPLICA_ZERO_PARALLEL)
399-
and getattr(param, IS_REPLICA_ZERO_PARALLEL) is True
400-
) or (is_gate_param(param) and gpc.config.parallel.expert.no_tp):
396+
if should_reduce_replica_param(param):
401397
accum_grad_obj.register_hook(extra_layernorm_reduce_grad_hook)
402398

403399
# we should not only register for parameters which have isp_reduce_scatter_name attr.

internlm/utils/parallel.py

+30
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
ParallelMode,
1414
)
1515
from internlm.core.context import global_context as gpc
16+
from internlm.model.modules.utils import is_gate_param
1617
from internlm.utils.utils import TensorParallelMode
1718

1819

@@ -85,6 +86,35 @@ def is_replica_expert_data_parallel_parameter(p):
8586
return hasattr(p, IS_REPLICA_EXPERT_DATA_PARALLEL) and getattr(p, IS_REPLICA_EXPERT_DATA_PARALLEL)
8687

8788

89+
def should_reduce_replica_param(p):
90+
_reduce = False
91+
92+
if not is_replica_zero_parallel_parameter(p):
93+
return _reduce
94+
95+
# for replica parameter
96+
if gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name:
97+
_reduce = False
98+
elif gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) in (
99+
TensorParallelMode.msp.name,
100+
TensorParallelMode.fsp.name,
101+
):
102+
_reduce = gpc.is_using_parallel_mode(ParallelMode.TENSOR)
103+
elif gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.isp.name:
104+
_reduce = gpc.is_using_parallel_mode(ParallelMode.WEIGHT)
105+
106+
if not is_gate_param(p):
107+
return _reduce
108+
109+
# for moe gate parameter
110+
if gpc.config.parallel["tensor"].get("mode", TensorParallelMode.mtp.name) == TensorParallelMode.mtp.name:
111+
_reduce = gpc.is_using_parallel_mode(ParallelMode.TENSOR) and getattr(
112+
gpc.config.parallel.expert, "no_tp", False
113+
)
114+
115+
return _reduce
116+
117+
88118
def sync_model_param(model):
89119
r"""Make sure data parameters are consistent during Data Parallel Mode.
90120

0 commit comments

Comments
 (0)