Skip to content

Commit 8fec4e9

Browse files
committed
Rewrite OrderedSet
1 parent 452a15b commit 8fec4e9

File tree

1 file changed

+29
-185
lines changed

1 file changed

+29
-185
lines changed

pytensor/misc/ordered_set.py

+29-185
Original file line numberDiff line numberDiff line change
@@ -1,197 +1,41 @@
1-
import types
2-
import weakref
3-
from collections.abc import MutableSet
4-
5-
6-
def check_deterministic(iterable):
7-
# Most places where OrderedSet is used, pytensor interprets any exception
8-
# whatsoever as a problem that an optimization introduced into the graph.
9-
# If I raise a TypeError when the DestroyHandler tries to do something
10-
# non-deterministic, it will just result in optimizations getting ignored.
11-
# So I must use an assert here. In the long term we should fix the rest of
12-
# pytensor to use exceptions correctly, so that this can be a TypeError.
13-
if iterable is not None:
14-
if not isinstance(
15-
iterable, list | tuple | OrderedSet | types.GeneratorType | str | dict
16-
):
17-
if len(iterable) > 1:
18-
# We need to accept length 1 size to allow unpickle in tests.
19-
raise AssertionError(
20-
"Get an not ordered iterable when one was expected"
21-
)
22-
23-
24-
# Copyright (C) 2009 Raymond Hettinger
25-
# Permission is hereby granted, free of charge, to any person obtaining a
26-
# copy of this software and associated documentation files (the
27-
# "Software"), to deal in the Software without restriction, including
28-
# without limitation the rights to use, copy, modify, merge, publish,
29-
# distribute, sublicense, and/or sell copies of the Software, and to permit
30-
# persons to whom the Software is furnished to do so, subject to the
31-
# following conditions:
32-
33-
# The above copyright notice and this permission notice shall be included
34-
# in all copies or substantial portions of the Software.
35-
36-
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
37-
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
38-
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
39-
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
40-
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
41-
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
42-
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
43-
# {{{ http://code.activestate.com/recipes/576696/ (r5)
44-
45-
46-
class Link:
47-
# This make that we need to use a different pickle protocol
48-
# then the default. Otherwise, there is pickling errors
49-
__slots__ = "prev", "next", "key", "__weakref__"
50-
51-
def __getstate__(self):
52-
# weakref.proxy don't pickle well, so we use weakref.ref
53-
# manually and don't pickle the weakref.
54-
# We restore the weakref when we unpickle.
55-
ret = [self.prev(), self.next()]
56-
try:
57-
ret.append(self.key)
58-
except AttributeError:
59-
pass
60-
return ret
61-
62-
def __setstate__(self, state):
63-
self.prev = weakref.ref(state[0])
64-
self.next = weakref.ref(state[1])
65-
if len(state) == 3:
66-
self.key = state[2]
1+
from collections.abc import Iterable, Iterator, MutableSet
2+
from typing import Any
673

684

695
class OrderedSet(MutableSet):
70-
"Set the remembers the order elements were added"
71-
72-
# Big-O running times for all methods are the same as for regular sets.
73-
# The internal self.__map dictionary maps keys to links in a doubly linked list.
74-
# The circular doubly linked list starts and ends with a sentinel element.
75-
# The sentinel element never gets deleted (this simplifies the algorithm).
76-
# The prev/next links are weakref proxies (to prevent circular references).
77-
# Individual links are kept alive by the hard reference in self.__map.
78-
# Those hard references disappear when a key is deleted from an OrderedSet.
6+
values: dict[Any, None]
797

80-
# Added by IG-- pre-existing pytensor code expected sets
81-
# to have this method
82-
def update(self, iterable):
83-
check_deterministic(iterable)
84-
self |= iterable
85-
86-
def __init__(self, iterable=None):
87-
# Checks added by IG
88-
check_deterministic(iterable)
89-
self.__root = root = Link() # sentinel node for doubly linked list
90-
root.prev = root.next = weakref.ref(root)
91-
self.__map = {} # key --> link
92-
if iterable is not None:
93-
self |= iterable
94-
95-
def __len__(self):
96-
return len(self.__map)
97-
98-
def __contains__(self, key):
99-
return key in self.__map
100-
101-
def add(self, key):
102-
# Store new key in a new link at the end of the linked list
103-
if key not in self.__map:
104-
self.__map[key] = link = Link()
105-
root = self.__root
106-
last = root.prev
107-
link.prev, link.next, link.key = last, weakref.ref(root), key
108-
last().next = root.prev = weakref.ref(link)
109-
110-
def union(self, s):
111-
check_deterministic(s)
112-
n = self.copy()
113-
for elem in s:
114-
if elem not in n:
115-
n.add(elem)
116-
return n
117-
118-
def intersection_update(self, s):
119-
l = []
120-
for elem in self:
121-
if elem not in s:
122-
l.append(elem)
123-
for elem in l:
124-
self.remove(elem)
125-
return self
126-
127-
def difference_update(self, s):
128-
check_deterministic(s)
129-
for elem in s:
130-
if elem in self:
131-
self.remove(elem)
132-
return self
8+
def __init__(self, iterable: Iterable | None = None) -> None:
9+
if iterable is None:
10+
self.values = {}
11+
else:
12+
self.values = {value: None for value in iterable}
13313

134-
def copy(self):
135-
n = OrderedSet()
136-
n.update(self)
137-
return n
14+
def __contains__(self, value) -> bool:
15+
return value in self.values
13816

139-
def discard(self, key):
140-
# Remove an existing item using self.__map to find the link which is
141-
# then removed by updating the links in the predecessor and successors.
142-
if key in self.__map:
143-
link = self.__map.pop(key)
144-
link.prev().next = link.next
145-
link.next().prev = link.prev
17+
def __iter__(self) -> Iterator:
18+
yield from self.values
14619

147-
def __iter__(self):
148-
# Traverse the linked list in order.
149-
root = self.__root
150-
curr = root.next()
151-
while curr is not root:
152-
yield curr.key
153-
curr = curr.next()
20+
def __len__(self) -> int:
21+
return len(self.values)
15422

155-
def __reversed__(self):
156-
# Traverse the linked list in reverse order.
157-
root = self.__root
158-
curr = root.prev()
159-
while curr is not root:
160-
yield curr.key
161-
curr = curr.prev()
23+
def add(self, value) -> None:
24+
self.values[value] = None
16225

163-
def pop(self, last=True):
164-
if not self:
165-
raise KeyError("set is empty")
166-
if last:
167-
key = next(reversed(self))
168-
else:
169-
key = next(iter(self))
170-
self.discard(key)
171-
return key
26+
def discard(self, value) -> None:
27+
if value in self.values:
28+
del self.values[value]
17229

173-
def __repr__(self):
174-
if not self:
175-
return f"{self.__class__.__name__}()"
176-
return f"{self.__class__.__name__}({list(self)!r})"
177-
178-
def __eq__(self, other):
179-
# Note that we implement only the comparison to another
180-
# `OrderedSet`, and not to a regular `set`, because otherwise we
181-
# could have a non-symmetric equality relation like:
182-
# my_ordered_set == my_set and my_set != my_ordered_set
183-
if isinstance(other, OrderedSet):
184-
return len(self) == len(other) and list(self) == list(other)
185-
elif isinstance(other, set):
186-
# Raise exception to avoid confusion.
187-
raise TypeError(
188-
"Cannot compare an `OrderedSet` to a `set` because "
189-
"this comparison cannot be made symmetric: please "
190-
"manually cast your `OrderedSet` into `set` before "
191-
"performing this comparison."
192-
)
193-
else:
194-
return NotImplemented
30+
def copy(self) -> "OrderedSet":
31+
return OrderedSet(self)
19532

33+
def union(self, other: Iterable) -> "OrderedSet":
34+
new_set = OrderedSet(self)
35+
for value in other:
36+
new_set.add(value)
37+
return new_set
19638

197-
# end of http://code.activestate.com/recipes/576696/ }}}
39+
def difference_update(self, other: Iterable) -> None:
40+
for value in other:
41+
self.discard(value)

0 commit comments

Comments
 (0)