Skip to content

Commit 9bcd60f

Browse files
Wanchao Liangpytorchmergebot
Wanchao Liang
authored andcommitted
[shard] ShardedTensor Interface (pytorch#74695)
Summary: Pull Request resolved: pytorch#74695 ShardedTensor Interface: 1. Make a ShardedTensorInterface class that is a subclass of torch.Tensor, define basic APIs in ShardedTensorInterface, ShardedTensor is now a subclass of it. 2. By default disable `__torch_dispatch__` in the ShardedTensorInterface, ShardedTensor will use `__torch_function__` for now, subclasses of ShardedTensorInterface can use `__torch_dispatch__` by overriding it 3. remove attribute functions in ShardedTensor and handled them in `__torch_function__` ghstack-source-id: 155141823 (Note: this ignores all push blocking failures!) Reviewed By: pritamdamania87, fduwjj Differential Revision: D35123200 fbshipit-source-id: 04ad48ae373e6f61d48bb3bc83021e97b0721362 (cherry picked from commit 71ad555)
1 parent a240d45 commit 9bcd60f

File tree

9 files changed

+202
-138
lines changed

9 files changed

+202
-138
lines changed

test/distributed/_shard/sharded_tensor/ops/test_linear.py

+1
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def test_sharded_linear_rowwise(self):
192192
def test_sharded_linear_errors(self):
193193
for spec in generate_chunk_sharding_specs_for_test(0):
194194
fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
195+
shard_parameter(fc1, "weight", spec)
195196
shard_parameter(fc1, "bias", spec)
196197
with self.assertRaisesRegex(TypeError, 'bias needs to be torch.Tensor'):
197198
fc1(torch.rand(10, 10).cuda(self.rank))

test/distributed/_shard/sharded_tensor/test_sharded_tensor.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def test_sharded_tensor_metadata(self):
421421
st = sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
422422
st_metadata = st.metadata()
423423
self.assertEqual(torch.Size([10, 20]), st_metadata.size)
424+
self.assertEqual(torch.Size([10, 20]), st.size())
424425
self.assertEqual(torch.float, st.dtype)
425426
self.assertEqual(torch.strided, st.layout)
426427
self.assertEqual(False, st.requires_grad)
@@ -449,7 +450,7 @@ def test_sharded_tensor_metadata(self):
449450

450451
# test read only properties, they're read only as we can't simply change
451452
# the global metadata without changing the underlying shard's properties
452-
with self.assertRaisesRegex(AttributeError, "can't set attribute"):
453+
with self.assertRaisesRegex(RuntimeError, "torch function '__set__'"):
453454
st.requires_grad = True
454455

455456
@with_comms
@@ -908,7 +909,7 @@ def test_invalid_sharding(self):
908909

909910
spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
910911
with self.assertRaisesRegex(ValueError, 'Only torch.strided layout is currently supported'):
911-
sharded_tensor.empty(spec, 10, 20, layout=torch.sparse)
912+
sharded_tensor.empty(spec, 10, 20, layout=torch.sparse_coo)
912913

913914
spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
914915
with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'):
@@ -1025,11 +1026,17 @@ def test_sharded_tensor_sizes(self):
10251026
st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
10261027
self.assertEqual(st.size(1), 20)
10271028

1029+
# Test with negative indexed size
1030+
st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
1031+
self.assertEqual(st.size(-1), 20)
1032+
1033+
# Test with dim/ndim
1034+
self.assertEqual(st.dim(), 2)
1035+
self.assertEqual(st.ndim, 2)
1036+
10281037
# Test with invalid input
10291038
st = sharded_tensor.empty(spec, (10, 20), init_rrefs=True)
1030-
with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[0, 2\\)'):
1031-
st.size(-1)
1032-
with self.assertRaisesRegex(ValueError, 'must be within the range of tensor dimensions \\[0, 2\\)'):
1039+
with self.assertRaisesRegex(IndexError, 'Dimension out of range'):
10331040
st.size(2)
10341041

10351042
with self.assertRaises(TypeError):
@@ -1493,15 +1500,15 @@ def test_sharded_tensor_to_cpu(self):
14931500
# CPU sharded tensor should return the same instance (no copy)
14941501
st_cpu = sharded_tensor.zeros(cpu_spec, h, w, process_group=gloo_pg)
14951502
new_st_cpu = st_cpu.cpu()
1496-
self.assertEqual(st_cpu, new_st_cpu)
1503+
self.assertTrue(st_cpu is new_st_cpu)
14971504

14981505
# GPU sharded tensor to cpu
14991506
st = sharded_tensor.zeros(spec, h, w)
15001507
# test ability to move st to CPU
15011508
spec_before_move = st.sharding_spec()
15021509
new_st = st.cpu(process_group=gloo_pg)
15031510
# return a copy of orginal st
1504-
self.assertNotEqual(st, new_st)
1511+
self.assertFalse(st is new_st)
15051512
# check the spec is still ChunkShardingSpec
15061513
spec_after_move = new_st.sharding_spec()
15071514
self.assertIsInstance(spec_after_move, ChunkShardingSpec)
@@ -1534,7 +1541,7 @@ def test_sharded_tensor_to_cpu(self):
15341541
st = sharded_tensor.zeros(mixed_spec, h, w, process_group=gloo_pg)
15351542
new_st = st.cpu()
15361543
# return a copy of orginal st
1537-
self.assertNotEqual(st, new_st)
1544+
self.assertFalse(st is new_st)
15381545
# check the spec is still ChunkShardingSpec
15391546
spec_after_move = new_st.sharding_spec()
15401547
self.assertIsInstance(spec_after_move, ChunkShardingSpec)

torch/distributed/_shard/sharded_tensor/_ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
22
import torch.distributed._shard.sharded_tensor._ops.math_ops
3+
import torch.distributed._shard.sharded_tensor._ops.default_tensor_ops
34

45
from .binary_cmp import equal, allclose
56
from .embedding import sharded_embedding
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
from torch.distributed._shard.sharded_tensor import (
3+
sharded_op_impl,
4+
)
5+
6+
7+
def register_default_op(op):
8+
@sharded_op_impl(op)
9+
def tensor_default_op(types, args=(), kwargs=None, pg=None):
10+
"""
11+
Handles ``__torch_function__`` dispatch for the default tensor ops that
12+
behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
13+
``torch.Tensor.dtype``. We simply lower to the real op call with
14+
DisableTorchFunction context like ``torch.Tensor.__torch_function__``
15+
to avoid recursions.
16+
"""
17+
if kwargs is None:
18+
kwargs = {}
19+
20+
with torch._C.DisableTorchFunction():
21+
return op(*args, **kwargs)
22+
23+
# Tensor properties access
24+
register_default_op(torch.Tensor.requires_grad.__get__) # type: ignore[attr-defined]
25+
register_default_op(torch.Tensor.shape.__get__) # type: ignore[attr-defined]
26+
register_default_op(torch.Tensor.dtype.__get__) # type: ignore[attr-defined]
27+
register_default_op(torch.Tensor.layout.__get__) # type: ignore[attr-defined]
28+
register_default_op(torch.Tensor.size)
29+
register_default_op(torch.Tensor.dim)
30+
register_default_op(torch.Tensor.ndim.__get__) # type: ignore[attr-defined]
31+
register_default_op(torch.Tensor.is_contiguous)
32+
33+
# __reduce_ex__ to dispatch to get_state/set_state
34+
register_default_op(torch.Tensor.__reduce_ex__)

torch/distributed/_shard/sharded_tensor/_ops/embedding.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# coding=utf-8
22

3-
from typing import cast
4-
53
import torch
64
import torch.distributed as dist
75
from ._common import (
@@ -158,7 +156,7 @@ def _validate_embedding_param(args, kwargs):
158156
raise TypeError("input need to be torch.Tensor")
159157
if not isinstance(weight, ShardedTensor):
160158
raise TypeError("weight needs to be ShardedTensor")
161-
weight_size = cast(torch.Size, weight.size())
159+
weight_size = weight.size()
162160
if len(weight_size) != 2:
163161
raise ValueError("Weight needs to have exactly 2 dims")
164162
if int(torch.min(input).item()) < 0:

torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def _validate_embedding_bag_param(args, kwargs):
204204
raise TypeError("weight needs to be ShardedTensor")
205205
if len(input.size()) > 2:
206206
raise ValueError("Input more than 2 dims not supported")
207-
weight_size = cast(torch.Size, weight.size())
207+
weight_size = weight.size()
208208
if len(weight_size) != 2:
209209
raise ValueError("Weight needs to have exactly 2 dims")
210210
if int(torch.min(input).item()) < 0:

torch/distributed/_shard/sharded_tensor/_ops/linear.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, cast
1+
from typing import List
22

33
import torch
44
import torch.distributed as dist
@@ -105,14 +105,14 @@ def sharded_linear(types, args, kwargs, pg):
105105
world_size = dist.get_world_size(pg)
106106
rank = dist.get_rank(pg)
107107

108-
if sharding_dim == 1 and isinstance(input, torch.Tensor):
109-
return _handle_row_wise_sharding_tensor(
110-
input, world_size, weight, rank, local_shard_t, bias, pg
111-
)
112-
elif sharding_dim == 1 and isinstance(input, ShardedTensor):
108+
if sharding_dim == 1 and isinstance(input, ShardedTensor):
113109
return _handle_row_wise_sharding_sharded_tensor(
114110
input, world_size, weight, local_shard_t, bias, pg
115111
)
112+
elif sharding_dim == 1 and isinstance(input, torch.Tensor):
113+
return _handle_row_wise_sharding_tensor(
114+
input, world_size, weight, rank, local_shard_t, bias, pg
115+
)
116116
elif sharding_dim == 0:
117117
return _handle_col_wise_sharding(
118118
input, world_size, weight, rank, local_shard_t, bias, pg
@@ -125,7 +125,7 @@ def sharded_linear(types, args, kwargs, pg):
125125

126126
def _validate_linear_op_param(args, kwargs):
127127
"""
128-
Validate input params of sharded embedding op.
128+
Validate input params of sharded linear op.
129129
130130
Args:
131131
input: input of the linear layer.
@@ -141,13 +141,13 @@ def _validate_linear_op_param(args, kwargs):
141141
# Validate types
142142
if not isinstance(input, torch.Tensor) and not isinstance(input, ShardedTensor):
143143
raise TypeError("input needs to be either torch.Tensor or ShardedTensor")
144-
if not isinstance(bias, torch.Tensor):
144+
if type(bias) != torch.Tensor and type(bias) != torch.nn.Parameter:
145145
raise TypeError("bias needs to be torch.Tensor")
146146
if not isinstance(weight, ShardedTensor):
147147
raise TypeError("weight needs to be ShardedTensor")
148148
if len(input.size()) < 1: # type: ignore[arg-type]
149149
raise ValueError("Input needs to have at least 1 dim")
150-
weight_size = cast(torch.Size, weight.size())
150+
weight_size = weight.size()
151151
if len(weight_size) != 2:
152152
raise ValueError("Weight needs to have exactly 2 dims")
153153
if len(bias.size()) != 1:

0 commit comments

Comments
 (0)