Skip to content

Commit 7815ebe

Browse files
committed
Use defaultdict in graph/rewriting/basic.py
1 parent 87b51d6 commit 7815ebe

File tree

1 file changed

+15
-41
lines changed

1 file changed

+15
-41
lines changed

pytensor/graph/rewriting/basic.py

+15-41
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import time
1111
import traceback
1212
import warnings
13-
from collections import UserList, defaultdict, deque
13+
from collections import Counter, UserList, defaultdict, deque
1414
from collections.abc import Callable, Iterable, Sequence
1515
from collections.abc import Iterable as IterableType
1616
from functools import _compose_mro, partial, reduce # type: ignore
@@ -1153,8 +1153,8 @@ class OpToRewriterTracker:
11531153
r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance."""
11541154

11551155
def __init__(self):
1156-
self.tracked_instances: dict[Op, list[NodeRewriter]] = {}
1157-
self.tracked_types: dict[type, list[NodeRewriter]] = {}
1156+
self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list)
1157+
self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list)
11581158
self.untracked_rewrites: list[NodeRewriter] = []
11591159

11601160
def add_tracker(self, rw: NodeRewriter):
@@ -1166,9 +1166,9 @@ def add_tracker(self, rw: NodeRewriter):
11661166
else:
11671167
for c in tracks:
11681168
if isinstance(c, type):
1169-
self.tracked_types.setdefault(c, []).append(rw)
1169+
self.tracked_types[c].append(rw)
11701170
else:
1171-
self.tracked_instances.setdefault(c, []).append(rw)
1171+
self.tracked_instances[c].append(rw)
11721172

11731173
def _find_impl(self, cls) -> list[NodeRewriter]:
11741174
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
@@ -1250,22 +1250,16 @@ def __init__(
12501250

12511251
self.profile = profile
12521252
if self.profile:
1253-
self.time_rewrites: dict[Rewriter, float] = {}
1254-
self.process_count: dict[Rewriter, int] = {}
1255-
self.applied_true: dict[Rewriter, int] = {}
1256-
self.node_created: dict[Rewriter, int] = {}
1253+
self.time_rewrites: dict[Rewriter, float] = defaultdict(float)
1254+
self.process_count: dict[Rewriter, int] = Counter()
1255+
self.applied_true: dict[Rewriter, int] = Counter()
1256+
self.node_created: dict[Rewriter, int] = Counter()
12571257

12581258
self.tracker = OpToRewriterTracker()
12591259

12601260
for o in self.rewrites:
12611261
self.tracker.add_tracker(o)
12621262

1263-
if self.profile:
1264-
self.time_rewrites.setdefault(o, 0.0)
1265-
self.process_count.setdefault(o, 0)
1266-
self.applied_true.setdefault(o, 0)
1267-
self.node_created.setdefault(o, 0)
1268-
12691263
def __str__(self):
12701264
return getattr(
12711265
self,
@@ -2316,30 +2310,21 @@ def apply(self, fgraph, start_from=None):
23162310
changed = True
23172311
max_use_abort = False
23182312
rewriter_name = None
2319-
global_process_count = {}
2313+
global_process_count = Counter()
23202314
start_nb_nodes = len(fgraph.apply_nodes)
23212315
max_nb_nodes = len(fgraph.apply_nodes)
23222316
max_use = max_nb_nodes * self.max_use_ratio
23232317

23242318
loop_timing = []
23252319
loop_process_count = []
23262320
global_rewriter_timing = []
2327-
time_rewriters = {}
2321+
time_rewriters = defaultdict(float)
23282322
io_toposort_timing = []
23292323
nb_nodes = []
2330-
node_created = {}
2324+
node_created = Counter()
23312325
global_sub_profs = []
23322326
final_sub_profs = []
23332327
cleanup_sub_profs = []
2334-
for rewriter in (
2335-
self.global_rewriters
2336-
+ list(self.get_node_rewriters())
2337-
+ self.final_rewriters
2338-
+ self.cleanup_rewriters
2339-
):
2340-
global_process_count.setdefault(rewriter, 0)
2341-
time_rewriters.setdefault(rewriter, 0)
2342-
node_created.setdefault(rewriter, 0)
23432328

23442329
def apply_cleanup(profs_dict):
23452330
changed = False
@@ -2351,15 +2336,14 @@ def apply_cleanup(profs_dict):
23512336
time_rewriters[crewriter] += time.perf_counter() - t_rewrite
23522337
profs_dict[crewriter].append(sub_prof)
23532338
if change_tracker.changed:
2354-
process_count.setdefault(crewriter, 0)
23552339
process_count[crewriter] += 1
23562340
global_process_count[crewriter] += 1
23572341
changed = True
23582342
node_created[crewriter] += change_tracker.nb_imported - nb
23592343
return changed
23602344

23612345
while changed and not max_use_abort:
2362-
process_count = {}
2346+
process_count = Counter()
23632347
t0 = time.perf_counter()
23642348
changed = False
23652349
iter_cleanup_sub_profs = {}
@@ -2376,7 +2360,6 @@ def apply_cleanup(profs_dict):
23762360
time_rewriters[grewrite] += time.perf_counter() - t_rewrite
23772361
sub_profs.append(sub_prof)
23782362
if change_tracker.changed:
2379-
process_count.setdefault(grewrite, 0)
23802363
process_count[grewrite] += 1
23812364
global_process_count[grewrite] += 1
23822365
changed = True
@@ -2431,7 +2414,6 @@ def chin_(node, i, r, new_r, reason):
24312414
time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite
24322415
if not node_rewriter_change:
24332416
continue
2434-
process_count.setdefault(node_rewriter, 0)
24352417
process_count[node_rewriter] += 1
24362418
global_process_count[node_rewriter] += 1
24372419
changed = True
@@ -2459,7 +2441,6 @@ def chin_(node, i, r, new_r, reason):
24592441
time_rewriters[grewrite] += time.perf_counter() - t_rewrite
24602442
sub_profs.append(sub_prof)
24612443
if change_tracker.changed:
2462-
process_count.setdefault(grewrite, 0)
24632444
process_count[grewrite] += 1
24642445
global_process_count[grewrite] += 1
24652446
changed = True
@@ -2514,7 +2495,7 @@ def chin_(node, i, r, new_r, reason):
25142495
(start_nb_nodes, end_nb_nodes, max_nb_nodes),
25152496
global_rewriter_timing,
25162497
nb_nodes,
2517-
time_rewriters,
2498+
dict(time_rewriters),
25182499
io_toposort_timing,
25192500
node_created,
25202501
global_sub_profs,
@@ -2597,14 +2578,7 @@ def print_profile(cls, stream, prof, level=0):
25972578
count_rewrite = []
25982579
not_used = []
25992580
not_used_time = 0
2600-
process_count = {}
2601-
for o in (
2602-
rewrite.global_rewriters
2603-
+ list(rewrite.get_node_rewriters())
2604-
+ list(rewrite.final_rewriters)
2605-
+ list(rewrite.cleanup_rewriters)
2606-
):
2607-
process_count.setdefault(o, 0)
2581+
process_count = Counter()
26082582
for count in loop_process_count:
26092583
for o, v in count.items():
26102584
process_count[o] += v

0 commit comments

Comments
 (0)