Skip to content

Commit c5c9dbe

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][user-defined] Simplify and improve scope of UserDefinedObject var_getattr (pytorch#130169)
Fixes pytorch#122649 Pull Request resolved: pytorch#130169 Approved by: https://github.com/jansel ghstack dependencies: pytorch#118448, pytorch#130159
1 parent d0ad13f commit c5c9dbe

9 files changed

+40
-73
lines changed

test/dynamo/test_export.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3428,8 +3428,7 @@ def f_mismatch_return_length(x):
34283428

34293429
example_inputs = (torch.rand(5),)
34303430
with self.assertRaisesRegex(
3431-
torch._dynamo.exc.UncapturedHigherOrderOpError,
3432-
"Cond doesn't work unless it is captured completely with torch.compile",
3431+
RuntimeError, "Unmatched number of outputs from cond"
34333432
):
34343433
torch._dynamo.export(
34353434
f_mismatch_return_length,

test/dynamo_expected_failures/TestTensorProtoSummary.test_half_tensor_proto_bfloat16_proto_type_14

Whitespace-only changes.

test/dynamo_expected_failures/TestTensorProtoSummary.test_half_tensor_proto_float16_proto_type_19

Whitespace-only changes.

test/functorch/test_control_flow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,8 +1411,8 @@ def f(x, y):
14111411

14121412
x = torch.randn(4)
14131413
with self.assertRaisesRegex(
1414-
torch._dynamo.exc.UncapturedHigherOrderOpError,
1415-
"Cond doesn't work unless it is captured completely with torch.compile",
1414+
torch._dynamo.exc.CondOpArgsMismatchError,
1415+
"Expected to return same number of outputs but got:",
14161416
):
14171417
make_fx(f)(x, torch.tensor(False))
14181418

@@ -1584,8 +1584,8 @@ def f(x, y):
15841584

15851585
x = torch.randn(4)
15861586
with self.assertRaisesRegex(
1587-
torch._dynamo.exc.UncapturedHigherOrderOpError,
1588-
"Cond doesn't work unless it is captured completely with torch.compile",
1587+
torch._dynamo.exc.CondOpArgsMismatchError,
1588+
"Expected to return same number of outputs but got:",
15891589
):
15901590
make_fx(f, tracing_mode="fake")(x, torch.tensor(False))
15911591

test/test_tensorboard.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
run_tests,
4646
TEST_WITH_CROSSREF,
4747
TestCase,
48+
skipIfTorchDynamo,
4849
)
4950

5051

@@ -760,6 +761,7 @@ class TestTensorProtoSummary(BaseTestCase):
760761
(torch.bfloat16, DataType.DT_BFLOAT16),
761762
],
762763
)
764+
@skipIfTorchDynamo("Unsuitable test for Dynamo, behavior changes with version")
763765
def test_half_tensor_proto(self, tensor_type, proto_type):
764766
float_values = [1.0, 2.0, 3.0]
765767
actual_proto = tensor_proto(

torch/_dynamo/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,16 @@ def is_wrapper_or_member_descriptor(value):
587587
return isinstance(
588588
value,
589589
(
590-
types.MethodWrapperType,
590+
# set up by PyGetSetDef
591+
types.GetSetDescriptorType,
592+
# set by PyMethodDef, e.g. list.append
593+
types.MethodDescriptorType,
594+
# slots - list.__add__
591595
types.WrapperDescriptorType,
596+
# set up by PyMemberDef
592597
types.MemberDescriptorType,
598+
# wrapper over C functions
599+
types.MethodWrapperType,
593600
),
594601
)
595602

torch/_dynamo/variables/builder.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2538,6 +2538,22 @@ def make_type_handlers():
25382538
handlers[immutable_list] = handlers[list]
25392539
handlers[types.ModuleType] = lambda tx, value: PythonModuleVariable(value)
25402540

2541+
handlers[
2542+
torch.distributions.constraints._Real
2543+
] = lambda tx, value: UserDefinedObjectVariable(
2544+
value, mutable_local=MutableLocal()
2545+
)
2546+
handlers[
2547+
torch.distributions.constraints._Interval
2548+
] = lambda tx, value: UserDefinedObjectVariable(
2549+
value, mutable_local=MutableLocal()
2550+
)
2551+
handlers[
2552+
torch.distributions.constraints.Constraint
2553+
] = lambda tx, value: UserDefinedObjectVariable(
2554+
value, mutable_local=MutableLocal()
2555+
)
2556+
25412557
def passthrough(tx, value):
25422558
return value
25432559

torch/_dynamo/variables/user_defined.py

Lines changed: 9 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
import contextlib
55
import enum
66
import functools
7-
import importlib
87
import inspect
98
import itertools
109
import random
11-
import re
1210
import sys
1311
import threading
1412
import types
@@ -45,7 +43,6 @@
4543
WeakRefCallSource,
4644
)
4745
from ..utils import (
48-
all_hook_names,
4946
build_checkpoint_variable,
5047
check_constant_args,
5148
get_custom_getattr,
@@ -830,7 +827,6 @@ def is_supported_nn_module_method(self, method):
830827
def var_getattr(self, tx, name):
831828
from .. import trace_rules
832829
from . import ConstantVariable
833-
from .builder import VariableBuilder
834830

835831
value = self.value
836832
source = AttrSource(self.source, name) if self.source else None
@@ -843,6 +839,13 @@ def var_getattr(self, tx, name):
843839
options = {"source": source}
844840
return variables.GetAttrVariable(self, name, **options)
845841

842+
# TODO(anijain2305) - Investigate if we need specialization for more
843+
# dunder attrs. inspect.getattr_static does not return correct value for
844+
# them.
845+
if name == "__class__":
846+
options = {"source": source}
847+
return UserDefinedClassVariable(type(self.value), **options)
848+
846849
try:
847850
subobj = self._getattr_static(name)
848851
except AttributeError:
@@ -958,75 +961,15 @@ def var_getattr(self, tx, name):
958961
else:
959962
return trace_rules.lookup(func)(func)
960963

961-
if (
962-
name in getattr(value, "__dict__", {})
963-
or ConstantVariable.is_literal(subobj)
964-
or isinstance(
965-
subobj,
966-
(
967-
torch.Tensor,
968-
torch.nn.Module,
969-
re.Pattern,
970-
),
971-
)
972-
):
964+
if subobj is not NO_SUCH_SUBOBJ and not is_wrapper_or_member_descriptor(subobj):
973965
if source:
974966
return variables.LazyVariableTracker.create(subobj, source)
975-
elif ConstantVariable.is_literal(subobj):
976-
return ConstantVariable.create(subobj)
977-
elif (
978-
type(subobj) == torch.utils._pytree.TreeSpec
979-
or type(subobj) == torch.utils._pytree.LeafSpec
980-
or type(value) == torch.utils._pytree.TreeSpec
981-
):
967+
else:
982968
from .builder import SourcelessBuilder
983969

984970
return SourcelessBuilder.create(tx, subobj)
985971

986-
if (
987-
subobj is not NO_SUCH_SUBOBJ
988-
and name not in getattr(value, "__dict__", {})
989-
and (
990-
type(value).__module__.startswith("torch.")
991-
or isinstance(subobj, re.Pattern)
992-
)
993-
and "torch.optim" not in type(value).__module__
994-
and not callable(value)
995-
and not isinstance(subobj, types.MethodDescriptorType)
996-
):
997-
if not source:
998-
assert getattr(
999-
importlib.import_module(type(value).__module__),
1000-
type(value).__name__,
1001-
) is type(value)
1002-
source = AttrSource(
1003-
AttrSource(
1004-
tx.import_source(type(value).__module__), type(value).__name__
1005-
),
1006-
name,
1007-
)
1008-
1009-
return VariableBuilder(tx, source)(subobj)
1010972
options = {"source": source}
1011-
if isinstance(
1012-
subobj,
1013-
(
1014-
torch.distributions.constraints._Interval,
1015-
torch.distributions.constraints._Real,
1016-
torch.distributions.constraints.Constraint,
1017-
),
1018-
):
1019-
return UserDefinedObjectVariable(subobj, **options)
1020-
elif isinstance(self.value, torch.nn.Module) and name in all_hook_names:
1021-
assert isinstance(subobj, collections.OrderedDict)
1022-
if not subobj:
1023-
return variables.ConstDictVariable(
1024-
subobj, collections.OrderedDict, **options
1025-
)
1026-
1027-
if name == "__class__":
1028-
return UserDefinedClassVariable(type(self.value), **options)
1029-
1030973
return variables.GetAttrVariable(self, name, **options)
1031974

1032975
def call_hasattr(self, tx, name: str) -> "VariableTracker":

0 commit comments

Comments
 (0)