Skip to content
This repository was archived by the owner on Apr 4, 2024. It is now read-only.

Commit e13cf01

Browse files
committed
Implement special sort order and binary search
1 parent 4d6d4c4 commit e13cf01

File tree

1 file changed

+78
-54
lines changed

1 file changed

+78
-54
lines changed
Lines changed: 78 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,33 @@
1+
from abc import ABC, abstractmethod
12
from collections.abc import Set, Iterator, Mapping
2-
from typing import List, TypeVar, Union
3-
from abc import abstractmethod, ABC
3+
from typing import List, TypeVar, Union, Any, Tuple
4+
import bisect
5+
6+
7+
class Comparable:
8+
def __lt__(self, other: Any) -> bool:
9+
return NotImplemented
10+
11+
def __le__(self, other: Any) -> bool:
12+
return NotImplemented
13+
14+
def __gt__(self, other: Any) -> bool:
15+
return NotImplemented
16+
17+
def __ge__(self, other: Any) -> bool:
18+
return NotImplemented
19+
420

521
T = TypeVar("T")
622
V = TypeVar("V")
7-
K = TypeVar("K")
23+
K = TypeVar("K", bound=Comparable)
24+
25+
26+
def string_slash_first_comparator(a: Any, b: Any) -> int:
27+
"""Special comparator for strings where '/' is considered the lowest."""
28+
if isinstance(a, str) and isinstance(b, str):
29+
return (a.replace("/", "\0"), a) < (b.replace("/", "\0"), b)
30+
return (a < b) - (a > b)
831

932

1033
class ListBackedSet(Set[T], ABC):
@@ -15,107 +38,108 @@ def __len__(self) -> int: ...
1538
def __getitem__(self, index: Union[int, slice]) -> Union[T, List[T]]: ...
1639

1740
def __contains__(self, item: object) -> bool:
18-
for i in range(len(self)):
19-
if self[i] == item:
20-
return True
21-
return False
41+
try:
42+
index = self.__binary_search(item)
43+
except ValueError:
44+
return False
45+
return index >= 0
46+
47+
@abstractmethod
48+
def __binary_search(self, item: Any) -> int: ...
2249

2350

2451
class ArraySet(ListBackedSet[K]):
2552
__data: List[K]
2653

27-
def __init__(self, data: List[K]):
28-
raise NotImplementedError("Use ArraySet.empty() instead")
54+
def __init__(self):
55+
raise NotImplementedError("Use ArraySet.empty() or other class methods instead")
2956

3057
@classmethod
3158
def __create(cls, data: List[K]) -> "ArraySet[K]":
32-
# Create a new instance without calling __init__
3359
instance = super().__new__(cls)
3460
instance.__data = data
3561
return instance
3662

37-
def __iter__(self) -> Iterator[K]:
38-
return iter(self.__data)
39-
4063
@classmethod
4164
def empty(cls) -> "ArraySet[K]":
4265
if not hasattr(cls, "__EMPTY"):
43-
cls.__EMPTY = cls([])
66+
cls.__EMPTY = cls.__create([])
4467
return cls.__EMPTY
4568

4669
def __len__(self) -> int:
4770
return len(self.__data)
4871

4972
def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]:
50-
if isinstance(index, int):
51-
return self.__data[index]
52-
elif isinstance(index, slice):
53-
return self.__data[index]
54-
else:
55-
raise TypeError("Invalid argument type.")
73+
return self.__data[index]
74+
75+
def __binary_search(self, item: K) -> int:
76+
if isinstance(item, str):
77+
key = lambda x: x.replace("/", "\0")
78+
return (
79+
bisect.bisect_left(self.__data, item, key=key) - 1
80+
if item in self.__data
81+
else -1
82+
)
83+
return bisect.bisect_left(self.__data, item) - 1 if item in self.__data else -1
5684

5785
def plusOrThis(self, element: K) -> "ArraySet[K]":
58-
# TODO: use binary search, and also special sort order for strings
59-
if element in self.__data:
86+
index = self.__binary_search(element)
87+
if index >= 0:
6088
return self
61-
else:
62-
new_data = self.__data[:]
63-
new_data.append(element)
64-
new_data.sort() # type: ignore[reportOperatorIssue]
65-
return ArraySet.__create(new_data)
89+
new_data = self.__data[:]
90+
bisect.insort_left(new_data, element)
91+
return ArraySet.__create(new_data)
6692

6793

6894
class ArrayMap(Mapping[K, V]):
69-
def __init__(self, data: list):
70-
# TODO: hide this constructor as done in ArraySet
71-
self.__data = data
95+
__data: List[Tuple[K, V]]
96+
97+
def __init__(self):
98+
raise NotImplementedError("Use ArrayMap.empty() or other class methods instead")
99+
100+
@classmethod
101+
def __create(cls, data: List[Tuple[K, V]]) -> "ArrayMap[K, V]":
102+
instance = super().__new__(cls)
103+
instance.__data = data
104+
return instance
72105

73106
@classmethod
74107
def empty(cls) -> "ArrayMap[K, V]":
75108
if not hasattr(cls, "__EMPTY"):
76-
cls.__EMPTY = cls([])
109+
cls.__EMPTY = cls.__create([])
77110
return cls.__EMPTY
78111

79112
def __getitem__(self, key: K) -> V:
80113
index = self.__binary_search_key(key)
81114
if index >= 0:
82-
return self.__data[2 * index + 1]
115+
return self.__data[index][1]
83116
raise KeyError(key)
84117

85118
def __iter__(self) -> Iterator[K]:
86-
return (self.__data[i] for i in range(0, len(self.__data), 2))
119+
return (key for key, _ in self.__data)
87120

88121
def __len__(self) -> int:
89-
return len(self.__data) // 2
122+
return len(self.__data)
90123

91124
def __binary_search_key(self, key: K) -> int:
92-
# TODO: special sort order for strings
93-
low, high = 0, (len(self.__data) // 2) - 1
94-
while low <= high:
95-
mid = (low + high) // 2
96-
mid_key = self.__data[2 * mid]
97-
if mid_key < key:
98-
low = mid + 1
99-
elif mid_key > key:
100-
high = mid - 1
101-
else:
102-
return mid
103-
return -(low + 1)
125+
keys = [k for k, _ in self.__data]
126+
index = bisect.bisect_left(keys, key)
127+
if index < len(keys) and keys[index] == key:
128+
return index
129+
return -1
104130

105131
def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
106132
index = self.__binary_search_key(key)
107133
if index >= 0:
108134
raise ValueError("Key already exists")
109-
insert_at = -(index + 1)
110135
new_data = self.__data[:]
111-
new_data[insert_at * 2 : insert_at * 2] = [key, value]
112-
return ArrayMap(new_data)
136+
bisect.insort_left(new_data, (key, value))
137+
return ArrayMap.__create(new_data)
113138

114139
def minus_sorted_indices(self, indicesToRemove: List[int]) -> "ArrayMap[K, V]":
115140
if not indicesToRemove:
116141
return self
117-
newData = []
118-
for i in range(0, len(self.__data), 2):
119-
if i // 2 not in indicesToRemove:
120-
newData.extend(self.__data[i : i + 2])
121-
return ArrayMap(newData)
142+
new_data = [
143+
item for i, item in enumerate(self.__data) if i not in indicesToRemove
144+
]
145+
return ArrayMap.__create(new_data)

0 commit comments

Comments
 (0)