Skip to content

Commit 8c0d4d0

Browse files
committed
Remove OrderedDict from updates.py
1 parent 8cb8734 commit 8c0d4d0

File tree

1 file changed

+4
-32
lines changed

1 file changed

+4
-32
lines changed

pytensor/updates.py

+4-32
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Defines Updates object for storing a (SharedVariable, new_value) mapping."""
22

33
import logging
4-
import warnings
5-
from collections import OrderedDict
64

75
from pytensor.compile.sharedvalue import SharedVariable
86

@@ -12,30 +10,16 @@
1210
logger = logging.getLogger("pytensor.updates")
1311

1412

15-
# Must be an OrderedDict or updates will be applied in a non-deterministic
16-
# order.
17-
class OrderedUpdates(OrderedDict):
13+
# Relies on the fact that dict is ordered, otherwise updates will be applied
14+
# in a non-deterministic order.
15+
class OrderedUpdates(dict):
1816
"""
1917
Dict-like mapping from SharedVariable keys to their new values.
2018
2119
This mapping supports the use of the "+" operator for the union of updates.
2220
"""
2321

2422
def __init__(self, *key, **kwargs):
25-
if (
26-
len(key) >= 1
27-
and isinstance(key[0], dict)
28-
and len(key[0]) > 1
29-
and not isinstance(key[0], OrderedDict)
30-
):
31-
# Warn when using as input a non-ordered dictionary.
32-
warnings.warn(
33-
"Initializing an `OrderedUpdates` from a "
34-
"non-ordered dictionary with 2+ elements could "
35-
"make your code non-deterministic. You can use "
36-
"an OrderedDict that is available at "
37-
"collections.OrderedDict for python 2.6+."
38-
)
3923
super().__init__(*key, **kwargs)
4024
for key in self:
4125
if not isinstance(key, SharedVariable):
@@ -56,19 +40,7 @@ def __setitem__(self, key, value):
5640
def update(self, other=None):
5741
if other is None:
5842
return
59-
if (
60-
isinstance(other, dict)
61-
and len(other) > 1
62-
and not isinstance(other, OrderedDict)
63-
):
64-
# Warn about non-determinism.
65-
warnings.warn(
66-
"Updating an `OrderedUpdates` with a "
67-
"non-ordered dictionary with 2+ elements could "
68-
"make your code non-deterministic",
69-
stacklevel=2,
70-
)
71-
for key, val in OrderedDict(other).items():
43+
for key, val in dict(other).items():
7244
if key in self:
7345
if self[key] == val:
7446
continue

0 commit comments

Comments
 (0)