10
10
import time
11
11
import traceback
12
12
import warnings
13
- from collections import UserList , defaultdict , deque
13
+ from collections import Counter , UserList , defaultdict , deque
14
14
from collections .abc import Callable , Iterable , Sequence
15
15
from collections .abc import Iterable as IterableType
16
16
from functools import _compose_mro , partial , reduce # type: ignore
@@ -1153,8 +1153,8 @@ class OpToRewriterTracker:
1153
1153
r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance."""
1154
1154
1155
1155
def __init__ (self ):
1156
- self .tracked_instances : dict [Op , list [NodeRewriter ]] = {}
1157
- self .tracked_types : dict [type , list [NodeRewriter ]] = {}
1156
+ self .tracked_instances : dict [Op , list [NodeRewriter ]] = defaultdict ( list )
1157
+ self .tracked_types : dict [type , list [NodeRewriter ]] = defaultdict ( list )
1158
1158
self .untracked_rewrites : list [NodeRewriter ] = []
1159
1159
1160
1160
def add_tracker (self , rw : NodeRewriter ):
@@ -1166,9 +1166,9 @@ def add_tracker(self, rw: NodeRewriter):
1166
1166
else :
1167
1167
for c in tracks :
1168
1168
if isinstance (c , type ):
1169
- self .tracked_types . setdefault ( c , []) .append (rw )
1169
+ self .tracked_types [ c ] .append (rw )
1170
1170
else :
1171
- self .tracked_instances . setdefault ( c , []) .append (rw )
1171
+ self .tracked_instances [ c ] .append (rw )
1172
1172
1173
1173
def _find_impl (self , cls ) -> list [NodeRewriter ]:
1174
1174
r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance.
@@ -1250,22 +1250,16 @@ def __init__(
1250
1250
1251
1251
self .profile = profile
1252
1252
if self .profile :
1253
- self .time_rewrites : dict [Rewriter , float ] = {}
1254
- self .process_count : dict [Rewriter , int ] = {}
1255
- self .applied_true : dict [Rewriter , int ] = {}
1256
- self .node_created : dict [Rewriter , int ] = {}
1253
+ self .time_rewrites : dict [Rewriter , float ] = defaultdict ( float )
1254
+ self .process_count : dict [Rewriter , int ] = Counter ()
1255
+ self .applied_true : dict [Rewriter , int ] = Counter ()
1256
+ self .node_created : dict [Rewriter , int ] = Counter ()
1257
1257
1258
1258
self .tracker = OpToRewriterTracker ()
1259
1259
1260
1260
for o in self .rewrites :
1261
1261
self .tracker .add_tracker (o )
1262
1262
1263
- if self .profile :
1264
- self .time_rewrites .setdefault (o , 0.0 )
1265
- self .process_count .setdefault (o , 0 )
1266
- self .applied_true .setdefault (o , 0 )
1267
- self .node_created .setdefault (o , 0 )
1268
-
1269
1263
def __str__ (self ):
1270
1264
return getattr (
1271
1265
self ,
@@ -2316,30 +2310,29 @@ def apply(self, fgraph, start_from=None):
2316
2310
changed = True
2317
2311
max_use_abort = False
2318
2312
rewriter_name = None
2319
- global_process_count = {}
2313
+ global_process_count = Counter ()
2320
2314
start_nb_nodes = len (fgraph .apply_nodes )
2321
2315
max_nb_nodes = len (fgraph .apply_nodes )
2322
2316
max_use = max_nb_nodes * self .max_use_ratio
2323
2317
2324
2318
loop_timing = []
2325
2319
loop_process_count = []
2326
2320
global_rewriter_timing = []
2327
- time_rewriters = {}
2321
+ time_rewriters = defaultdict ( float )
2328
2322
io_toposort_timing = []
2329
2323
nb_nodes = []
2330
- node_created = {}
2324
+ node_created = Counter ()
2331
2325
global_sub_profs = []
2332
2326
final_sub_profs = []
2333
2327
cleanup_sub_profs = []
2334
- for rewriter in (
2335
- self .global_rewriters
2336
- + list (self .get_node_rewriters ())
2337
- + self .final_rewriters
2338
- + self .cleanup_rewriters
2339
- ):
2340
- global_process_count .setdefault (rewriter , 0 )
2341
- time_rewriters .setdefault (rewriter , 0 )
2342
- node_created .setdefault (rewriter , 0 )
2328
+
2329
+ for rewriter in [
2330
+ * self .global_rewriters ,
2331
+ * self .get_node_rewriters (),
2332
+ * self .final_rewriters ,
2333
+ * self .cleanup_rewriters ,
2334
+ ]:
2335
+ time_rewriters [rewriter ] += 0
2343
2336
2344
2337
def apply_cleanup (profs_dict ):
2345
2338
changed = False
@@ -2351,15 +2344,14 @@ def apply_cleanup(profs_dict):
2351
2344
time_rewriters [crewriter ] += time .perf_counter () - t_rewrite
2352
2345
profs_dict [crewriter ].append (sub_prof )
2353
2346
if change_tracker .changed :
2354
- process_count .setdefault (crewriter , 0 )
2355
2347
process_count [crewriter ] += 1
2356
2348
global_process_count [crewriter ] += 1
2357
2349
changed = True
2358
2350
node_created [crewriter ] += change_tracker .nb_imported - nb
2359
2351
return changed
2360
2352
2361
2353
while changed and not max_use_abort :
2362
- process_count = {}
2354
+ process_count = Counter ()
2363
2355
t0 = time .perf_counter ()
2364
2356
changed = False
2365
2357
iter_cleanup_sub_profs = {}
@@ -2376,7 +2368,6 @@ def apply_cleanup(profs_dict):
2376
2368
time_rewriters [grewrite ] += time .perf_counter () - t_rewrite
2377
2369
sub_profs .append (sub_prof )
2378
2370
if change_tracker .changed :
2379
- process_count .setdefault (grewrite , 0 )
2380
2371
process_count [grewrite ] += 1
2381
2372
global_process_count [grewrite ] += 1
2382
2373
changed = True
@@ -2431,7 +2422,6 @@ def chin_(node, i, r, new_r, reason):
2431
2422
time_rewriters [node_rewriter ] += time .perf_counter () - t_rewrite
2432
2423
if not node_rewriter_change :
2433
2424
continue
2434
- process_count .setdefault (node_rewriter , 0 )
2435
2425
process_count [node_rewriter ] += 1
2436
2426
global_process_count [node_rewriter ] += 1
2437
2427
changed = True
@@ -2459,7 +2449,6 @@ def chin_(node, i, r, new_r, reason):
2459
2449
time_rewriters [grewrite ] += time .perf_counter () - t_rewrite
2460
2450
sub_profs .append (sub_prof )
2461
2451
if change_tracker .changed :
2462
- process_count .setdefault (grewrite , 0 )
2463
2452
process_count [grewrite ] += 1
2464
2453
global_process_count [grewrite ] += 1
2465
2454
changed = True
@@ -2514,7 +2503,7 @@ def chin_(node, i, r, new_r, reason):
2514
2503
(start_nb_nodes , end_nb_nodes , max_nb_nodes ),
2515
2504
global_rewriter_timing ,
2516
2505
nb_nodes ,
2517
- time_rewriters ,
2506
+ dict ( time_rewriters ) ,
2518
2507
io_toposort_timing ,
2519
2508
node_created ,
2520
2509
global_sub_profs ,
@@ -2597,14 +2586,7 @@ def print_profile(cls, stream, prof, level=0):
2597
2586
count_rewrite = []
2598
2587
not_used = []
2599
2588
not_used_time = 0
2600
- process_count = {}
2601
- for o in (
2602
- rewrite .global_rewriters
2603
- + list (rewrite .get_node_rewriters ())
2604
- + list (rewrite .final_rewriters )
2605
- + list (rewrite .cleanup_rewriters )
2606
- ):
2607
- process_count .setdefault (o , 0 )
2589
+ process_count = Counter ()
2608
2590
for count in loop_process_count :
2609
2591
for o , v in count .items ():
2610
2592
process_count [o ] += v
0 commit comments