Skip to content
This repository was archived by the owner on Apr 4, 2024. It is now read-only.

Commit d6b542f

Browse files
committed
Update binary search
1 parent 1845272 commit d6b542f

File tree

1 file changed

+29
-32
lines changed

1 file changed

+29
-32
lines changed

Diff for: python/selfie-lib/selfie_lib/ArrayMap.py

+29-32
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,38 @@
11
from collections.abc import Set, Iterator, Mapping
2-
from typing import List, TypeVar, Union, Any, Callable, Optional, Generator
2+
from typing import List, TypeVar, Union, Any
33
from abc import abstractmethod, ABC
44

55
T = TypeVar("T")
66
V = TypeVar("V")
77
K = TypeVar("K")
88

99

10-
class BinarySearchUtil:
11-
@staticmethod
12-
def binary_search(
13-
data, item, compare_func: Optional[Callable[[Any, Any], int]] = None
14-
) -> int:
15-
low, high = 0, len(data) - 1
16-
while low <= high:
17-
mid = (low + high) // 2
18-
mid_val = data[mid] if not isinstance(data, ListBackedSet) else data[mid]
19-
comparison = (
20-
compare_func(mid_val, item)
21-
if compare_func
22-
else (mid_val > item) - (mid_val < item)
23-
)
24-
25-
if comparison < 0:
26-
low = mid + 1
27-
elif comparison > 0:
28-
high = mid - 1
29-
else:
30-
return mid # item found
31-
return -(low + 1) # item not found
32-
33-
@staticmethod
34-
def default_compare(a: Any, b: Any) -> int:
35-
"""Default comparison function for binary search, with special handling for strings."""
36-
if isinstance(a, str) and isinstance(b, str):
37-
a, b = a.replace("/", "\0"), b.replace("/", "\0")
38-
return (a > b) - (a < b)
10+
def _compare_normal(a, b) -> int:
11+
if a == b:
12+
return 0
13+
elif a < b:
14+
return -1
15+
else:
16+
return 1
17+
18+
def _compare_string_slash_first(a: str, b: str) -> int:
19+
return _compare_normal(a.replace("/", "\0"), b.replace("/", "\0"))
20+
21+
def _binary_search(data, item) -> int:
22+
compare_func = _compare_string_slash_first if isinstance(item, str) else _compare_normal
23+
low, high = 0, len(data) - 1
24+
while low <= high:
25+
mid = (low + high) // 2
26+
mid_val = data[mid]
27+
comparison = compare_func(mid_val, item)
28+
29+
if comparison < 0:
30+
low = mid + 1
31+
elif comparison > 0:
32+
high = mid - 1
33+
else:
34+
return mid # item found
35+
return -(low + 1) # item not found
3936

4037

4138
class ListBackedSet(Set[T], ABC):
@@ -52,7 +49,7 @@ def __contains__(self, item: Any) -> bool:
5249
return self._binary_search(item) >= 0
5350

5451
def _binary_search(self, item: Any) -> int:
55-
return BinarySearchUtil.binary_search(self, item)
52+
return _binary_search(self, item)
5653

5754

5855
class ArraySet(ListBackedSet[K]):
@@ -125,7 +122,7 @@ def __len__(self) -> int:
125122

126123
def _binary_search_key(self, key: K) -> int:
127124
keys = [self.__data[i] for i in range(0, len(self.__data), 2)]
128-
return BinarySearchUtil.binary_search(keys, key)
125+
return _binary_search(keys, key)
129126

130127
def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
131128
index = self._binary_search_key(key)

0 commit comments

Comments
 (0)