|
23 | 23 | ) |
24 | 24 | from internlm.core.parallel.comm.isp import ISPCommunicatorWrapper |
25 | 25 | from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler |
26 | | -from internlm.model.modules.utils import is_gate_param, is_moe_param |
| 26 | +from internlm.model.modules.utils import is_moe_param |
27 | 27 | from internlm.monitor import send_alert_message |
28 | 28 | from internlm.solver.optimizer.store import ( |
29 | 29 | BucketStore, |
|
44 | 44 | from internlm.utils.common import get_current_device |
45 | 45 | from internlm.utils.logger import get_logger |
46 | 46 | from internlm.utils.megatron_timers import megatron_timer as timer |
47 | | -from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel |
| 47 | +from internlm.utils.parallel import is_using_isp, should_reduce_replica_param |
48 | 48 | from internlm.utils.timeout import llm_timeout |
49 | 49 |
|
50 | 50 | from .base_optimizer import BaseOptimizer |
@@ -393,11 +393,7 @@ def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 |
393 | 393 | # the grad of layernorm should be all-reduce across the global process group |
394 | 394 | # here is the first stage all-reduce in tp/wp process group |
395 | 395 | # the second stage all-reduce will be processed in reduce_grad_hook |
396 | | - if ( |
397 | | - is_using_sequence_parallel() |
398 | | - and hasattr(param, IS_REPLICA_ZERO_PARALLEL) |
399 | | - and getattr(param, IS_REPLICA_ZERO_PARALLEL) is True |
400 | | - ) or (is_gate_param(param) and gpc.config.parallel.expert.no_tp): |
| 396 | + if should_reduce_replica_param(param): |
401 | 397 | accum_grad_obj.register_hook(extra_layernorm_reduce_grad_hook) |
402 | 398 |
|
403 | 399 | # we should not only register for parameters which have isp_reduce_scatter_name attr. |
|
0 commit comments