Skip to content

Commit d8bdb50

Browse files
wanchaolpytorchmergebot
authored andcommitted
[reland] pass shape/stride during tensor unflatten (pytorch#117340)
Reland of pytorch#113547 as the previous PR reverted bc of torch.compile symbolic shape issue. Since we now disabled tensor unflatten with dynamo.disable, we should not hit this issue again Pull Request resolved: pytorch#117340 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#117336
1 parent eebf115 commit d8bdb50

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

torch/distributed/_tensor/placement_types.py

+6
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,12 @@ def shape(self) -> torch.Size:
457457
raise ValueError("tensor_meta is not set")
458458
return self.tensor_meta.shape
459459

460+
@property
461+
def stride(self) -> Tuple[int, ...]:
462+
if self.tensor_meta is None:
463+
raise ValueError("tensor_meta is not set")
464+
return self.tensor_meta.stride
465+
460466
@property
461467
def ndim(self) -> int:
462468
if self.tensor_meta is None:

torch/distributed/tensor/parallel/_data_parallel_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None):
3636
spec.mesh,
3737
spec.placements,
3838
run_check=False,
39+
shape=spec.shape,
40+
stride=spec.stride,
3941
)
4042
if tensor.requires_grad:
4143
# only register the hook if the tensor requires grad

0 commit comments

Comments
 (0)