|
23 | 23 | )
|
24 | 24 | from internlm.core.context import global_context as gpc
|
25 | 25 | from internlm.core.parallel.comm import ISPCommunicatorWrapper, ParamAsyncBcastHandler
|
26 |
| -from internlm.model.model_ops.modules.utils import is_gate_param, is_moe_param |
| 26 | +from internlm.model.model_ops.modules.utils import is_moe_param |
27 | 27 | from internlm.monitor import send_alert_message
|
28 | 28 | from internlm.solver.optimizer.store import (
|
29 | 29 | BucketStore,
|
|
44 | 44 | from internlm.utils.config import Config
|
45 | 45 | from internlm.utils.logger import get_logger
|
46 | 46 | from internlm.utils.megatron_timers import megatron_timer as timer
|
47 |
| -from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel |
| 47 | +from internlm.utils.parallel import is_using_isp, should_reduce_replica_param |
48 | 48 | from internlm.utils.timeout import llm_timeout
|
49 | 49 |
|
50 | 50 | from .base_optimizer import BaseOptimizer
|
@@ -393,11 +393,7 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613
|
393 | 393 | # the grad of layernorm should be all-reduce across the global process group
|
394 | 394 | # here is the first stage all-reduce in tp/wp process group
|
395 | 395 | # the second stage all-reduce will be processed in reduce_grad_hook
|
396 |
| - if ( |
397 |
| - is_using_sequence_parallel() |
398 |
| - and hasattr(param, IS_REPLICA_ZERO_PARALLEL) |
399 |
| - and getattr(param, IS_REPLICA_ZERO_PARALLEL) is True |
400 |
| - ) or (is_gate_param(param) and gpc.config.parallel.expert.no_tp): |
| 396 | + if should_reduce_replica_param(param): |
401 | 397 | accum_grad_obj.register_hook(extra_layernorm_reduce_grad_hook)
|
402 | 398 |
|
403 | 399 | # we should not only register for parameters which have isp_reduce_scatter_name attr.
|
|
0 commit comments