diff --git a/orderly_set/sets.py b/orderly_set/sets.py index 899251d..4759967 100644 --- a/orderly_set/sets.py +++ b/orderly_set/sets.py @@ -14,11 +14,13 @@ TypeVar, Union, overload, + Hashable, ) SLICE_ALL = slice(None) T = TypeVar("T") +S = TypeVar("S", bound="StableSet") # SetLike[T] is either a set of elements of type T, or a sequence, which # we will convert to a StableSet or to an OrderedSet by adding its elements in order. @@ -211,7 +213,7 @@ def __and__(self, other: SetLike[T]) -> "StableSet[T]": # (left hand and right hand - as the operands order does matter) # based on the implementations of the super class (Set(Collection)), # see _collections_abc.py - def __sub__(self, other): + def __sub__(self: S, other: AbstractSet[T]) -> S: cls = type( self if isinstance(self, StableSet) @@ -225,7 +227,7 @@ def __sub__(self, other): other = cls(other) return cls(value for value in self if value not in other) - def __rsub__(self, other): + def __rsub__(self: S, other: AbstractSet[T]) -> S: cls = type( self if isinstance(self, StableSet) @@ -240,7 +242,7 @@ def __rsub__(self, other): return cls(value for value in other if value not in self) - def __or__(self, other): + def __or__(self: S, other: AbstractSet[T]) -> S: cls = type( self if isinstance(self, StableSet) @@ -253,7 +255,7 @@ def __or__(self, other): chain = (e for s in (self, other) for e in s) return cls(chain) - def __ror__(self, other): + def __ror__(self: S, other: AbstractSet[T]) -> S: cls = type( self if isinstance(self, StableSet) @@ -266,17 +268,17 @@ def __ror__(self, other): chain = (e for s in (other, self) for e in s) return cls(chain) - def __xor__(self, other): + def __xor__(self: S, other: AbstractSet[T]) -> S: if not isinstance(other, Iterable): return NotImplemented return (self - other) | (other - self) - def __rxor__(self, other): + def __rxor__(self: S, other: AbstractSet[T]) -> S: if not isinstance(other, Iterable): return NotImplemented return (other - self) | (self - other) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, Iterable): return False if len(self._map) != len(other): @@ -313,7 +315,7 @@ def copy(self) -> "StableSet[T]": # Technically type-incompatible with MutableSet, because we return an # int instead of nothing. This is also one of the things that makes # StableSet convenient to use. - def add(self, key: T) -> int: + def add(self, key: T) -> int: # pyright: ignore """ Add `key` as an item to this StableSet, then return its index. @@ -355,15 +357,15 @@ def update(self, sequence: SetLike[T]) -> int: return len(self._map) - 1 @overload - def index(self, key: Sequence[T]) -> List[int]: # NOQA + def index(self, key: Sequence[T]) -> List[int]: ... @overload - def index(self, key: T) -> int: # NOQA + def index(self, key: T) -> int: ... # concrete implementation - def index(self, key): # NOQA + def index(self, key: Union[T, Sequence[T]]) -> Union[int, List[int]]: # pyright: ignore """ Get the index of a given entry, raising an IndexError if it's not present @@ -443,7 +445,7 @@ def move_to_end(self, key) -> None: self._map.pop(key) self._map[key] = None - def discard(self, key: T) -> None: + def discard(self, value: T) -> None: """ Remove an element. Do not raise an exception if absent. @@ -462,7 +464,7 @@ def discard(self, key: T) -> None: if self._is_mutable is False: raise ValueError("This object is not mutable.") - self._map.pop(key, None) + self._map.pop(value, None) def union(self, *sets: SetLike[T]) -> "StableSet[T]": """ @@ -483,7 +485,7 @@ def union(self, *sets: SetLike[T]) -> "StableSet[T]": items = it.chain.from_iterable(containers) return cls(items) # type: ignore - def intersection(self, *sets: SetLike[T]) -> "StableSet[T]": + def intersection(self: S, *sets: SetLike[T]) -> S: """ Returns elements in common between all sets. Order is defined only by the first set. @@ -504,7 +506,7 @@ def intersection(self, *sets: SetLike[T]) -> "StableSet[T]": items = (item for item in self if item in common) return cls(items) - def difference(self, *sets: SetLike[T]) -> "StableSet[T]": + def difference(self: S, *sets: SetLike[T]) -> S: """ Returns all elements that are in this set but not the others. @@ -525,7 +527,7 @@ def difference(self, *sets: SetLike[T]) -> "StableSet[T]": items = (item for item in self if item not in other) return cls(items) - def symmetric_difference(self, other: SetLike[T]) -> "StableSet[T]": + def symmetric_difference(self: S, other: SetLike[T]) -> S: """ Return the symmetric difference of two StableSets as a new set. That is, the new set will contain all elements that are in exactly @@ -644,7 +646,7 @@ def issuperset(self, other: SetLike[T]) -> bool: return False return all(item in self for item in other) - def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False): + def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False) -> bool: if len(self) > len(other): return False if non_consecutive: @@ -662,13 +664,13 @@ def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False return False return True - def isorderedsuperset(self, other: SetLike, non_consecutive: bool = False): + def isorderedsuperset(self, other: SetLike, non_consecutive: bool = False) -> bool: return StableSet.isorderedsubset(other, self, non_consecutive) - def get(self): + def get(self) -> Hashable: return next(iter(self._map)) - def freeze(self): + def freeze(self) -> None: """ Once this function is run, the object becomes immutable """ @@ -1210,7 +1212,7 @@ def pop(self, index=None): self.set_.remove(result) return result - def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False): + def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False) -> bool: if len(self) > len(other): return False if non_consecutive: @@ -1228,5 +1230,5 @@ def isorderedsubset(self: SetLike, other: SetLike, non_consecutive: bool = False return False return True - def isorderedsuperset(self, other: SetLike, non_consecutive: bool = False): + def isorderedsuperset(self, other: SetLike, non_consecutive: bool = False) -> bool: return StableSet.isorderedsubset(other, self, non_consecutive)