Skip to content

Commit 90d5023

Browse files
authored
Add small util to enable FSDP offloading quickly (#3006)
* Wrap up util * Add small util * Update doc * Don't req * Clean
1 parent 3bde615 commit 90d5023

File tree

5 files changed

+58
-16
lines changed

5 files changed

+58
-16
lines changed

docs/source/package_reference/fsdp.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ rendered properly in your Markdown viewer.
1515

1616
# Utilities for Fully Sharded Data Parallelism
1717

18+
[[autodoc]] utils.enable_fsdp_ram_efficient_loading
19+
20+
[[autodoc]] utils.disable_fsdp_ram_efficient_loading
21+
1822
[[autodoc]] utils.merge_fsdp_weights
1923

20-
[[autodoc]] utils.FullyShardedDataParallelPlugin
24+
[[autodoc]] utils.FullyShardedDataParallelPlugin

src/accelerate/utils/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,15 @@
188188
)
189189

190190
from .bnb import has_4bit_bnb_layers, load_and_quantize_model
191-
from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, merge_fsdp_weights, save_fsdp_model, save_fsdp_optimizer
191+
from .fsdp_utils import (
192+
disable_fsdp_ram_efficient_loading,
193+
enable_fsdp_ram_efficient_loading,
194+
load_fsdp_model,
195+
load_fsdp_optimizer,
196+
merge_fsdp_weights,
197+
save_fsdp_model,
198+
save_fsdp_optimizer,
199+
)
192200
from .launch import (
193201
PrepareForLaunch,
194202
_filter_args,

src/accelerate/utils/dataclasses.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,7 @@ class FullyShardedDataParallelPlugin:
13001300
"for reduced memory usage. Defaults to `False`"
13011301
},
13021302
)
1303-
ram_efficient_loading: bool = field(
1303+
cpu_ram_efficient_loading: bool = field(
13041304
default=None,
13051305
metadata={
13061306
"help": "If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. "
@@ -1399,12 +1399,12 @@ def __post_init__(self):
13991399
str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
14001400
)
14011401

1402-
if self.ram_efficient_loading is None:
1403-
self.ram_efficient_loading = (
1404-
str_to_bool(os.environ.get(env_prefix + "RAM_EFFICIENT_LOADING", "False")) == 1
1402+
if self.cpu_ram_efficient_loading is None:
1403+
self.cpu_ram_efficient_loading = (
1404+
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1
14051405
)
14061406

1407-
if self.ram_efficient_loading and not self.sync_module_states:
1407+
if self.cpu_ram_efficient_loading and not self.sync_module_states:
14081408
warnings.warn(
14091409
"sync_module_states cannot be False since efficient cpu ram loading enabled. "
14101410
"Setting sync_module_states to True."

src/accelerate/utils/fsdp_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,23 @@
2727
logger = get_logger(__name__)
2828

2929

30+
def enable_fsdp_ram_efficient_loading():
31+
"""
32+
Enables RAM efficient loading of Hugging Face models for FSDP in the environment.
33+
"""
34+
# Sets values for `transformers.modeling_utils.is_fsdp_enabled`
35+
if "ACCELERATE_USE_FSDP" not in os.environ:
36+
os.environ["ACCELERATE_USE_FSDP"] = "True"
37+
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "True"
38+
39+
40+
def disable_fsdp_ram_efficient_loading():
41+
"""
42+
Disables RAM efficient loading of Hugging Face models for FSDP in the environment.
43+
"""
44+
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "False"
45+
46+
3047
def _get_model_state_dict(model, adapter_only=False):
3148
if adapter_only and is_peft_model(model):
3249
from peft import get_peft_model_state_dict

tests/fsdp/test_fsdp.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
FSDP_STATE_DICT_TYPE,
4343
)
4444
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
45+
from accelerate.utils.fsdp_utils import disable_fsdp_ram_efficient_loading, enable_fsdp_ram_efficient_loading
4546
from accelerate.utils.other import patch_environment
4647

4748

@@ -98,16 +99,18 @@ def test_backward_prefetch(self):
9899

99100
for i, prefetch_policy in enumerate(FSDP_BACKWARD_PREFETCH):
100101
expected_value = None if prefetch_policy == "NO_PREFETCH" else BackwardPrefetch(i + 1)
101-
# env = self.fsdp_env.copy()
102-
# env["FSDP_BACKWARD_PREFETCH"] = prefetch_policy
103-
# with mockenv_context(**env):
104-
# fsdp_plugin = FullyShardedDataParallelPlugin()
105-
# assert fsdp_plugin.backward_prefetch == expected_value, f"Actual: {fsdp_plugin.backward_prefetch} != Expected: {expected_value}"
102+
env = self.fsdp_env.copy()
103+
env["FSDP_BACKWARD_PREFETCH"] = prefetch_policy
104+
with mockenv_context(**env):
105+
fsdp_plugin = FullyShardedDataParallelPlugin()
106+
assert (
107+
fsdp_plugin.backward_prefetch == expected_value
108+
), f"Actual: {fsdp_plugin.backward_prefetch} != Expected: {expected_value}"
106109

107-
# # Check if torch enum works
108-
# if prefetch_policy != "NO_PREFETCH":
109-
# fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=BackwardPrefetch(i + 1))
110-
# assert fsdp_plugin.backward_prefetch == expected_value
110+
# Check if torch enum works
111+
if prefetch_policy != "NO_PREFETCH":
112+
fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=BackwardPrefetch(i + 1))
113+
assert fsdp_plugin.backward_prefetch == expected_value
111114

112115
# Check if name works
113116
fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=prefetch_policy)
@@ -263,6 +266,16 @@ def test_cpu_offload(self):
263266
fsdp_plugin = FullyShardedDataParallelPlugin(cpu_offload=flag)
264267
assert fsdp_plugin.cpu_offload == CPUOffload(offload_params=flag)
265268

269+
def test_cpu_ram_efficient_loading(self):
270+
enable_fsdp_ram_efficient_loading()
271+
fsdp_plugin = FullyShardedDataParallelPlugin()
272+
assert fsdp_plugin.cpu_ram_efficient_loading is True
273+
assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "True"
274+
disable_fsdp_ram_efficient_loading()
275+
fsdp_plugin = FullyShardedDataParallelPlugin()
276+
assert fsdp_plugin.cpu_ram_efficient_loading is False
277+
assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "False"
278+
266279

267280
# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
268281
@require_non_torch_xla

0 commit comments

Comments
 (0)