Skip to content

Commit f144fe9

Browse files
anijain2305facebook-github-bot
authored andcommitted
Simplify disabling of the helper functions on tensor properties (#155259)
Summary: X-link: pytorch/pytorch#155259 Approved by: https://github.com/zhxchen17 Reviewed By: seemethere Differential Revision: D76158167 fbshipit-source-id: ad35724cd12b20b1d21cf777dfaa34d0703842d4
1 parent 68cf9df commit f144fe9

File tree

1 file changed

+9
-25
lines changed
  • userbenchmark/dynamo/dynamobench/_dynamo

1 file changed

+9
-25
lines changed

userbenchmark/dynamo/dynamobench/_dynamo/utils.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4495,38 +4495,22 @@ def does_not_override_dict_iter_methods(user_cls):
44954495
)
44964496

44974497

4498-
# Helper functions below are to prevent __torch_function__
4499-
# calls from happening in the middle of __torch_function__
4500-
# compiled bytecode
4501-
# They will be skipped which is the desired result
4498+
# Helper functions below are to prevent TorchDynamo to prevent tracing of
4499+
# __torch_function__ calls triggered on tensor properties in the pre graph
4500+
# bytecode.
4501+
@torch._disable_dynamo
45024502
def call_size(x, i):
4503-
@torch._dynamo.disable(
4504-
recursive=True, reason="__torch_function__ tracing helper function"
4505-
)
4506-
def fn(x, i):
4507-
return x.size(i)
4508-
4509-
return fn(x, i)
4503+
return x.size(i)
45104504

45114505

4506+
@torch._disable_dynamo
45124507
def call_stride(x, i):
4513-
@torch._dynamo.disable(
4514-
recursive=True, reason="__torch_function__ tracing helper function"
4515-
)
4516-
def fn(x, i):
4517-
return x.stride(i)
4518-
4519-
return fn(x, i)
4508+
return x.stride(i)
45204509

45214510

4511+
@torch._disable_dynamo
45224512
def call_storage_offset(x):
4523-
@torch._dynamo.disable(
4524-
recursive=True, reason="__torch_function__ tracing helper function"
4525-
)
4526-
def fn(x):
4527-
return x.storage_offset()
4528-
4529-
return fn(x)
4513+
return x.storage_offset()
45304514

45314515

45324516
# Helper function to extract relevant parts of a tensor's __dict__ to store in node meta.

0 commit comments

Comments
 (0)