Skip to content

Commit ace2b4f

Browse files
wanchaolfacebook-github-bot
authored andcommitted
[resubmit] try to infer rref type from python (pytorch#33992)
Summary: Pull Request resolved: pytorch#33992 resubmit of pytorch#33369 with tweaks on when the rref type being created to ensure ivalue->type() hold the correct RRef type inside of inner element type. Test Plan: Imported from OSS Differential Revision: D20175043 Pulled By: wanchaol fbshipit-source-id: a08b178e989c995632374e6c868d23c5a85526ae
1 parent 7747fe8 commit ace2b4f

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ TypePtr IValue::type() const {
4848
case Tag::Future:
4949
return toFuture()->type();
5050
case Tag::RRef:
51-
return toRRef()->type();
51+
return RRefType::create(toRRef()->type());
5252
case Tag::Device:
5353
return DeviceObjType::get();
5454
case Tag::Object:

torch/csrc/jit/python/pybind_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,11 @@ inline InferredType tryToInferType(py::handle input) {
177177
if (py::isinstance<script::Object>(input)) {
178178
auto object = py::cast<script::Object>(input);
179179
return InferredType(object.type());
180+
#ifdef USE_DISTRIBUTED
181+
} else if (py::isinstance<torch::distributed::rpc::PyRRef>(input)) {
182+
auto rref_ivalue = input.cast<torch::distributed::rpc::PyRRef>().toIValue();
183+
return InferredType(rref_ivalue.type());
184+
#endif
180185
}
181186

182187
// Try container types

torch/testing/_internal/distributed/rpc/jit/rpc_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,23 @@ def one_arg(value):
1717
return value + 1
1818

1919

20+
class MyScriptModuleWithRRefs(torch.jit.ScriptModule):
21+
def __init__(self, dst_worker):
22+
super().__init__()
23+
self.rrefs = []
24+
for _ in range(4):
25+
self.rrefs.append(rpc_return_rref(dst_worker))
26+
27+
@torch.jit.script_method
28+
def forward(self):
29+
# type: () -> Tensor
30+
res_tensor = torch.ones(2, 2)
31+
for rref in self.rrefs:
32+
res_tensor += rref.to_here()
33+
34+
return res_tensor
35+
36+
2037
@torch.jit.script
2138
class MyScriptClass:
2239
def __init__(self):
@@ -205,6 +222,15 @@ def rref_tensor_is_owner(rref_var):
205222
res = rref_tensor_is_owner(rref_var)
206223
self.assertEqual(res, False)
207224

225+
@dist_init
226+
def test_my_script_module_with_rrefs(self):
227+
n = self.rank + 1
228+
dst_rank = n % self.world_size
229+
230+
module_with_rrefs = MyScriptModuleWithRRefs("worker{}".format(dst_rank))
231+
res = module_with_rrefs()
232+
self.assertEqual(res, torch.ones(2, 2) * 9)
233+
208234
@dist_init
209235
def test_rref_python_annotation(self):
210236
n = self.rank + 1

0 commit comments

Comments
 (0)