Skip to content

Commit 3f248a5

Browse files
laithsakkapytorchmergebot
authored andcommitted
Classify miss-inplaced tensors in logs. (pytorch#139240)
Summary: use signpost logs, a followup is to remove the field possibly_missed_reinplacing_opportunities form dynamo compile table. Differential Revision: D65180194 Pull Request resolved: pytorch#139240 Approved by: https://github.com/zou3519
1 parent e947649 commit 3f248a5

File tree

4 files changed

+86
-50
lines changed

4 files changed

+86
-50
lines changed

test/inductor/test_inplacing_pass.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch._inductor.config as inductor_config
77
from functorch import make_fx
88
from torch import Tensor
9-
from torch._dynamo.utils import counters
9+
from torch._dynamo.utils import ReinplaceCounters
1010
from torch._higher_order_ops.auto_functionalize import (
1111
auto_functionalized,
1212
auto_functionalized_v2,
@@ -31,11 +31,11 @@
3131

3232

3333
def num_reinplacing_failures():
34-
return counters["inductor"]["possibly_missed_reinplacing_opportunities"]
34+
return ReinplaceCounters.get_total_missed()
3535

3636

3737
def miss_inplaced_bytes():
38-
return counters["inductor"]["possibly_missed_reinplacing_bytes"]
38+
return ReinplaceCounters.get_total_missed_bytes()
3939

4040

4141
@torch.library.custom_op("_reinplacing::sin", mutates_args={"result"})
@@ -85,7 +85,7 @@ def boo(x: torch.Tensor) -> None:
8585

8686
class TestReinplacingPassCorrectness(InductorTestCase):
8787
def setUp(self):
88-
counters.clear()
88+
ReinplaceCounters.clear()
8989
return super().setUp()
9090

9191
def _test(self, f):
@@ -138,7 +138,7 @@ def f(x, y):
138138
self._test(f)
139139

140140
def test_counters_functionalize_old(self):
141-
counters.clear()
141+
ReinplaceCounters.clear()
142142

143143
def f(x):
144144
out = torch.empty_like(x)
@@ -158,7 +158,7 @@ def f(x):
158158
self.assertEqual(miss_inplaced_bytes(), 12)
159159

160160
def test_counters_functionalize_v2(self):
161-
counters.clear()
161+
ReinplaceCounters.clear()
162162

163163
def f(x):
164164
out = torch.empty_like(x)
@@ -314,7 +314,7 @@ def test_multi_output_intermediate(self):
314314
with inductor_config.patch(
315315
{"enable_auto_functionalized_v2": enable_v2}
316316
):
317-
counters.clear()
317+
ReinplaceCounters.clear()
318318

319319
def f(x):
320320
out1 = torch.empty_like(x)
@@ -329,7 +329,7 @@ def f(x):
329329
self.assertEqual(num_reinplacing_failures(), 0)
330330

331331
def test_multiple_mutations(self):
332-
counters.clear()
332+
ReinplaceCounters.clear()
333333

334334
def f(x, out):
335335
sin(x, out)
@@ -345,7 +345,7 @@ def f(x, out):
345345
self.assertEqual(num_reinplacing_failures(), 0)
346346

347347
def test_multiple_intermediate(self):
348-
counters.clear()
348+
ReinplaceCounters.clear()
349349

350350
def f(x):
351351
out = torch.empty_like(x)

torch/_dynamo/convert_frame.py

+3-26
Original file line numberDiff line numberDiff line change
@@ -971,12 +971,7 @@ def format_guard_failures() -> str:
971971
fail_reason: Optional[str] = None
972972
fail_user_frame_filename: Optional[str] = None
973973
fail_user_frame_lineno: Optional[int] = None
974-
start_possibly_missed_reinplacing_opportunities = torch._dynamo.utils.counters[
975-
"inductor"
976-
]["possibly_missed_reinplacing_opportunities"]
977-
start_possibly_missed_reinplacing_bytes = torch._dynamo.utils.counters[
978-
"inductor"
979-
]["start_possibly_missed_reinplacing_bytes"]
974+
torch._dynamo.utils.ReinplaceCounters.clear()
980975
guarded_code = None
981976
try:
982977
guarded_code = compile_inner(code, one_graph, hooks, transform)
@@ -1054,33 +1049,17 @@ def format_guard_failures() -> str:
10541049
compliant_custom_ops = {
10551050
op.__qualname__ for op in output.compliant_custom_ops
10561051
}
1057-
possibly_missed_reinplacing_opportunities = (
1058-
torch._dynamo.utils.counters["inductor"][
1059-
"possibly_missed_reinplacing_opportunities"
1060-
]
1061-
- start_possibly_missed_reinplacing_opportunities
1062-
)
10631052
remote_cache_time_saved = frame_phase_timing[frame_key].get(
10641053
"remote_cache_time_saved", 0
10651054
)
1066-
possibly_missed_reinplacing_bytes = (
1067-
torch._dynamo.utils.counters["inductor"][
1068-
"possibly_missed_reinplacing_bytes"
1069-
]
1070-
- start_possibly_missed_reinplacing_bytes
1071-
)
1072-
if possibly_missed_reinplacing_bytes != 0:
1073-
signpost_event(
1074-
"inductor",
1075-
"auto_functionalize",
1076-
{"missed_reinplacing_bytes": possibly_missed_reinplacing_bytes},
1077-
)
10781055
remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get(
10791056
"remote_fx_graph_cache_get", None
10801057
)
10811058
remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get(
10821059
"remote_fx_graph_cache_put", None
10831060
)
1061+
torch._dynamo.utils.ReinplaceCounters.log()
1062+
10841063
else:
10851064
guard_count = None
10861065
shape_env_guard_count = None
@@ -1096,7 +1075,6 @@ def format_guard_failures() -> str:
10961075
restart_reasons = set()
10971076
# If compilation failed, the entire time is wasted
10981077
dynamo_time_before_restart = duration_ns / 1e9
1099-
possibly_missed_reinplacing_opportunities = None
11001078
remote_cache_time_saved = None
11011079
remote_fx_graph_cache_get_time = None
11021080
remote_fx_graph_cache_put_time = None
@@ -1161,7 +1139,6 @@ def clean_for_json(d: Dict[str, Any]) -> Dict[str, Any]:
11611139
restart_reasons,
11621140
dynamo_time_before_restart,
11631141
guarded_code is not None,
1164-
possibly_missed_reinplacing_opportunities,
11651142
remote_cache_time_saved,
11661143
structured_logging_overhead_s,
11671144
config.suppress_errors,

torch/_dynamo/utils.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@
7373
from torch._dispatch.python import enable_python_dispatcher
7474
from torch._guards import Source, TracingContext
7575
from torch._subclasses.meta_utils import is_sparse_compressed
76-
from torch._utils_internal import log_chromium_event_internal, log_compilation_event
76+
from torch._utils_internal import (
77+
log_chromium_event_internal,
78+
log_compilation_event,
79+
signpost_event,
80+
)
7781
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
7882
from torch.nn.modules.lazy import LazyModuleMixin
7983
from torch.utils._triton import has_triton, has_triton_package
@@ -143,6 +147,51 @@
143147
timer_counter = itertools.count()
144148

145149

150+
# Abstraction on top of counters.
151+
class ReInplaceTrigger(enum.Enum):
152+
AUTO_FUNC_V1 = 1
153+
AUTO_FUNC_V2 = 2
154+
TRITON_OPS = 3
155+
156+
157+
class ReinplaceCounters:
158+
_values: DefaultDict[str, int] = collections.defaultdict(int)
159+
160+
# Track sizes of known not re-inplaced tensors (exclude dynamic shapes).
161+
@classmethod
162+
def add_missed_bytes(cls, trigger: ReInplaceTrigger, bytes: int):
163+
cls._values[f"missed_bytes_{trigger.name}"] += bytes
164+
165+
# Track number of not re-inplaced tensors.
166+
@classmethod
167+
def add_missed_opportunities(cls, trigger: ReInplaceTrigger, count: int):
168+
cls._values[f"missed_tensors_{trigger}"] += count
169+
170+
@classmethod
171+
def clear(cls):
172+
cls._values.clear()
173+
174+
@classmethod
175+
def get_total_missed(cls):
176+
sum = 0
177+
for trigger in ReInplaceTrigger:
178+
sum += cls._values.get(f"missed_tensors_{trigger}", 0)
179+
return sum
180+
181+
@classmethod
182+
def get_total_missed_bytes(cls):
183+
sum = 0
184+
for trigger in ReInplaceTrigger:
185+
sum += cls._values.get(f"missed_bytes_{trigger.name}", 0)
186+
return sum
187+
188+
@classmethod
189+
def log(cls):
190+
# if not empty log.
191+
if cls._values:
192+
signpost_event("inductor", "reinplace_counters", cls._values)
193+
194+
146195
def tabulate(
147196
rows: Union[List[Tuple[str, object]], List[List[object]]],
148197
headers: Union[Tuple[str, ...], List[str]],
@@ -843,7 +892,6 @@ class CompilationMetrics:
843892
# to install any guarded code. True means we actually decided to install
844893
# a compiled frame
845894
has_guarded_code: Optional[bool] = None
846-
possibly_missed_reinplacing_opportunities: Optional[int] = None
847895
remote_cache_time_saved_s: Optional[float] = None
848896
structured_logging_overhead_s: Optional[float] = None
849897
config_suppress_errors: Optional[bool] = None

torch/_inductor/fx_passes/reinplace.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from torch._dispatch.python import enable_python_dispatcher
11+
from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger
1112
from torch._higher_order_ops.triton_kernel_wrap import (
1213
kernel_side_table,
1314
triton_kernel_wrapper_functional,
@@ -497,7 +498,12 @@ def can_inplace(node, mutated_arg):
497498
)
498499

499500
def log_inplace_results(
500-
node_name, old_tensors_to_clone, tensors_to_clone, missed_args, missed_nodes
501+
node_name,
502+
old_tensors_to_clone,
503+
tensors_to_clone,
504+
missed_args,
505+
missed_nodes,
506+
trigger,
501507
):
502508
# Total size of possibly_missed_reinplacing_opportunities for tensors with static shapes.
503509
missed_bytes = 0
@@ -531,17 +537,14 @@ def bytes(node):
531537
missed_args,
532538
missed_bytes,
533539
)
534-
torch._dynamo.utils.counters["inductor"][
535-
"possibly_missed_reinplacing_opportunities"
536-
] += len(missed_args)
537-
torch._dynamo.utils.counters["inductor"][
538-
"possibly_missed_reinplacing_bytes"
539-
] += missed_bytes
540+
541+
ReinplaceCounters.add_missed_opportunities(trigger, len(missed_args))
542+
ReinplaceCounters.add_missed_bytes(trigger, missed_bytes)
540543

541544
replace_dict: Dict[torch.fx.Node, torch.fx.Node] = {}
542545

543546
def reinplace_and_refine_tensors_to_clone(
544-
old_tensors_to_clone, kwargs, node_name, auto_functionalize_v2=False
547+
old_tensors_to_clone, kwargs, node_name, trigger
545548
):
546549
tensors_to_clone: List[str] = []
547550
storage_of_reinplaced_args = set()
@@ -580,7 +583,7 @@ def tensor_with_same_storage_already_reinplaced(arg):
580583
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node))
581584
if copy_node is not None:
582585
replace_dict[copy_node] = copy_node.args[0]
583-
if not auto_functionalize_v2:
586+
if not trigger == ReInplaceTrigger.AUTO_FUNC_V2:
584587
for user in node.users:
585588
# For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to
586589
# output atindex size(out)+i.
@@ -602,7 +605,12 @@ def tensor_with_same_storage_already_reinplaced(arg):
602605
tensors_to_clone.append(arg)
603606

604607
log_inplace_results(
605-
node_name, old_tensors_to_clone, tensors_to_clone, missed_args, missed_nodes
608+
node_name,
609+
old_tensors_to_clone,
610+
tensors_to_clone,
611+
missed_args,
612+
missed_nodes,
613+
trigger,
606614
)
607615
return tensors_to_clone
608616

@@ -628,7 +636,7 @@ def tensor_with_same_storage_already_reinplaced(arg):
628636
bases_to_clone,
629637
base_tensors_dct,
630638
node.target,
631-
auto_functionalize_v2=True,
639+
ReInplaceTrigger.AUTO_FUNC_V2,
632640
)
633641
# Stash the metadata. There is a pass later on where we decompose
634642
# auto_functionalized into clones + a mutable op; this metadata
@@ -647,7 +655,7 @@ def tensor_with_same_storage_already_reinplaced(arg):
647655
tensors_to_clone,
648656
node.kwargs,
649657
_mutable_op._name,
650-
auto_functionalize_v2=False,
658+
ReInplaceTrigger.AUTO_FUNC_V1,
651659
)
652660

653661
# Stash the metadata. There is a pass later on where we decompose
@@ -679,7 +687,10 @@ def tensor_with_same_storage_already_reinplaced(arg):
679687
# This pass iterates over them and sees which ones are safe
680688
# to eliminate (i.e. no longer need the clones)
681689
tensors_to_clone = reinplace_and_refine_tensors_to_clone(
682-
node.kwargs["tensors_to_clone"], node.kwargs["kwargs"], kernel_name
690+
node.kwargs["tensors_to_clone"],
691+
node.kwargs["kwargs"],
692+
kernel_name,
693+
ReInplaceTrigger.TRITON_OPS,
683694
)
684695

685696
kwargs = dict(node.kwargs)

0 commit comments

Comments
 (0)