Skip to content

Commit 177124b

Browse files
ArmavicaricardoV94
authored andcommitted
Remove OrderedDict from scan/op
1 parent c7a99b6 commit 177124b

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

pytensor/scan/op.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import dataclasses
4747
import logging
4848
import time
49-
from collections import OrderedDict
5049
from collections.abc import Callable, Iterable
5150
from copy import copy
5251
from itertools import chain, product
@@ -2188,7 +2187,7 @@ def infer_shape(self, fgraph, node, input_shapes):
21882187
# corresponding outer inputs that the Scan would use as input for
21892188
# any given iteration. For simplicity, we use iteration 0.
21902189
inner_ins_shapes = []
2191-
out_equivalent = OrderedDict()
2190+
out_equivalent = {}
21922191

21932192
# The two following blocks are commented as it cause in some
21942193
# cases extra scans in the graph. See gh-XXX for the
@@ -2469,7 +2468,7 @@ def compute_all_gradients(known_grads):
24692468
if (x in diff_inputs)
24702469
and get_inp_idx(self_inputs.index(x)) in connected_inputs
24712470
]
2472-
gmp = OrderedDict()
2471+
gmp = {}
24732472

24742473
# Required in case there is a pair of variables X and Y, with X
24752474
# used to compute Y, for both of which there is an external
@@ -2478,7 +2477,7 @@ def compute_all_gradients(known_grads):
24782477
# it will be the sum of the external gradient signal and the
24792478
# gradient obtained by propagating Y's external gradient signal
24802479
# to X.
2481-
known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()])
2480+
known_grads = {k.copy(): v for (k, v) in known_grads.items()}
24822481

24832482
grads = grad(
24842483
cost=None,
@@ -2548,7 +2547,7 @@ def compute_all_gradients(known_grads):
25482547
dC_dXt = safe_new(dC_douts[idx][0])
25492548
dC_dXts.append(dC_dXt)
25502549

2551-
known_grads = OrderedDict()
2550+
known_grads = {}
25522551
dc_dxts_idx = 0
25532552
for i in range(len(diff_outputs)):
25542553
if i < idx_nitsot_start or i >= idx_nitsot_end:

0 commit comments

Comments
 (0)