|
42 | 42 | FSDP_STATE_DICT_TYPE,
|
43 | 43 | )
|
44 | 44 | from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
|
| 45 | +from accelerate.utils.fsdp_utils import disable_fsdp_ram_efficient_loading, enable_fsdp_ram_efficient_loading |
45 | 46 | from accelerate.utils.other import patch_environment
|
46 | 47 |
|
47 | 48 |
|
@@ -98,16 +99,18 @@ def test_backward_prefetch(self):
|
98 | 99 |
|
99 | 100 | for i, prefetch_policy in enumerate(FSDP_BACKWARD_PREFETCH):
|
100 | 101 | expected_value = None if prefetch_policy == "NO_PREFETCH" else BackwardPrefetch(i + 1)
|
101 |
| - # env = self.fsdp_env.copy() |
102 |
| - # env["FSDP_BACKWARD_PREFETCH"] = prefetch_policy |
103 |
| - # with mockenv_context(**env): |
104 |
| - # fsdp_plugin = FullyShardedDataParallelPlugin() |
105 |
| - # assert fsdp_plugin.backward_prefetch == expected_value, f"Actual: {fsdp_plugin.backward_prefetch} != Expected: {expected_value}" |
| 102 | + env = self.fsdp_env.copy() |
| 103 | + env["FSDP_BACKWARD_PREFETCH"] = prefetch_policy |
| 104 | + with mockenv_context(**env): |
| 105 | + fsdp_plugin = FullyShardedDataParallelPlugin() |
| 106 | + assert ( |
| 107 | + fsdp_plugin.backward_prefetch == expected_value |
| 108 | + ), f"Actual: {fsdp_plugin.backward_prefetch} != Expected: {expected_value}" |
106 | 109 |
|
107 |
| - # # Check if torch enum works |
108 |
| - # if prefetch_policy != "NO_PREFETCH": |
109 |
| - # fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=BackwardPrefetch(i + 1)) |
110 |
| - # assert fsdp_plugin.backward_prefetch == expected_value |
| 110 | + # Check if torch enum works |
| 111 | + if prefetch_policy != "NO_PREFETCH": |
| 112 | + fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=BackwardPrefetch(i + 1)) |
| 113 | + assert fsdp_plugin.backward_prefetch == expected_value |
111 | 114 |
|
112 | 115 | # Check if name works
|
113 | 116 | fsdp_plugin = FullyShardedDataParallelPlugin(backward_prefetch=prefetch_policy)
|
@@ -263,6 +266,16 @@ def test_cpu_offload(self):
|
263 | 266 | fsdp_plugin = FullyShardedDataParallelPlugin(cpu_offload=flag)
|
264 | 267 | assert fsdp_plugin.cpu_offload == CPUOffload(offload_params=flag)
|
265 | 268 |
|
| 269 | + def test_cpu_ram_efficient_loading(self): |
| 270 | + enable_fsdp_ram_efficient_loading() |
| 271 | + fsdp_plugin = FullyShardedDataParallelPlugin() |
| 272 | + assert fsdp_plugin.cpu_ram_efficient_loading is True |
| 273 | + assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "True" |
| 274 | + disable_fsdp_ram_efficient_loading() |
| 275 | + fsdp_plugin = FullyShardedDataParallelPlugin() |
| 276 | + assert fsdp_plugin.cpu_ram_efficient_loading is False |
| 277 | + assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "False" |
| 278 | + |
266 | 279 |
|
267 | 280 | # Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
|
268 | 281 | @require_non_torch_xla
|
|
0 commit comments