Skip to content

Commit f1c3c15

Browse files
committed
Use defaultdict in graph/rewriting/basic.py
1 parent 1977e5a commit f1c3c15

File tree

1 file changed

+23
-41
lines changed

1 file changed

+23
-41
lines changed

pytensor/graph/rewriting/basic.py

+23-41
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import time
1111
import traceback
1212
import warnings
13-
from collections import UserList, defaultdict, deque
13+
from collections import Counter, UserList, defaultdict, deque
1414
from collections.abc import Callable, Iterable, Sequence
1515
from collections.abc import Iterable as IterableType
1616
from functools import _compose_mro, partial, reduce # type: ignore
@@ -1153,8 +1153,8 @@ class OpToRewriterTracker:
11531153
r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance."""
11541154

11551155
def __init__(self):
1156-
self.tracked_instances: dict[Op, list[NodeRewriter]] = {}
1157-
self.tracked_types: dict[type, list[NodeRewriter]] = {}
1156+
self.tracked_instances: dict[Op, list[NodeRewriter]] = defaultdict(list)
1157+
self.tracked_types: dict[type, list[NodeRewriter]] = defaultdict(list)
11581158
self.untracked_rewrites: list[NodeRewriter] = []
11591159

11601160
def add_tracker(self, rw: NodeRewriter):
@@ -1166,9 +1166,9 @@ def add_tracker(self, rw: NodeRewriter):
11661166
else:
11671167
for c in tracks:
11681168
if isinstance(c, type):
1169-
self.tracked_types.setdefault(c, []).append(rw)
1169+
self.tracked_types[c].append(rw)
11701170
else:
1171-
self.tracked_instances.setdefault(c, []).append(rw)
1171+
self.tracked_instances[c].append(rw)
11721172

11731173
def _find_impl(self, cls) -> list[NodeRewriter]:
11741174
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
@@ -1250,22 +1250,16 @@ def __init__(
12501250

12511251
self.profile = profile
12521252
if self.profile:
1253-
self.time_rewrites: dict[Rewriter, float] = {}
1254-
self.process_count: dict[Rewriter, int] = {}
1255-
self.applied_true: dict[Rewriter, int] = {}
1256-
self.node_created: dict[Rewriter, int] = {}
1253+
self.time_rewrites: dict[Rewriter, float] = defaultdict(float)
1254+
self.process_count: dict[Rewriter, int] = Counter()
1255+
self.applied_true: dict[Rewriter, int] = Counter()
1256+
self.node_created: dict[Rewriter, int] = Counter()
12571257

12581258
self.tracker = OpToRewriterTracker()
12591259

12601260
for o in self.rewrites:
12611261
self.tracker.add_tracker(o)
12621262

1263-
if self.profile:
1264-
self.time_rewrites.setdefault(o, 0.0)
1265-
self.process_count.setdefault(o, 0)
1266-
self.applied_true.setdefault(o, 0)
1267-
self.node_created.setdefault(o, 0)
1268-
12691263
def __str__(self):
12701264
return getattr(
12711265
self,
@@ -2316,30 +2310,29 @@ def apply(self, fgraph, start_from=None):
23162310
changed = True
23172311
max_use_abort = False
23182312
rewriter_name = None
2319-
global_process_count = {}
2313+
global_process_count = Counter()
23202314
start_nb_nodes = len(fgraph.apply_nodes)
23212315
max_nb_nodes = len(fgraph.apply_nodes)
23222316
max_use = max_nb_nodes * self.max_use_ratio
23232317

23242318
loop_timing = []
23252319
loop_process_count = []
23262320
global_rewriter_timing = []
2327-
time_rewriters = {}
2321+
time_rewriters = defaultdict(float)
23282322
io_toposort_timing = []
23292323
nb_nodes = []
2330-
node_created = {}
2324+
node_created = Counter()
23312325
global_sub_profs = []
23322326
final_sub_profs = []
23332327
cleanup_sub_profs = []
2334-
for rewriter in (
2335-
self.global_rewriters
2336-
+ list(self.get_node_rewriters())
2337-
+ self.final_rewriters
2338-
+ self.cleanup_rewriters
2339-
):
2340-
global_process_count.setdefault(rewriter, 0)
2341-
time_rewriters.setdefault(rewriter, 0)
2342-
node_created.setdefault(rewriter, 0)
2328+
2329+
for rewriter in [
2330+
*self.global_rewriters,
2331+
*self.get_node_rewriters(),
2332+
*self.final_rewriters,
2333+
*self.cleanup_rewriters,
2334+
]:
2335+
time_rewriters[rewriter] += 0
23432336

23442337
def apply_cleanup(profs_dict):
23452338
changed = False
@@ -2351,15 +2344,14 @@ def apply_cleanup(profs_dict):
23512344
time_rewriters[crewriter] += time.perf_counter() - t_rewrite
23522345
profs_dict[crewriter].append(sub_prof)
23532346
if change_tracker.changed:
2354-
process_count.setdefault(crewriter, 0)
23552347
process_count[crewriter] += 1
23562348
global_process_count[crewriter] += 1
23572349
changed = True
23582350
node_created[crewriter] += change_tracker.nb_imported - nb
23592351
return changed
23602352

23612353
while changed and not max_use_abort:
2362-
process_count = {}
2354+
process_count = Counter()
23632355
t0 = time.perf_counter()
23642356
changed = False
23652357
iter_cleanup_sub_profs = {}
@@ -2376,7 +2368,6 @@ def apply_cleanup(profs_dict):
23762368
time_rewriters[grewrite] += time.perf_counter() - t_rewrite
23772369
sub_profs.append(sub_prof)
23782370
if change_tracker.changed:
2379-
process_count.setdefault(grewrite, 0)
23802371
process_count[grewrite] += 1
23812372
global_process_count[grewrite] += 1
23822373
changed = True
@@ -2431,7 +2422,6 @@ def chin_(node, i, r, new_r, reason):
24312422
time_rewriters[node_rewriter] += time.perf_counter() - t_rewrite
24322423
if not node_rewriter_change:
24332424
continue
2434-
process_count.setdefault(node_rewriter, 0)
24352425
process_count[node_rewriter] += 1
24362426
global_process_count[node_rewriter] += 1
24372427
changed = True
@@ -2459,7 +2449,6 @@ def chin_(node, i, r, new_r, reason):
24592449
time_rewriters[grewrite] += time.perf_counter() - t_rewrite
24602450
sub_profs.append(sub_prof)
24612451
if change_tracker.changed:
2462-
process_count.setdefault(grewrite, 0)
24632452
process_count[grewrite] += 1
24642453
global_process_count[grewrite] += 1
24652454
changed = True
@@ -2514,7 +2503,7 @@ def chin_(node, i, r, new_r, reason):
25142503
(start_nb_nodes, end_nb_nodes, max_nb_nodes),
25152504
global_rewriter_timing,
25162505
nb_nodes,
2517-
time_rewriters,
2506+
dict(time_rewriters),
25182507
io_toposort_timing,
25192508
node_created,
25202509
global_sub_profs,
@@ -2597,14 +2586,7 @@ def print_profile(cls, stream, prof, level=0):
25972586
count_rewrite = []
25982587
not_used = []
25992588
not_used_time = 0
2600-
process_count = {}
2601-
for o in (
2602-
rewrite.global_rewriters
2603-
+ list(rewrite.get_node_rewriters())
2604-
+ list(rewrite.final_rewriters)
2605-
+ list(rewrite.cleanup_rewriters)
2606-
):
2607-
process_count.setdefault(o, 0)
2589+
process_count = Counter()
26082590
for count in loop_process_count:
26092591
for o, v in count.items():
26102592
process_count[o] += v

0 commit comments

Comments
 (0)