Skip to content

Commit 8e04b09

Browse files
committed
update fsdp wrap
1 parent e854c90 commit 8e04b09

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

Diff for: internlm/core/fsdp.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import collections
2-
import functools
32
import itertools
43
from typing import List, Optional, Set, Union
54

@@ -11,7 +10,7 @@
1110
BackwardPrefetch,
1211
ShardingStrategy,
1312
)
14-
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
13+
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
1514

1615
from internlm.accelerator.abstract_accelerator import get_accelerator
1716
from internlm.core.context import ParallelMode
@@ -170,7 +169,7 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
170169
module=model,
171170
process_group=gpc.get_group(ParallelMode.GLOBAL),
172171
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO2: SHARD_GRAD_OP, ZeRO3: FULL_SHARD
173-
auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=set(wrap_cls)),
172+
auto_wrap_policy=ModuleWrapPolicy(wrap_cls),
174173
sync_module_states=fsdp_init_method != "cuda", # sync model paramters
175174
forward_prefetch=True,
176175
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,

Diff for: internlm/initialize/initialize_launcher.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,14 @@ def args_sanity_check():
579579
assert (
580580
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
581581
), "not support overlap and moe at the same time"
582-
assert gpc.config.parallel.zero1.size in (
583-
-1,
584-
gpc.get_world_size(ParallelMode.DATA),
585-
) or is_using_fsdp(), "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
582+
assert (
583+
gpc.config.parallel.zero1.size
584+
in (
585+
-1,
586+
gpc.get_world_size(ParallelMode.DATA),
587+
)
588+
or is_using_fsdp()
589+
), "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
586590

587591
if gpc.config.parallel.tensor.mode != "isp":
588592
assert gpc.config.parallel.expert_weight.size <= 1, "expert weight parallel is only supported with isp"
@@ -637,11 +641,6 @@ def args_sanity_check():
637641
assert (
638642
gpc.config.parallel.weight.size == 1
639643
), f"fsdp only compatible with weight size = 1, but get weight size = {gpc.config.parallel.weight.size}"
640-
if "expert" in gpc.config.parallel:
641-
assert gpc.config.parallel.expert.size in (
642-
1,
643-
-1,
644-
), f"fsdp only compatible with expert size = (-1, 1), but get expert size = {gpc.config.parallel.expert.size}"
645644
if "expert_zero1" in gpc.config.parallel:
646645
assert gpc.config.parallel.expert_zero1.size == 1, (
647646
f"fsdp only compatible with expert_zero1 size = 1, "

0 commit comments

Comments
 (0)