|
1 |
| -import types |
2 |
| -import weakref |
3 |
| -from collections.abc import MutableSet |
4 |
| - |
5 |
| - |
6 |
| -def check_deterministic(iterable): |
7 |
| - # Most places where OrderedSet is used, pytensor interprets any exception |
8 |
| - # whatsoever as a problem that an optimization introduced into the graph. |
9 |
| - # If I raise a TypeError when the DestroyHandler tries to do something |
10 |
| - # non-deterministic, it will just result in optimizations getting ignored. |
11 |
| - # So I must use an assert here. In the long term we should fix the rest of |
12 |
| - # pytensor to use exceptions correctly, so that this can be a TypeError. |
13 |
| - if iterable is not None: |
14 |
| - if not isinstance( |
15 |
| - iterable, list | tuple | OrderedSet | types.GeneratorType | str | dict |
16 |
| - ): |
17 |
| - if len(iterable) > 1: |
18 |
| - # We need to accept length 1 size to allow unpickle in tests. |
19 |
| - raise AssertionError( |
20 |
| - "Get an not ordered iterable when one was expected" |
21 |
| - ) |
22 |
| - |
23 |
| - |
24 |
| -# Copyright (C) 2009 Raymond Hettinger |
25 |
| -# Permission is hereby granted, free of charge, to any person obtaining a |
26 |
| -# copy of this software and associated documentation files (the |
27 |
| -# "Software"), to deal in the Software without restriction, including |
28 |
| -# without limitation the rights to use, copy, modify, merge, publish, |
29 |
| -# distribute, sublicense, and/or sell copies of the Software, and to permit |
30 |
| -# persons to whom the Software is furnished to do so, subject to the |
31 |
| -# following conditions: |
32 |
| - |
33 |
| -# The above copyright notice and this permission notice shall be included |
34 |
| -# in all copies or substantial portions of the Software. |
35 |
| - |
36 |
| -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS |
37 |
| -# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF |
38 |
| -# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. |
39 |
| -# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY |
40 |
| -# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
41 |
| -# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE |
42 |
| -# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
43 |
| -# {{{ http://code.activestate.com/recipes/576696/ (r5) |
44 |
| - |
45 |
| - |
46 |
| -class Link: |
47 |
| - # This make that we need to use a different pickle protocol |
48 |
| - # then the default. Otherwise, there is pickling errors |
49 |
| - __slots__ = "prev", "next", "key", "__weakref__" |
50 |
| - |
51 |
| - def __getstate__(self): |
52 |
| - # weakref.proxy don't pickle well, so we use weakref.ref |
53 |
| - # manually and don't pickle the weakref. |
54 |
| - # We restore the weakref when we unpickle. |
55 |
| - ret = [self.prev(), self.next()] |
56 |
| - try: |
57 |
| - ret.append(self.key) |
58 |
| - except AttributeError: |
59 |
| - pass |
60 |
| - return ret |
61 |
| - |
62 |
| - def __setstate__(self, state): |
63 |
| - self.prev = weakref.ref(state[0]) |
64 |
| - self.next = weakref.ref(state[1]) |
65 |
| - if len(state) == 3: |
66 |
| - self.key = state[2] |
| 1 | +from collections.abc import Iterable, Iterator, MutableSet |
| 2 | +from typing import Any |
67 | 3 |
|
68 | 4 |
|
69 | 5 | class OrderedSet(MutableSet):
|
70 |
| - "Set the remembers the order elements were added" |
71 |
| - |
72 |
| - # Big-O running times for all methods are the same as for regular sets. |
73 |
| - # The internal self.__map dictionary maps keys to links in a doubly linked list. |
74 |
| - # The circular doubly linked list starts and ends with a sentinel element. |
75 |
| - # The sentinel element never gets deleted (this simplifies the algorithm). |
76 |
| - # The prev/next links are weakref proxies (to prevent circular references). |
77 |
| - # Individual links are kept alive by the hard reference in self.__map. |
78 |
| - # Those hard references disappear when a key is deleted from an OrderedSet. |
| 6 | + values: dict[Any, None] |
79 | 7 |
|
80 |
| - # Added by IG-- pre-existing pytensor code expected sets |
81 |
| - # to have this method |
82 |
| - def update(self, iterable): |
83 |
| - check_deterministic(iterable) |
84 |
| - self |= iterable |
85 |
| - |
86 |
| - def __init__(self, iterable=None): |
87 |
| - # Checks added by IG |
88 |
| - check_deterministic(iterable) |
89 |
| - self.__root = root = Link() # sentinel node for doubly linked list |
90 |
| - root.prev = root.next = weakref.ref(root) |
91 |
| - self.__map = {} # key --> link |
92 |
| - if iterable is not None: |
93 |
| - self |= iterable |
94 |
| - |
95 |
| - def __len__(self): |
96 |
| - return len(self.__map) |
97 |
| - |
98 |
| - def __contains__(self, key): |
99 |
| - return key in self.__map |
100 |
| - |
101 |
| - def add(self, key): |
102 |
| - # Store new key in a new link at the end of the linked list |
103 |
| - if key not in self.__map: |
104 |
| - self.__map[key] = link = Link() |
105 |
| - root = self.__root |
106 |
| - last = root.prev |
107 |
| - link.prev, link.next, link.key = last, weakref.ref(root), key |
108 |
| - last().next = root.prev = weakref.ref(link) |
109 |
| - |
110 |
| - def union(self, s): |
111 |
| - check_deterministic(s) |
112 |
| - n = self.copy() |
113 |
| - for elem in s: |
114 |
| - if elem not in n: |
115 |
| - n.add(elem) |
116 |
| - return n |
117 |
| - |
118 |
| - def intersection_update(self, s): |
119 |
| - l = [] |
120 |
| - for elem in self: |
121 |
| - if elem not in s: |
122 |
| - l.append(elem) |
123 |
| - for elem in l: |
124 |
| - self.remove(elem) |
125 |
| - return self |
126 |
| - |
127 |
| - def difference_update(self, s): |
128 |
| - check_deterministic(s) |
129 |
| - for elem in s: |
130 |
| - if elem in self: |
131 |
| - self.remove(elem) |
132 |
| - return self |
| 8 | + def __init__(self, iterable: Iterable | None = None) -> None: |
| 9 | + if iterable is None: |
| 10 | + self.values = {} |
| 11 | + else: |
| 12 | + self.values = {value: None for value in iterable} |
133 | 13 |
|
134 |
| - def copy(self): |
135 |
| - n = OrderedSet() |
136 |
| - n.update(self) |
137 |
| - return n |
| 14 | + def __contains__(self, value) -> bool: |
| 15 | + return value in self.values |
138 | 16 |
|
139 |
| - def discard(self, key): |
140 |
| - # Remove an existing item using self.__map to find the link which is |
141 |
| - # then removed by updating the links in the predecessor and successors. |
142 |
| - if key in self.__map: |
143 |
| - link = self.__map.pop(key) |
144 |
| - link.prev().next = link.next |
145 |
| - link.next().prev = link.prev |
| 17 | + def __iter__(self) -> Iterator: |
| 18 | + yield from self.values |
146 | 19 |
|
147 |
| - def __iter__(self): |
148 |
| - # Traverse the linked list in order. |
149 |
| - root = self.__root |
150 |
| - curr = root.next() |
151 |
| - while curr is not root: |
152 |
| - yield curr.key |
153 |
| - curr = curr.next() |
| 20 | + def __len__(self) -> int: |
| 21 | + return len(self.values) |
154 | 22 |
|
155 |
| - def __reversed__(self): |
156 |
| - # Traverse the linked list in reverse order. |
157 |
| - root = self.__root |
158 |
| - curr = root.prev() |
159 |
| - while curr is not root: |
160 |
| - yield curr.key |
161 |
| - curr = curr.prev() |
| 23 | + def add(self, value) -> None: |
| 24 | + self.values[value] = None |
162 | 25 |
|
163 |
| - def pop(self, last=True): |
164 |
| - if not self: |
165 |
| - raise KeyError("set is empty") |
166 |
| - if last: |
167 |
| - key = next(reversed(self)) |
168 |
| - else: |
169 |
| - key = next(iter(self)) |
170 |
| - self.discard(key) |
171 |
| - return key |
| 26 | + def discard(self, value) -> None: |
| 27 | + if value in self.values: |
| 28 | + del self.values[value] |
172 | 29 |
|
173 |
| - def __repr__(self): |
174 |
| - if not self: |
175 |
| - return f"{self.__class__.__name__}()" |
176 |
| - return f"{self.__class__.__name__}({list(self)!r})" |
177 |
| - |
178 |
| - def __eq__(self, other): |
179 |
| - # Note that we implement only the comparison to another |
180 |
| - # `OrderedSet`, and not to a regular `set`, because otherwise we |
181 |
| - # could have a non-symmetric equality relation like: |
182 |
| - # my_ordered_set == my_set and my_set != my_ordered_set |
183 |
| - if isinstance(other, OrderedSet): |
184 |
| - return len(self) == len(other) and list(self) == list(other) |
185 |
| - elif isinstance(other, set): |
186 |
| - # Raise exception to avoid confusion. |
187 |
| - raise TypeError( |
188 |
| - "Cannot compare an `OrderedSet` to a `set` because " |
189 |
| - "this comparison cannot be made symmetric: please " |
190 |
| - "manually cast your `OrderedSet` into `set` before " |
191 |
| - "performing this comparison." |
192 |
| - ) |
193 |
| - else: |
194 |
| - return NotImplemented |
| 30 | + def copy(self) -> "OrderedSet": |
| 31 | + return OrderedSet(self) |
195 | 32 |
|
| 33 | + def union(self, other: Iterable) -> "OrderedSet": |
| 34 | + new_set = OrderedSet(self) |
| 35 | + for value in other: |
| 36 | + new_set.add(value) |
| 37 | + return new_set |
196 | 38 |
|
197 |
| -# end of http://code.activestate.com/recipes/576696/ }}} |
| 39 | + def difference_update(self, other: Iterable) -> None: |
| 40 | + for value in other: |
| 41 | + self.discard(value) |
0 commit comments