46
46
import dataclasses
47
47
import logging
48
48
import time
49
- from collections import OrderedDict
50
49
from collections .abc import Callable , Iterable
51
50
from copy import copy
52
51
from itertools import chain , product
@@ -2188,7 +2187,7 @@ def infer_shape(self, fgraph, node, input_shapes):
2188
2187
# corresponding outer inputs that the Scan would use as input for
2189
2188
# any given iteration. For simplicity, we use iteration 0.
2190
2189
inner_ins_shapes = []
2191
- out_equivalent = OrderedDict ()
2190
+ out_equivalent = {}
2192
2191
2193
2192
# The two following blocks are commented as it cause in some
2194
2193
# cases extra scans in the graph. See gh-XXX for the
@@ -2469,7 +2468,7 @@ def compute_all_gradients(known_grads):
2469
2468
if (x in diff_inputs )
2470
2469
and get_inp_idx (self_inputs .index (x )) in connected_inputs
2471
2470
]
2472
- gmp = OrderedDict ()
2471
+ gmp = {}
2473
2472
2474
2473
# Required in case there is a pair of variables X and Y, with X
2475
2474
# used to compute Y, for both of which there is an external
@@ -2478,7 +2477,7 @@ def compute_all_gradients(known_grads):
2478
2477
# it will be the sum of the external gradient signal and the
2479
2478
# gradient obtained by propagating Y's external gradient signal
2480
2479
# to X.
2481
- known_grads = OrderedDict ([( k .copy (), v ) for (k , v ) in known_grads .items ()])
2480
+ known_grads = { k .copy (): v for (k , v ) in known_grads .items ()}
2482
2481
2483
2482
grads = grad (
2484
2483
cost = None ,
@@ -2548,7 +2547,7 @@ def compute_all_gradients(known_grads):
2548
2547
dC_dXt = safe_new (dC_douts [idx ][0 ])
2549
2548
dC_dXts .append (dC_dXt )
2550
2549
2551
- known_grads = OrderedDict ()
2550
+ known_grads = {}
2552
2551
dc_dxts_idx = 0
2553
2552
for i in range (len (diff_outputs )):
2554
2553
if i < idx_nitsot_start or i >= idx_nitsot_end :
0 commit comments