Skip to content

Commit 0a16f8d

Browse files
mlazospobin6
authored and
pobin6
committed
[Dynamo] allow dynamic callables on tensor variables (pytorch#137940)
Fixes pytorch#134844 Pull Request resolved: pytorch#137940 Approved by: https://github.com/williamwen42
1 parent bc350b3 commit 0a16f8d

File tree

4 files changed

+63
-25
lines changed

4 files changed

+63
-25
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

+16-15
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,65 @@
1-
add_loop_eager,compile_time_instruction_count,3073000000,0.015
1+
add_loop_eager,compile_time_instruction_count,3077000000,0.015
22

33

44

5-
add_loop_eager_dynamic,compile_time_instruction_count,5700000000,0.025
5+
add_loop_eager_dynamic,compile_time_instruction_count,5719000000,0.025
66

77

88

9-
add_loop_inductor,compile_time_instruction_count,24580000000,0.015
9+
add_loop_inductor,compile_time_instruction_count,24630000000,0.015
1010

1111

1212

13-
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40810000000,0.025
13+
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40910000000,0.025
1414

1515

1616

17-
add_loop_inductor_gpu,compile_time_instruction_count,23290000000,0.015
17+
add_loop_inductor_gpu,compile_time_instruction_count,23330000000,0.015
1818

1919

2020

2121
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1037000000,0.015
2222

2323

2424

25-
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19200000000,0.015
2625

26+
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19210000000,0.015
2727

2828

29-
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15820000000,0.015
3029

30+
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15840000000,0.015
3131

3232

33-
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,16890000000,0.2
3433

34+
basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,16510000000,0.2
3535

3636

37-
update_hint_regression,compile_time_instruction_count,1757000000,0.02
3837

38+
update_hint_regression,compile_time_instruction_count,1753000000,0.02
3939

4040

41-
sum_floordiv_regression,compile_time_instruction_count,1171000000,0.015
4241

42+
sum_floordiv_regression,compile_time_instruction_count,1241000000,0.015
4343

4444

45-
symint_sum,compile_time_instruction_count,3321000000,0.015
4645

46+
symint_sum,compile_time_instruction_count,3331000000,0.015
4747

4848

49-
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2014000000,0.015
5049

50+
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2011000000,0.015
5151

5252

53-
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5826000000,0.015
5453

54+
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5827000000,0.015
5555

5656

57-
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9022000000,0.015
5857

58+
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9054000000,0.015
5959

6060

61-
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3848000000,0.015
61+
62+
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3844000000,0.015
6263

6364

6465

test/dynamo/test_misc.py

+14
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,20 @@ def fn(x):
12011201
inp = torch.ones(2, 2)
12021202
fn(inp)
12031203

1204+
def test_tensor_dynamic_method(self):
1205+
def add_one(x):
1206+
return x + 1
1207+
1208+
t = torch.nn.Parameter(torch.ones(1))
1209+
t.add_one = add_one
1210+
1211+
@torch.compile(fullgraph=True)
1212+
def fn(x):
1213+
return t.add_one(t) + x
1214+
1215+
result = fn(torch.ones(1))
1216+
self.assertEqual(torch.ones(1) + 2, result)
1217+
12041218
def test_shape_unpack(self):
12051219
def fn(x):
12061220
a, b = x.size()

torch/_dynamo/variables/misc.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -963,15 +963,23 @@ def call_function(
963963
class GetAttrVariable(VariableTracker):
964964
_nonvar_fields = {
965965
"name",
966+
"py_type",
966967
*VariableTracker._nonvar_fields,
967968
}
968969

969-
def __init__(self, obj, name, **kwargs) -> None:
970+
def __init__(self, obj, name, py_type=None, **kwargs) -> None:
970971
super().__init__(**kwargs)
971972
assert isinstance(obj, VariableTracker)
972973
assert isinstance(name, str)
973974
self.obj = obj
974975
self.name = name
976+
self.py_type = py_type # In some cases we know the type (ex. tensor methods)
977+
978+
def python_type(self):
979+
if self.py_type is not None:
980+
return self.py_type
981+
else:
982+
super().python_type()
975983

976984
def __repr__(self) -> str:
977985
return f"{self.__class__.__name__}({self.obj}, {self.name})"

torch/_dynamo/variables/tensor.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@
8989
)
9090

9191

92+
def is_bound_tensor_method(value):
93+
return (
94+
callable(value)
95+
and not torch._dynamo.utils.object_has_getattribute(value)
96+
and hasattr(value, "__self__")
97+
and isinstance(value.__self__, torch.Tensor)
98+
and getattr(value.__self__, value.__name__, None)
99+
)
100+
101+
92102
class TensorVariable(VariableTracker):
93103
"""A torch.Tensor input or an intermediate value in the FX graph"""
94104

@@ -273,14 +283,19 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name):
273283
raise NotImplementedError
274284

275285
real_value = getattr(_input_associated_real_value, name)
276-
if callable(real_value):
277-
# Callables have more nuanced handling, and we should let the existing system delegate here.
278-
# Raising was past behavior and so should always be sound to fall back.
279-
# Note - at a certain point we may want to handle
280-
raise NotImplementedError
281286

282287
attr_source = AttrSource(self.source, name)
283288
install_guard(attr_source.make_guard(GuardBuilder.HASATTR))
289+
290+
# Typically we'd want to use variable builder here
291+
# but unfortunately id(real_value.__self__) is not id(<original value>)
292+
if is_bound_tensor_method(real_value):
293+
from .misc import GetAttrVariable
294+
295+
return GetAttrVariable(
296+
self, name, source=attr_source, py_type=type(real_value)
297+
)
298+
284299
return VariableTracker.build(tx, real_value, attr_source)
285300

286301
def method_attr_ndim(self, tx):
@@ -522,16 +537,16 @@ def call_method(
522537
# Only override builtin tensor methods
523538
# The user can manually add override handling
524539
# with a decorator for other methods (e.g. a dispatch subclass with other methods)
525-
has_torch_function_override = False
540+
is_base_tensor_method = False
526541
try:
527542
inspect.getattr_static(torch.Tensor, name)
528-
has_torch_function_override = True
543+
is_base_tensor_method = True
529544
except AttributeError:
530-
has_torch_function_override = False
545+
is_base_tensor_method = False
531546

532547
if (
533548
can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs)
534-
and has_torch_function_override
549+
and is_base_tensor_method
535550
):
536551
if self.source:
537552
func_var = VariableBuilder(

0 commit comments

Comments
 (0)