Skip to content

Commit b71ab3f

Browse files
wz337pytorchmergebot
authored andcommitted
[DTensor][Bug Fix]Fix 2D DTensor mm with mesh_shape (1, n) or (n, 1) (pytorch#139134)
Fixes pytorch#138742. In the issue, the matrix multiplication with DTensor failed when the size of one of mesh dimension is 1 when the mesh is > 1D. We are missing tests for covering this corner case where mesh_shape is (n, 1) or (1, n). The DTensor mm op is correct when the 1D mesh is of shape (self.world_size, ) or 2D mesh with none of the mesh_dimension has a size of 1. In this PR, we fixed the corner case by updating `gen_einsum_strategies` in `_einsum_strategy.py`. Specifically, we cannot skip generating `mesh_dim_strategies` when `mesh_dim <= 1`, as this is not valid for nD mesh with one of the mesh dimension sizes being 1. Without the fix, the OpStrategy generated for 2D mesh with mesh_shape of (1,n) or (n,1) is wrong, as the OpStrategy generated is 1D. ``` all_mesh_dim_strategies=[[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]] OpStrategy(all_strategies)::: [(R, R) -> R, (S(1), S(0)) -> P, (S(0), R) -> S(0), (R, S(1)) -> S(1)] @ mesh: (4, 1)[(R, R) -> R, (S(1), S(0)) -> P, (S(0), R) -> S(0), (R, S(1)) -> S(1)] @ mesh: (4, 1) ``` After the fix, we can see the OpStrategy generated is correct with 2D strategy. ``` all_mesh_dim_strategies=[[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]][[[Replicate(), Replicate(), Replicate()], [Partial(sum), Shard(dim=1), Shard(dim=0)], [Shard(dim=0), Shard(dim=0), Replicate()], [Shard(dim=1), Replicate(), Shard(dim=1)]]] OpStrategy(all_strategies) = [(RR, RR) -> RR, (RS(1), RS(0)) -> RP, (RS(0), RR) -> RS(0), (RR, RS(1)) -> RS(1), (S(1)R, S(0)R) -> PR, (S(1)S(1), S(0)S(0)) -> PP, (S(1)S(0), S(0)R) -> PS(0), (S(1)R, S(0)S(1)) -> PS(1), (S(0)R, RR) -> S(0)R, (S(0)S(1), RS(0)) -> S(0)P, (S(0)S(0), RR) -> S(0)S(0), (S(0)R, RS(1)) -> S(0)S(1), (RR, S(1)R) -> S(1)R, (RS(1), S(1)S(0)) -> S(1)P, (RS(0), S(1)R) -> S(1)S(0), (RR, S(1)S(1)) -> S(1)S(1)] @ mesh: (4, 1) ``` ******* As a follow up, we should add more test coverage for DTensor op with 2D mesh and 2D mesh with one of the size of mesh dimension being 1. ******* Pull Request resolved: pytorch#139134 Approved by: https://github.com/fegin
1 parent ceab24d commit b71ab3f

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

test/distributed/_tensor/test_matrix_ops.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
import torch
88
import torch.nn.functional as F
9-
from torch.distributed._tensor import DeviceMesh, distribute_tensor
10-
from torch.distributed._tensor.api import DTensor
11-
from torch.distributed._tensor.placement_types import (
9+
from torch.distributed import DeviceMesh, init_device_mesh
10+
from torch.distributed.tensor import (
11+
distribute_tensor,
12+
DTensor,
1213
Partial,
1314
Placement,
1415
Replicate,
@@ -339,6 +340,31 @@ def test_scaled_dot_product_attention(self):
339340
self.assertTrue(dist_value.grad.placements[0].is_shard(dim=1))
340341
self.assertEqual(dist_value.grad.full_tensor(), value.grad)
341342

343+
@skip_unless_torch_gpu
344+
@with_comms()
345+
def test_dtensor_mm(self):
346+
"""
347+
Test mm with DTensor with 2D mesh.
348+
We need to add the test here since we only test 1D mesh in test_dtensor_ops.py.
349+
Also, we added tests for the corner case where one of the 2D dimension is 1.
350+
351+
# TODO: we need to test more DTensor ops with 2D mesh, especially when 1 of the
352+
mesh dimension of the 2D mesh is 1.
353+
"""
354+
mesh_0 = init_device_mesh(self.device_type, (self.world_size // 2, 2))
355+
mesh_1 = init_device_mesh(self.device_type, (self.world_size, 1))
356+
mesh_2 = init_device_mesh(self.device_type, (1, self.world_size))
357+
358+
for mesh in [mesh_0, mesh_1, mesh_2]:
359+
lhs = torch.randn(256, 128)
360+
rhs = torch.randn(128, 256)
361+
mm_result = lhs @ rhs
362+
363+
lhs_dtensor = distribute_tensor(lhs, mesh, [Shard(dim=0), Replicate()])
364+
rhs_dtensor = distribute_tensor(rhs, mesh, [Replicate(), Shard(dim=1)])
365+
dtensor_result = lhs_dtensor @ rhs_dtensor
366+
self.assertEqual(dtensor_result.full_tensor(), mm_result)
367+
342368

343369
if __name__ == "__main__":
344370
run_tests()

torch/distributed/tensor/_ops/_einsum_strategy.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,6 @@ def gen_einsum_strategies(
107107
placement_list: List[Placement] = [Replicate()] * (len(input_dims) + 1)
108108
mesh_dim_strategies.append(placement_list)
109109

110-
if mesh.size(mesh_dim) <= 1:
111-
# only replicate strategy for mesh dim with size 1
112-
# TODO: see if this is valid for the submesh case
113-
continue
114-
115110
# split batch dim
116111
for batch_dim in edims.batch_dims:
117112
output_batch_dim = output_dim.index(batch_dim)

0 commit comments

Comments
 (0)