You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
ENH: Allow FSDP ignored modules to be regex (#3698)
* ENH: Allow FSDP ignored modules to be regex
Description
For FSDP, there is an option to indicate ignored_modules, which should
be a list of modules are ignored by FSDP. Even though this argument was
supported in accelerate, it was not very usable:
1. Listing all modules can tricky, especially with something like PEFT,
where the whole model is wrapped and thus the module structure changes.
2. When configuring this argument, accelerate takes a detour via
environment variables. These can only be strings. Therefore, passing a
list of modules is not feasible.
Moreover, I noticed that the environment variable for ignored_modules
was not even set, so configuring this argument didn't even work.
Status
This PR is lacking tests. I would be happy for pointers on how to add
those.
Context
When using PEFT with LoRA and the target_parameters feature, I ran into
an issue training such a model with FSDP. The only working fix I found
was to ignore the layers targeted by LoRA. However, I could not
configure accelerate to do that. With this PR, it is possible. I could
successfully trained such a PEFT model that targets q_proj and v_proj by
setting fsdp_ignored_modules: '.*\.(q_proj$|v_proj$)'.
* Fix type annotation
* Fix failing test
Copy file name to clipboardExpand all lines: src/accelerate/utils/dataclasses.py
+7-3Lines changed: 7 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -1561,8 +1561,9 @@ class FullyShardedDataParallelPlugin:
1561
1561
Whether to offload parameters to CPU. Should be either a `bool` or an instance of
1562
1562
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or
1563
1563
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.
1564
-
ignored_modules (`Optional[Iterable[torch.nn.Module]]`, defaults to `None`):
1565
-
A list of modules to ignore when wrapping with FSDP.
1564
+
ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):
1565
+
A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name
1566
+
using regex fullmatch.
1566
1567
state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):
1567
1568
State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or
1568
1569
`sharded_state_dict`.
@@ -1660,7 +1661,7 @@ class FullyShardedDataParallelPlugin:
1660
1661
"help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`"
0 commit comments