1
1
from collections .abc import Set , Iterator , Mapping
2
- from typing import List , TypeVar , Union , Any , Callable , Optional , Generator
2
+ from typing import List , TypeVar , Union , Any
3
3
from abc import abstractmethod , ABC
4
4
5
5
T = TypeVar ("T" )
6
6
V = TypeVar ("V" )
7
7
K = TypeVar ("K" )
8
8
9
9
10
- class BinarySearchUtil :
11
- @staticmethod
12
- def binary_search (
13
- data , item , compare_func : Optional [Callable [[Any , Any ], int ]] = None
14
- ) -> int :
15
- low , high = 0 , len (data ) - 1
16
- while low <= high :
17
- mid = (low + high ) // 2
18
- mid_val = data [mid ] if not isinstance (data , ListBackedSet ) else data [mid ]
19
- comparison = (
20
- compare_func (mid_val , item )
21
- if compare_func
22
- else (mid_val > item ) - (mid_val < item )
23
- )
24
-
25
- if comparison < 0 :
26
- low = mid + 1
27
- elif comparison > 0 :
28
- high = mid - 1
29
- else :
30
- return mid # item found
31
- return - (low + 1 ) # item not found
32
-
33
- @staticmethod
34
- def default_compare (a : Any , b : Any ) -> int :
35
- """Default comparison function for binary search, with special handling for strings."""
36
- if isinstance (a , str ) and isinstance (b , str ):
37
- a , b = a .replace ("/" , "\0 " ), b .replace ("/" , "\0 " )
38
- return (a > b ) - (a < b )
10
+ def _compare_normal (a , b ) -> int :
11
+ if a == b :
12
+ return 0
13
+ elif a < b :
14
+ return - 1
15
+ else :
16
+ return 1
17
+
18
+ def _compare_string_slash_first (a : str , b : str ) -> int :
19
+ return _compare_normal (a .replace ("/" , "\0 " ), b .replace ("/" , "\0 " ))
20
+
21
+ def _binary_search (data , item ) -> int :
22
+ compare_func = _compare_string_slash_first if isinstance (item , str ) else _compare_normal
23
+ low , high = 0 , len (data ) - 1
24
+ while low <= high :
25
+ mid = (low + high ) // 2
26
+ mid_val = data [mid ]
27
+ comparison = compare_func (mid_val , item )
28
+
29
+ if comparison < 0 :
30
+ low = mid + 1
31
+ elif comparison > 0 :
32
+ high = mid - 1
33
+ else :
34
+ return mid # item found
35
+ return - (low + 1 ) # item not found
39
36
40
37
41
38
class ListBackedSet (Set [T ], ABC ):
@@ -52,7 +49,7 @@ def __contains__(self, item: Any) -> bool:
52
49
return self ._binary_search (item ) >= 0
53
50
54
51
def _binary_search (self , item : Any ) -> int :
55
- return BinarySearchUtil . binary_search (self , item )
52
+ return _binary_search (self , item )
56
53
57
54
58
55
class ArraySet (ListBackedSet [K ]):
@@ -125,7 +122,7 @@ def __len__(self) -> int:
125
122
126
123
def _binary_search_key (self , key : K ) -> int :
127
124
keys = [self .__data [i ] for i in range (0 , len (self .__data ), 2 )]
128
- return BinarySearchUtil . binary_search (keys , key )
125
+ return _binary_search (keys , key )
129
126
130
127
def plus (self , key : K , value : V ) -> "ArrayMap[K, V]" :
131
128
index = self ._binary_search_key (key )
0 commit comments