Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 24 additions & 22 deletions orderly_set/sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
TypeVar,
Union,
overload,
Hashable,
)

SLICE_ALL = slice(None)

T = TypeVar("T")
S = TypeVar("S", bound="StableSet")

# SetLike[T] is either a set of elements of type T, or a sequence, which
# we will convert to a StableSet or to an OrderedSet by adding its elements in order.
Expand Down Expand Up @@ -211,7 +213,7 @@ def __and__(self, other: SetLike[T]) -> "StableSet[T]":
# (left hand and right hand - as the operands order does matter)
# based on the implementations of the super class (Set(Collection)),
# see _collections_abc.py
def __sub__(self, other):
def __sub__(self: S, other: AbstractSet[T]) -> S:
cls = type(
self
if isinstance(self, StableSet)
Expand All @@ -225,7 +227,7 @@ def __sub__(self, other):
other = cls(other)
return cls(value for value in self if value not in other)

def __rsub__(self, other):
def __rsub__(self: S, other: AbstractSet[T]) -> S:
cls = type(
self
if isinstance(self, StableSet)
Expand All @@ -240,7 +242,7 @@ def __rsub__(self, other):
return cls(value for value in other if value not in self)


def __or__(self, other):
def __or__(self: S, other: AbstractSet[T]) -> S:
cls = type(
self
if isinstance(self, StableSet)
Expand All @@ -253,7 +255,7 @@ def __or__(self, other):
chain = (e for s in (self, other) for e in s)
return cls(chain)

def __ror__(self, other):
def __ror__(self: S, other: AbstractSet[T]) -> S:
cls = type(
self
if isinstance(self, StableSet)
Expand All @@ -266,17 +268,17 @@ def __ror__(self, other):
chain = (e for s in (other, self) for e in s)
return cls(chain)

def __xor__(self, other):
def __xor__(self: S, other: AbstractSet[T]) -> S:
if not isinstance(other, Iterable):
return NotImplemented
return (self - other) | (other - self)

def __rxor__(self, other):
def __rxor__(self: S, other: AbstractSet[T]) -> S:
if not isinstance(other, Iterable):
return NotImplemented
return (other - self) | (self - other)

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, Iterable):
return False
if len(self._map) != len(other):
Expand Down Expand Up @@ -313,7 +315,7 @@ def copy(self) -> "StableSet[T]":
# Technically type-incompatible with MutableSet, because we return an
# int instead of nothing. This is also one of the things that makes
# StableSet convenient to use.
def add(self, key: T) -> int:
def add(self, key: T) -> int: # pyright: ignore
"""
Add `key` as an item to this StableSet, then return its index.

Expand Down Expand Up @@ -355,15 +357,15 @@ def update(self, sequence: SetLike[T]) -> int:
return len(self._map) - 1

@overload
def index(self, key: Sequence[T]) -> List[int]: # NOQA
def index(self, key: Sequence[T]) -> List[int]:
...

@overload
def index(self, key: T) -> int: # NOQA
def index(self, key: T) -> int:
...

# concrete implementation
def index(self, key): # NOQA
def index(self, key: Union[T, Sequence[T]]) -> Union[int, List[int]]: # pyright: ignore
"""
Get the index of a given entry, raising an IndexError if it's not present

Expand Down Expand Up @@ -443,7 +445,7 @@ def move_to_end(self, key) -> None:
self._map.pop(key)
self._map[key] = None

def discard(self, key: T) -> None:
def discard(self, value: T) -> None:
"""
Remove an element. Do not raise an exception if absent.

Expand All @@ -462,7 +464,7 @@ def discard(self, key: T) -> None:
if self._is_mutable is False:
raise ValueError("This object is not mutable.")

self._map.pop(key, None)
self._map.pop(value, None)

def union(self, *sets: SetLike[T]) -> "StableSet[T]":
"""
Expand All @@ -483,7 +485,7 @@ def union(self, *sets: SetLike[T]) -> "StableSet[T]":
items = it.chain.from_iterable(containers)
return cls(items) # type: ignore

def intersection(self, *sets: SetLike[T]) -> "StableSet[T]":
def intersection(self: S, *sets: SetLike[T]) -> S:
"""
Returns elements in common between all sets. Order is defined only
by the first set.
Expand All @@ -504,7 +506,7 @@ def intersection(self, *sets: SetLike[T]) -> "StableSet[T]":
items = (item for item in self if item in common)
return cls(items)

def difference(self, *sets: SetLike[T]) -> "StableSet[T]":
def difference(self: S, *sets: SetLike[T]) -> S:
"""
Returns all elements that are in this set but not the others.

Expand All @@ -525,7 +527,7 @@ def difference(self, *sets: SetLike[T]) -> "StableSet[T]":
items = (item for item in self if item not in other)
return cls(items)

def symmetric_difference(self, other: SetLike[T]) -> "StableSet[T]":
def symmetric_difference(self: S, other: SetLike[T]) -> S:
"""
Return the symmetric difference of two StableSets as a new set.
That is, the new set will contain all elements that are in exactly
Expand Down Expand Up @@ -644,7 +646,7 @@ def issuperset(self, other: SetLike[T]) -> bool:
return False
return all(item in self for item in other)

def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False):
def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False) -> bool:
if len(self) > len(other):
return False
if non_consecutive:
Expand All @@ -662,13 +664,13 @@ def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False
return False
return True

def isorderedsuperset(self, other: SetLike, non_consecutive: bool = False):
def isorderedsuperset(self, other: SetLike, non_consecutive: bool = False) -> bool:
return StableSet.isorderedsubset(other, self, non_consecutive)

def get(self):
def get(self) -> Hashable:
return next(iter(self._map))

def freeze(self):
def freeze(self) -> None:
"""
Once this function is run, the object becomes immutable
"""
Expand Down Expand Up @@ -1210,7 +1212,7 @@ def pop(self, index=None):
self.set_.remove(result)
return result

def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False):
def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False) -> bool:
if len(self) > len(other):
return False
if non_consecutive:
Expand All @@ -1228,5 +1230,5 @@ def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False
return False
return True

def isorderedsuperset(self, other: SetLike, non_consecutive: bool = False):
def isorderedsuperset(self, other: SetLike, non_consecutive: bool = False) -> bool:
return StableSet.isorderedsubset(other, self, non_consecutive)