5
5
"""
6
6
7
7
import itertools
8
- from collections import OrderedDict , deque
8
+ from collections import deque
9
9
10
10
import pytensor
11
11
from pytensor .configdefaults import config
@@ -306,7 +306,7 @@ def __init__(self, do_imports_on_attach=True, algo=None):
306
306
TODO: change name to var_to_vroot.
307
307
308
308
"""
309
- self .droot = OrderedDict ()
309
+ self .droot = {}
310
310
311
311
"""
312
312
Maps a variable to all variables that are indirect or direct views of it
@@ -317,19 +317,19 @@ def __init__(self, do_imports_on_attach=True, algo=None):
317
317
TODO: rename to x_to_views after reverse engineering what x is
318
318
319
319
"""
320
- self .impact = OrderedDict ()
320
+ self .impact = {}
321
321
322
322
"""
323
323
If a var is destroyed, then this dict will map
324
324
droot[var] to the apply node that destroyed var
325
325
TODO: rename to vroot_to_destroyer
326
326
327
327
"""
328
- self .root_destroyer = OrderedDict ()
328
+ self .root_destroyer = {}
329
329
if algo is None :
330
330
algo = config .cycle_detection
331
331
self .algo = algo
332
- self .fail_validate = OrderedDict ()
332
+ self .fail_validate = {}
333
333
334
334
def clone (self ):
335
335
return type (self )(self .do_imports_on_attach , self .algo )
@@ -370,7 +370,7 @@ def on_attach(self, fgraph):
370
370
self .view_i = {} # variable -> variable used in calculation
371
371
self .view_o = {} # variable -> set of variables that use this one as a direct input
372
372
# clients: how many times does an apply use a given variable
373
- self .clients = OrderedDict () # variable -> apply -> ninputs
373
+ self .clients = {} # variable -> apply -> ninputs
374
374
self .stale_droot = True
375
375
376
376
self .debug_all_apps = set ()
@@ -527,11 +527,11 @@ def on_import(self, fgraph, app, reason):
527
527
528
528
# update self.clients
529
529
for i , input in enumerate (app .inputs ):
530
- self .clients .setdefault (input , OrderedDict () ).setdefault (app , 0 )
530
+ self .clients .setdefault (input , {} ).setdefault (app , 0 )
531
531
self .clients [input ][app ] += 1
532
532
533
533
for i , output in enumerate (app .outputs ):
534
- self .clients .setdefault (output , OrderedDict () )
534
+ self .clients .setdefault (output , {} )
535
535
536
536
self .stale_droot = True
537
537
@@ -591,7 +591,7 @@ def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
591
591
if self .clients [old_r ][app ] == 0 :
592
592
del self .clients [old_r ][app ]
593
593
594
- self .clients .setdefault (new_r , OrderedDict () ).setdefault (app , 0 )
594
+ self .clients .setdefault (new_r , {} ).setdefault (app , 0 )
595
595
self .clients [new_r ][app ] += 1
596
596
597
597
# UPDATE self.view_i, self.view_o
@@ -632,7 +632,7 @@ def validate(self, fgraph):
632
632
if self .algo == "fast" :
633
633
if self .fail_validate :
634
634
app_err_pairs = self .fail_validate
635
- self .fail_validate = OrderedDict ()
635
+ self .fail_validate = {}
636
636
# self.fail_validate can only be a hint that maybe/probably
637
637
# there is a cycle.This is because inside replace() we could
638
638
# record many reasons to not accept a change, but we don't
@@ -674,12 +674,8 @@ def orderings(self, fgraph, ordered=True):
674
674
c) an Apply destroys (illegally) one of its own inputs by aliasing
675
675
676
676
"""
677
- if ordered :
678
- set_type = OrderedSet
679
- rval = OrderedDict ()
680
- else :
681
- set_type = set
682
- rval = dict ()
677
+ set_type = OrderedSet if ordered else set
678
+ rval = {}
683
679
684
680
if self .destroyers :
685
681
# BUILD DATA STRUCTURES
0 commit comments