|
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | import torch.nn.functional as F
|
9 |
| -from torch.distributed._tensor import DeviceMesh, distribute_tensor |
10 |
| -from torch.distributed._tensor.api import DTensor |
11 |
| -from torch.distributed._tensor.placement_types import ( |
| 9 | +from torch.distributed import DeviceMesh, init_device_mesh |
| 10 | +from torch.distributed.tensor import ( |
| 11 | + distribute_tensor, |
| 12 | + DTensor, |
12 | 13 | Partial,
|
13 | 14 | Placement,
|
14 | 15 | Replicate,
|
@@ -339,6 +340,31 @@ def test_scaled_dot_product_attention(self):
|
339 | 340 | self.assertTrue(dist_value.grad.placements[0].is_shard(dim=1))
|
340 | 341 | self.assertEqual(dist_value.grad.full_tensor(), value.grad)
|
341 | 342 |
|
| 343 | + @skip_unless_torch_gpu |
| 344 | + @with_comms() |
| 345 | + def test_dtensor_mm(self): |
| 346 | + """ |
| 347 | + Test mm with DTensor with 2D mesh. |
| 348 | + We need to add the test here since we only test 1D mesh in test_dtensor_ops.py. |
| 349 | + Also, we added tests for the corner case where one of the 2D dimension is 1. |
| 350 | +
|
| 351 | + # TODO: we need to test more DTensor ops with 2D mesh, especially when 1 of the |
| 352 | + mesh dimension of the 2D mesh is 1. |
| 353 | + """ |
| 354 | + mesh_0 = init_device_mesh(self.device_type, (self.world_size // 2, 2)) |
| 355 | + mesh_1 = init_device_mesh(self.device_type, (self.world_size, 1)) |
| 356 | + mesh_2 = init_device_mesh(self.device_type, (1, self.world_size)) |
| 357 | + |
| 358 | + for mesh in [mesh_0, mesh_1, mesh_2]: |
| 359 | + lhs = torch.randn(256, 128) |
| 360 | + rhs = torch.randn(128, 256) |
| 361 | + mm_result = lhs @ rhs |
| 362 | + |
| 363 | + lhs_dtensor = distribute_tensor(lhs, mesh, [Shard(dim=0), Replicate()]) |
| 364 | + rhs_dtensor = distribute_tensor(rhs, mesh, [Replicate(), Shard(dim=1)]) |
| 365 | + dtensor_result = lhs_dtensor @ rhs_dtensor |
| 366 | + self.assertEqual(dtensor_result.full_tensor(), mm_result) |
| 367 | + |
342 | 368 |
|
343 | 369 | if __name__ == "__main__":
|
344 | 370 | run_tests()
|
0 commit comments