11from collections .abc import Set , Iterator , Mapping
2- from typing import List , TypeVar , Union , Any , Callable , Optional , Generator
2+ from typing import List , TypeVar , Union , Any
33from abc import abstractmethod , ABC
44
55T = TypeVar ("T" )
66V = TypeVar ("V" )
77K = TypeVar ("K" )
88
99
10- class BinarySearchUtil :
11- @staticmethod
12- def binary_search (
13- data , item , compare_func : Optional [Callable [[Any , Any ], int ]] = None
14- ) -> int :
15- low , high = 0 , len (data ) - 1
16- while low <= high :
17- mid = (low + high ) // 2
18- mid_val = data [mid ] if not isinstance (data , ListBackedSet ) else data [mid ]
19- comparison = (
20- compare_func (mid_val , item )
21- if compare_func
22- else (mid_val > item ) - (mid_val < item )
23- )
24-
25- if comparison < 0 :
26- low = mid + 1
27- elif comparison > 0 :
28- high = mid - 1
29- else :
30- return mid # item found
31- return - (low + 1 ) # item not found
32-
33- @staticmethod
34- def default_compare (a : Any , b : Any ) -> int :
35- """Default comparison function for binary search, with special handling for strings."""
36- if isinstance (a , str ) and isinstance (b , str ):
37- a , b = a .replace ("/" , "\0 " ), b .replace ("/" , "\0 " )
38- return (a > b ) - (a < b )
10+ def _compare_normal (a , b ) -> int :
11+ if a == b :
12+ return 0
13+ elif a < b :
14+ return - 1
15+ else :
16+ return 1
17+
18+ def _compare_string_slash_first (a : str , b : str ) -> int :
19+ return _compare_normal (a .replace ("/" , "\0 " ), b .replace ("/" , "\0 " ))
20+
21+ def _binary_search (data , item ) -> int :
22+ compare_func = _compare_string_slash_first if isinstance (item , str ) else _compare_normal
23+ low , high = 0 , len (data ) - 1
24+ while low <= high :
25+ mid = (low + high ) // 2
26+ mid_val = data [mid ]
27+ comparison = compare_func (mid_val , item )
28+
29+ if comparison < 0 :
30+ low = mid + 1
31+ elif comparison > 0 :
32+ high = mid - 1
33+ else :
34+ return mid # item found
35+ return - (low + 1 ) # item not found
3936
4037
4138class ListBackedSet (Set [T ], ABC ):
@@ -52,7 +49,7 @@ def __contains__(self, item: Any) -> bool:
5249 return self ._binary_search (item ) >= 0
5350
5451 def _binary_search (self , item : Any ) -> int :
55- return BinarySearchUtil . binary_search (self , item )
52+ return _binary_search (self , item )
5653
5754
5855class ArraySet (ListBackedSet [K ]):
@@ -125,7 +122,7 @@ def __len__(self) -> int:
125122
126123 def _binary_search_key (self , key : K ) -> int :
127124 keys = [self .__data [i ] for i in range (0 , len (self .__data ), 2 )]
128- return BinarySearchUtil . binary_search (keys , key )
125+ return _binary_search (keys , key )
129126
130127 def plus (self , key : K , value : V ) -> "ArrayMap[K, V]" :
131128 index = self ._binary_search_key (key )
0 commit comments