Skip to content

Commit bab15df

Browse files
Revert "[FSDP2] Move to public torch.distributed.fsdp (pytorch#141868)"
This reverts commit 45583a5. Reverted pytorch#141868 on behalf of https://github.com/atalman due to failing internally ([comment](pytorch#141868 (comment)))
1 parent 4af7aa5 commit bab15df

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+174
-363
lines changed

docs/source/distributed.fsdp.fully_shard.rst

Lines changed: 0 additions & 85 deletions
This file was deleted.

docs/source/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ Features described in this documentation are classified by release status:
7979
torch.distributed.algorithms.join <distributed.algorithms.join>
8080
torch.distributed.elastic <distributed.elastic>
8181
torch.distributed.fsdp <fsdp>
82-
torch.distributed.fsdp.fully_shard <distributed.fsdp.fully_shard>
8382
torch.distributed.tensor.parallel <distributed.tensor.parallel>
8483
torch.distributed.optim <distributed.optim>
8584
torch.distributed.pipelining <distributed.pipelining>

test/distributed/_composable/fsdp/test_fully_shard_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
import torch.distributed as dist
1212
import torch.nn as nn
13-
from torch.distributed.fsdp import fully_shard
13+
from torch.distributed._composable.fsdp import fully_shard
1414
from torch.nn.parallel.scatter_gather import _is_namedtuple
1515
from torch.testing._internal.common_cuda import TEST_CUDA
1616
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

test/distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch
88
import torch.nn as nn
99
from torch.distributed._composable import replicate
10+
from torch.distributed._composable.fsdp import fully_shard
1011
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
11-
from torch.distributed.fsdp import fully_shard
1212
from torch.distributed.tensor.debug import CommDebugMode
1313
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1414
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack

test/distributed/_composable/fsdp/test_fully_shard_comm.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,30 @@
1111
import torch.nn as nn
1212
import torch.nn.functional as F
1313
from torch.distributed._composable import checkpoint, replicate
14-
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
15-
from torch.distributed.fsdp import (
14+
from torch.distributed._composable.fsdp import (
1615
FSDPModule,
1716
fully_shard,
1817
MixedPrecisionPolicy,
1918
OffloadPolicy,
2019
)
21-
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
20+
from torch.distributed._composable.fsdp._fsdp_collectives import (
2221
_div_if_needed,
2322
_get_gradient_divide_factors,
2423
foreach_all_gather,
2524
foreach_all_gather_copy_out,
2625
foreach_reduce,
2726
)
28-
from torch.distributed.fsdp._fully_shard._fsdp_common import FSDPMeshInfo, TrainingState
29-
from torch.distributed.fsdp._fully_shard._fsdp_init import (
27+
from torch.distributed._composable.fsdp._fsdp_common import FSDPMeshInfo, TrainingState
28+
from torch.distributed._composable.fsdp._fsdp_init import (
3029
_get_post_forward_mesh_info,
3130
_init_default_fully_shard_mesh,
3231
)
33-
from torch.distributed.fsdp._fully_shard._fsdp_param import ShardedState
34-
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
35-
from torch.distributed.tensor import DTensor
32+
from torch.distributed._composable.fsdp._fsdp_param import ShardedState
33+
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
34+
from torch.distributed._tensor import DTensor
35+
from torch.distributed._tensor.experimental import implicit_replication
36+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
3637
from torch.distributed.tensor.debug import CommDebugMode
37-
from torch.distributed.tensor.experimental import implicit_replication
3838
from torch.testing._internal.common_cuda import TEST_CUDA
3939
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
4040
from torch.testing._internal.common_fsdp import (

test/distributed/_composable/fsdp/test_fully_shard_compile.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,17 @@
1212

1313
import torch
1414
import torch._dynamo.testing
15+
import torch.distributed._composable.fsdp._fsdp_param
1516
import torch.nn.functional as F
1617
from torch import nn
1718
from torch._dynamo.utils import counters
1819
from torch._inductor import comms
1920
from torch._inductor.utils import is_fallback_op, run_and_get_code
21+
from torch.distributed._composable.fsdp import fully_shard
22+
from torch.distributed._composable.fsdp._fsdp_common import TrainingState
23+
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
2024
from torch.distributed._tensor import init_device_mesh
21-
from torch.distributed.fsdp import (
22-
fully_shard,
23-
FullyShardedDataParallel as FSDP,
24-
ShardingStrategy,
25-
)
26-
from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
27-
from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
25+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
2826
from torch.testing import FileCheck
2927
from torch.testing._internal.common_distributed import (
3028
at_least_x_gpu,
@@ -85,7 +83,7 @@ def _test_disable_compiling_hooks(
8583
):
8684
torch._dynamo.reset()
8785
trace_rules_check_count = 0
88-
HOOKS_FILE_NAME = "torch/distributed/fsdp/_fully_shard/_fsdp_state.py"
86+
HOOKS_FILE_NAME = "torch/distributed/_composable/fsdp/_fsdp_state.py"
8987
HOOK_WRAPPER_NAME = "fsdp_hook_wrapper"
9088

9189
def patched_trace_rules_check(*args, **kwargs):

test/distributed/_composable/fsdp/test_fully_shard_extensions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import torch.nn as nn
1414
import torch.utils._pytree as pytree
1515
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
16+
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1617
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
17-
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
1818
from torch.testing._internal.common_cuda import TEST_CUDA
1919
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
2020
from torch.testing._internal.common_fsdp import (

test/distributed/_composable/fsdp/test_fully_shard_frozen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212
from torch.distributed._composable import checkpoint, replicate
13-
from torch.distributed.fsdp import fully_shard
14-
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
13+
from torch.distributed._composable.fsdp import fully_shard
14+
from torch.distributed._composable.fsdp._fsdp_param_group import (
1515
RegisterPostBackwardFunction,
1616
)
1717
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu

test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch
55
import torch.nn as nn
66
from torch.amp.grad_scaler import GradScaler, OptState
7+
from torch.distributed._composable.fsdp import fully_shard
78
from torch.distributed._tensor import init_device_mesh
8-
from torch.distributed.fsdp import fully_shard
99
from torch.distributed.tensor.parallel import (
1010
ColwiseParallel,
1111
parallelize_module,

test/distributed/_composable/fsdp/test_fully_shard_init.py

Lines changed: 7 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
import torch.distributed as dist
1010
import torch.nn as nn
1111
from torch.distributed._composable import replicate
12+
from torch.distributed._composable.fsdp import fully_shard
13+
from torch.distributed._composable.fsdp._fsdp_init import (
14+
_get_managed_modules,
15+
_get_managed_states,
16+
)
17+
from torch.distributed._composable.fsdp._fsdp_param import ParamModuleInfo
18+
from torch.distributed._composable.fsdp._fsdp_param_group import _get_param_module_infos
1219
from torch.distributed._tensor import (
1320
DeviceMesh,
1421
distribute_tensor,
@@ -17,15 +24,6 @@
1724
Shard,
1825
)
1926
from torch.distributed.device_mesh import init_device_mesh
20-
from torch.distributed.fsdp import fully_shard
21-
from torch.distributed.fsdp._fully_shard._fsdp_init import (
22-
_get_managed_modules,
23-
_get_managed_states,
24-
)
25-
from torch.distributed.fsdp._fully_shard._fsdp_param import ParamModuleInfo
26-
from torch.distributed.fsdp._fully_shard._fsdp_param_group import (
27-
_get_param_module_infos,
28-
)
2927
from torch.distributed.fsdp._init_utils import (
3028
_init_inter_node_process_group,
3129
_init_intra_node_process_group,
@@ -1158,26 +1156,5 @@ def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
11581156
fully_shard(model, shard_placement_fn=shard_placement_fn)
11591157

11601158

1161-
# TODO: Remove this test class once we remove the old import path:
1162-
# torch/distributed/_composable/fsdp
1163-
class TestFullyShardOldImport(FSDPTestMultiThread):
1164-
@property
1165-
def world_size(self) -> int:
1166-
return 2
1167-
1168-
@unittest.skipIf(not TEST_CUDA, "no cuda")
1169-
def test_old_import_training(self):
1170-
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1171-
1172-
model = nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16))
1173-
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
1174-
fully_shard(model[0], mp_policy=mp_policy)
1175-
fully_shard(model[1], mp_policy=mp_policy)
1176-
fully_shard(model, mp_policy=mp_policy)
1177-
1178-
inp = torch.randn((8, 16), device="cuda")
1179-
model(inp).sum().backward()
1180-
1181-
11821159
if __name__ == "__main__":
11831160
run_tests()

0 commit comments

Comments
 (0)