Skip to content

Commit f053be2

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo] Graph break on random_ op (pytorch#130222)
Fixes pytorch#121621 Pull Request resolved: pytorch#130222 Approved by: https://github.com/jansel
1 parent 31bb65d commit f053be2

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

test/dynamo/test_repros.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5293,6 +5293,17 @@ def forward(self):
52935293
# the second call causes a failure
52945294
m()
52955295

5296+
# https://github.com/pytorch/pytorch/issues/121621
5297+
def test_tensor_random(self):
5298+
def random_op(tensor, params):
5299+
res = tensor.random_(**params)
5300+
return res
5301+
5302+
random_op = torch.compile(random_op)
5303+
params = {"from": -10, "to": 10}
5304+
tensor = torch.randn([2, 3])
5305+
res = random_op(tensor, params)
5306+
52965307

52975308
instantiate_parametrized_tests(ReproTests)
52985309

torch/_dynamo/symbolic_convert.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,16 +1495,29 @@ def CALL_FUNCTION_EX(self, inst):
14951495
null = self.pop()
14961496
assert isinstance(null, NullVariable)
14971497

1498-
if (
1499-
isinstance(fn, GetAttrVariable)
1500-
and isinstance(fn.obj, TensorVariable)
1501-
and fn.name == "view"
1502-
and isinstance(argsvars, (ConstantVariable, TensorVariable))
1503-
):
1504-
# Hack to handle special case in some bert models. Converts
1505-
# x.view(*shape) into x.view(shape), which is correct for view()
1506-
# but not generally. See test_transpose_for_scores().
1507-
argsvars = TupleVariable([argsvars])
1498+
if isinstance(fn, GetAttrVariable) and isinstance(fn.obj, TensorVariable):
1499+
# realize is requires for Python 3.8
1500+
kwargsvars = kwargsvars.realize()
1501+
if fn.name == "view" and isinstance(
1502+
argsvars, (ConstantVariable, TensorVariable)
1503+
):
1504+
# Hack to handle special case in some bert models. Converts
1505+
# x.view(*shape) into x.view(shape), which is correct for view()
1506+
# but not generally. See test_transpose_for_scores().
1507+
argsvars = TupleVariable([argsvars])
1508+
elif (
1509+
fn.name == "random_"
1510+
and isinstance(argsvars, TupleVariable)
1511+
and len(argsvars.items) == 0
1512+
and isinstance(kwargsvars, ConstDictVariable)
1513+
and ConstantVariable.create("from") in kwargsvars
1514+
):
1515+
# `from`` is python keyword. Adding random_ with `from` in the
1516+
# Fx graph causes syntax error. Even if we convert the kwargs to
1517+
# args, aot_autograd/inductor while lowering generates
1518+
# aten.random.from, again causing syntax errors. Since this
1519+
# usecase is uncommon, graph break.
1520+
unimplemented("random_ op is called with from keyword")
15081521

15091522
if not isinstance(
15101523
argsvars, BaseListVariable

0 commit comments

Comments
 (0)