Skip to content

Commit eebf115

Browse files
wanchaolpytorchmergebot
authored andcommitted
[fsdp][2d] FSDP sync module states handle tensor subclass (pytorch#117336)
This PR adds the ability to let FSDP sync module states kwarg to handle tensor subclass, because FSDP works on the "dp" mesh dimension, as long as FSDP works on a different device mesh dimension, we can safety let FSDP just broadcast the DTensor local shards. fixes pytorch#117126 Pull Request resolved: pytorch#117336 Approved by: https://github.com/awgu
1 parent fc044b5 commit eebf115

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

test/distributed/fsdp/test_fsdp_tp_integration.py

+62
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
distribute_module,
1212
DTensor,
1313
init_device_mesh,
14+
Replicate,
1415
Shard,
1516
)
1617
from torch.distributed._tensor.debug import CommDebugMode
@@ -378,6 +379,67 @@ def forward(self, x):
378379
for grad in grads:
379380
self.assertFalse(grad.isnan().any().item())
380381

382+
@skip_if_lt_x_gpu(4)
383+
def test_fsdp_tp_sync_module_state(self):
384+
mesh_2d = init_device_mesh(
385+
"cuda", (self.world_size // 2, 2), mesh_dim_names=["dp", "tp"]
386+
)
387+
tp_mesh = mesh_2d["tp"]
388+
dp_mesh = mesh_2d["dp"]
389+
390+
# set random seed for each rank
391+
torch.manual_seed(mesh_2d.get_rank())
392+
393+
class TestModel(torch.nn.Module):
394+
def __init__(self):
395+
super().__init__()
396+
replicated_dt = DTensor.from_local(
397+
torch.randn(8, 8), tp_mesh, [Replicate()], run_check=False
398+
)
399+
replicated_buffer_dt = DTensor.from_local(
400+
torch.randn(8, 8), tp_mesh, [Replicate()], run_check=False
401+
)
402+
self.param = torch.nn.Parameter(replicated_dt)
403+
self.register_buffer("buf", replicated_buffer_dt)
404+
405+
def forward(self, x):
406+
return self.param + self.buffer + 1
407+
408+
model = TestModel()
409+
410+
def assert_local_shard_across_ranks(local_tensor, group, check_equal=True):
411+
gathered_tensors = [
412+
torch.empty_like(local_tensor) for _ in range(group.size())
413+
]
414+
dist.all_gather(gathered_tensors, local_tensor, group=group)
415+
# on dp mesh dim local tensor does not equal
416+
tensor_to_compare = gathered_tensors[0]
417+
for tensor in gathered_tensors[1:]:
418+
if check_equal:
419+
self.assertTrue(torch.equal(tensor, tensor_to_compare))
420+
else:
421+
self.assertFalse(torch.equal(tensor, tensor_to_compare))
422+
423+
dp_group = dp_mesh.get_group()
424+
425+
# check on dp mesh dim param local tensor does not equal
426+
local_param = model.param.to_local()
427+
assert_local_shard_across_ranks(local_param, dp_group, check_equal=False)
428+
# check on dp mesh dim buffer local tensor does not equal
429+
local_buf = model.buf.to_local()
430+
assert_local_shard_across_ranks(local_buf, dp_group, check_equal=False)
431+
432+
# wrap with fsdp sync param should sync dp mesh dim
433+
fsdp_mod = FSDP(model, device_mesh=dp_mesh, sync_module_states=True)
434+
with fsdp_mod.summon_full_params(fsdp_mod):
435+
# on dp mesh dim local param does equal after sync_module_states
436+
local_param = fsdp_mod.param.to_local()
437+
assert_local_shard_across_ranks(local_param, dp_group, check_equal=True)
438+
439+
# on dp mesh dim local buf does equal after sync_module_states
440+
local_buf = fsdp_mod.buf.to_local()
441+
assert_local_shard_across_ranks(local_buf, dp_group, check_equal=True)
442+
381443

382444
instantiate_parametrized_tests(TestTPFSDPIntegration)
383445

torch/distributed/fsdp/_init_utils.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
from torch.distributed.fsdp.wrap import _Policy
5757
from torch.distributed.tensor.parallel.fsdp import DTensorExtensions
5858
from torch.distributed.utils import _sync_params_and_buffers
59+
60+
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
5961
from torch.utils.hooks import RemovableHandle
6062

6163
_TORCHDISTX_AVAIL = True
@@ -1062,8 +1064,25 @@ def _sync_module_params_and_buffers(
10621064
# Avoid re-synchronizing buffers in case of nested wrapping
10631065
if not getattr(buffer, FSDP_SYNCED, False):
10641066
setattr(buffer, FSDP_SYNCED, True)
1065-
module_states.append(buffer.detach())
1066-
module_states.extend(param.detach() for param in params)
1067+
detached_buffer = buffer.detach()
1068+
if is_traceable_wrapper_subclass(detached_buffer):
1069+
# NOTE: Here we assume no nested subclasses, at most one level of subclass
1070+
# in both model's buffers and params
1071+
attrs, _ = detached_buffer.__tensor_flatten__() # type: ignore[attr-defined]
1072+
inner_buffers = [getattr(detached_buffer, attr) for attr in attrs]
1073+
module_states.extend(inner_buffers)
1074+
else:
1075+
module_states.append(detached_buffer)
1076+
1077+
for param in params:
1078+
detached_param = param.detach()
1079+
if is_traceable_wrapper_subclass(detached_param):
1080+
attrs, _ = detached_param.__tensor_flatten__() # type: ignore[attr-defined]
1081+
inner_params = [getattr(detached_param, attr) for attr in attrs]
1082+
module_states.extend(inner_params)
1083+
else:
1084+
module_states.append(detached_param)
1085+
10671086
_check_module_states_for_sync_module_states(module_states)
10681087
_sync_params_and_buffers(
10691088
process_group,

0 commit comments

Comments
 (0)