|
11 | 11 | distribute_module,
|
12 | 12 | DTensor,
|
13 | 13 | init_device_mesh,
|
| 14 | + Replicate, |
14 | 15 | Shard,
|
15 | 16 | )
|
16 | 17 | from torch.distributed._tensor.debug import CommDebugMode
|
@@ -378,6 +379,67 @@ def forward(self, x):
|
378 | 379 | for grad in grads:
|
379 | 380 | self.assertFalse(grad.isnan().any().item())
|
380 | 381 |
|
| 382 | + @skip_if_lt_x_gpu(4) |
| 383 | + def test_fsdp_tp_sync_module_state(self): |
| 384 | + mesh_2d = init_device_mesh( |
| 385 | + "cuda", (self.world_size // 2, 2), mesh_dim_names=["dp", "tp"] |
| 386 | + ) |
| 387 | + tp_mesh = mesh_2d["tp"] |
| 388 | + dp_mesh = mesh_2d["dp"] |
| 389 | + |
| 390 | + # set random seed for each rank |
| 391 | + torch.manual_seed(mesh_2d.get_rank()) |
| 392 | + |
| 393 | + class TestModel(torch.nn.Module): |
| 394 | + def __init__(self): |
| 395 | + super().__init__() |
| 396 | + replicated_dt = DTensor.from_local( |
| 397 | + torch.randn(8, 8), tp_mesh, [Replicate()], run_check=False |
| 398 | + ) |
| 399 | + replicated_buffer_dt = DTensor.from_local( |
| 400 | + torch.randn(8, 8), tp_mesh, [Replicate()], run_check=False |
| 401 | + ) |
| 402 | + self.param = torch.nn.Parameter(replicated_dt) |
| 403 | + self.register_buffer("buf", replicated_buffer_dt) |
| 404 | + |
| 405 | + def forward(self, x): |
| 406 | + return self.param + self.buffer + 1 |
| 407 | + |
| 408 | + model = TestModel() |
| 409 | + |
| 410 | + def assert_local_shard_across_ranks(local_tensor, group, check_equal=True): |
| 411 | + gathered_tensors = [ |
| 412 | + torch.empty_like(local_tensor) for _ in range(group.size()) |
| 413 | + ] |
| 414 | + dist.all_gather(gathered_tensors, local_tensor, group=group) |
| 415 | + # on dp mesh dim local tensor does not equal |
| 416 | + tensor_to_compare = gathered_tensors[0] |
| 417 | + for tensor in gathered_tensors[1:]: |
| 418 | + if check_equal: |
| 419 | + self.assertTrue(torch.equal(tensor, tensor_to_compare)) |
| 420 | + else: |
| 421 | + self.assertFalse(torch.equal(tensor, tensor_to_compare)) |
| 422 | + |
| 423 | + dp_group = dp_mesh.get_group() |
| 424 | + |
| 425 | + # check on dp mesh dim param local tensor does not equal |
| 426 | + local_param = model.param.to_local() |
| 427 | + assert_local_shard_across_ranks(local_param, dp_group, check_equal=False) |
| 428 | + # check on dp mesh dim buffer local tensor does not equal |
| 429 | + local_buf = model.buf.to_local() |
| 430 | + assert_local_shard_across_ranks(local_buf, dp_group, check_equal=False) |
| 431 | + |
| 432 | + # wrap with fsdp sync param should sync dp mesh dim |
| 433 | + fsdp_mod = FSDP(model, device_mesh=dp_mesh, sync_module_states=True) |
| 434 | + with fsdp_mod.summon_full_params(fsdp_mod): |
| 435 | + # on dp mesh dim local param does equal after sync_module_states |
| 436 | + local_param = fsdp_mod.param.to_local() |
| 437 | + assert_local_shard_across_ranks(local_param, dp_group, check_equal=True) |
| 438 | + |
| 439 | + # on dp mesh dim local buf does equal after sync_module_states |
| 440 | + local_buf = fsdp_mod.buf.to_local() |
| 441 | + assert_local_shard_across_ranks(local_buf, dp_group, check_equal=True) |
| 442 | + |
381 | 443 |
|
382 | 444 | instantiate_parametrized_tests(TestTPFSDPIntegration)
|
383 | 445 |
|
|
0 commit comments