Skip to content

Commit 24e48f3

Browse files
ENH: Allow FSDP ignored modules to be regex (#3698)
* ENH: Allow FSDP ignored modules to be regex Description For FSDP, there is an option to indicate ignored_modules, which should be a list of modules are ignored by FSDP. Even though this argument was supported in accelerate, it was not very usable: 1. Listing all modules can tricky, especially with something like PEFT, where the whole model is wrapped and thus the module structure changes. 2. When configuring this argument, accelerate takes a detour via environment variables. These can only be strings. Therefore, passing a list of modules is not feasible. Moreover, I noticed that the environment variable for ignored_modules was not even set, so configuring this argument didn't even work. Status This PR is lacking tests. I would be happy for pointers on how to add those. Context When using PEFT with LoRA and the target_parameters feature, I ran into an issue training such a model with FSDP. The only working fix I found was to ignore the layers targeted by LoRA. However, I could not configure accelerate to do that. With this PR, it is possible. I could successfully trained such a PEFT model that targets q_proj and v_proj by setting fsdp_ignored_modules: '.*\.(q_proj$|v_proj$)'. * Fix type annotation * Fix failing test
1 parent 6640ff4 commit 24e48f3

File tree

4 files changed

+21
-3
lines changed

4 files changed

+21
-3
lines changed

src/accelerate/accelerator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,17 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
18721872
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
18731873
"device_id": self.device,
18741874
}
1875+
1876+
if isinstance(kwargs["ignored_modules"], str):
1877+
reg = re.compile(kwargs["ignored_modules"])
1878+
ignored = []
1879+
for name, module in model.named_modules():
1880+
if reg.fullmatch(name):
1881+
# ensure that the device for these modules is still set correctly
1882+
module.to(self.device)
1883+
ignored.append(module)
1884+
kwargs["ignored_modules"] = ignored
1885+
18751886
model = FSDP(model, **kwargs)
18761887
if fsdp_plugin.activation_checkpointing:
18771888
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (

src/accelerate/utils/dataclasses.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,8 +1561,9 @@ class FullyShardedDataParallelPlugin:
15611561
Whether to offload parameters to CPU. Should be either a `bool` or an instance of
15621562
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or
15631563
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.
1564-
ignored_modules (`Optional[Iterable[torch.nn.Module]]`, defaults to `None`):
1565-
A list of modules to ignore when wrapping with FSDP.
1564+
ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):
1565+
A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name
1566+
using regex fullmatch.
15661567
state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):
15671568
State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or
15681569
`sharded_state_dict`.
@@ -1660,7 +1661,7 @@ class FullyShardedDataParallelPlugin:
16601661
"help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`"
16611662
},
16621663
)
1663-
ignored_modules: Optional[Iterable[torch.nn.Module]] = field(
1664+
ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field(
16641665
default=None,
16651666
metadata={"help": "A list of modules to ignore when wrapping with FSDP."},
16661667
)
@@ -1896,6 +1897,9 @@ def __post_init__(self):
18961897
str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
18971898
)
18981899

1900+
if self.ignored_modules is None:
1901+
self.ignored_modules = os.environ.get(env_prefix + "IGNORED_MODULES", None)
1902+
18991903
if self.cpu_ram_efficient_loading is None:
19001904
self.cpu_ram_efficient_loading = (
19011905
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1

src/accelerate/utils/launch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
328328
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
329329
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
330330
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
331+
if getattr(args, "fsdp_ignored_modules", None) is not None:
332+
current_env["FSDP_IGNORED_MODULES"] = str(args.fsdp_ignored_modules)
331333

332334
if args.use_megatron_lm:
333335
prefix = "MEGATRON_LM_"

tests/test_configs/latest_fsdp.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ fsdp_config:
1515
fsdp_sync_module_states: true
1616
fsdp_transformer_layer_cls_to_wrap: BertLayer
1717
fsdp_use_orig_params: true
18+
fsdp_ignored_modules: null
1819
machine_rank: 0
1920
main_training_function: main
2021
mixed_precision: 'no'

0 commit comments

Comments
 (0)