Skip to content
This repository was archived by the owner on Apr 4, 2024. It is now read-only.

Update ArrayMap #50

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 76 additions & 51 deletions python/selfie-lib/selfie_lib/ArrayMap.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,68 @@
from collections.abc import Set, Iterator, Mapping
from typing import List, TypeVar, Union
from typing import List, TypeVar, Union, Any, Callable, Optional, Generator
from abc import abstractmethod, ABC

T = TypeVar("T")
V = TypeVar("V")
K = TypeVar("K")


class BinarySearchUtil:
@staticmethod
def binary_search(
data, item, compare_func: Optional[Callable[[Any, Any], int]] = None
) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. since we're in Python, this doesn't have to be in a class. It could just be def _binary_search_util(data, item):

  2. This function is only ever called with two parameters -compare_func is always None. The fix is

def _compare_normal(a, b) -> int:
  if a == b:
    return 0
  else if a < b:
    return -1
  else
    return 1

def _compare_string_slash_first(a: str, b : str) -> int:
  return _compare_normal(a.replace("/", "\0"), b.replace("/", "\0"))

def _binary_search(data, item) -> int:
  compare_func = _compare_string_slash_first if item is str else _compare_normal
  low, high = 0, len(data) - 1
  ...
  the rest of the logic, but now you know that compare_func is never None

low, high = 0, len(data) - 1
while low <= high:
mid = (low + high) // 2
mid_val = data[mid] if not isinstance(data, ListBackedSet) else data[mid]
comparison = (
compare_func(mid_val, item)
if compare_func
else (mid_val > item) - (mid_val < item)
)

if comparison < 0:
low = mid + 1
elif comparison > 0:
high = mid - 1
else:
return mid # item found
return -(low + 1) # item not found

@staticmethod
def default_compare(a: Any, b: Any) -> int:
"""Default comparison function for binary search, with special handling for strings."""
if isinstance(a, str) and isinstance(b, str):
a, b = a.replace("/", "\0"), b.replace("/", "\0")
return (a > b) - (a < b)


class ListBackedSet(Set[T], ABC):
@abstractmethod
def __len__(self) -> int: ...

@abstractmethod
def __getitem__(self, index: Union[int, slice]) -> Union[T, List[T]]: ...

def __contains__(self, item: object) -> bool:
for i in range(len(self)):
if self[i] == item:
return True
return False
@abstractmethod
def __iter__(self) -> Iterator[T]: ...

def __contains__(self, item: Any) -> bool:
return self._binary_search(item) >= 0

def _binary_search(self, item: Any) -> int:
return BinarySearchUtil.binary_search(self, item)


class ArraySet(ListBackedSet[K]):
__data: List[K]

def __init__(self, data: List[K]):
raise NotImplementedError("Use ArraySet.empty() instead")
def __init__(self):
raise NotImplementedError("Use ArraySet.empty() or other class methods instead")

@classmethod
def __create(cls, data: List[K]) -> "ArraySet[K]":
# Create a new instance without calling __init__
instance = super().__new__(cls)
instance.__data = data
return instance
Expand All @@ -40,82 +73,74 @@ def __iter__(self) -> Iterator[K]:
@classmethod
def empty(cls) -> "ArraySet[K]":
if not hasattr(cls, "__EMPTY"):
cls.__EMPTY = cls([])
cls.__EMPTY = cls.__create([])
return cls.__EMPTY

def __len__(self) -> int:
return len(self.__data)

def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]:
if isinstance(index, int):
return self.__data[index]
elif isinstance(index, slice):
return self.__data[index]
else:
raise TypeError("Invalid argument type.")
return self.__data[index]

def plusOrThis(self, element: K) -> "ArraySet[K]":
# TODO: use binary search, and also special sort order for strings
if element in self.__data:
index = self._binary_search(element)
if index >= 0:
return self
else:
insert_at = -(index + 1)
new_data = self.__data[:]
new_data.append(element)
new_data.sort() # type: ignore[reportOperatorIssue]
new_data.insert(insert_at, element)
return ArraySet.__create(new_data)


class ArrayMap(Mapping[K, V]):
def __init__(self, data: list):
# TODO: hide this constructor as done in ArraySet
self.__data = data
__data: List[Union[K, V]]

def __init__(self):
raise NotImplementedError("Use ArrayMap.empty() or other class methods instead")

@classmethod
def __create(cls, data: List[Union[K, V]]) -> "ArrayMap[K, V]":
instance = cls.__new__(cls)
instance.__data = data
return instance

@classmethod
def empty(cls) -> "ArrayMap[K, V]":
if not hasattr(cls, "__EMPTY"):
cls.__EMPTY = cls([])
cls.__EMPTY = cls.__create([])
return cls.__EMPTY

def __getitem__(self, key: K) -> V:
index = self.__binary_search_key(key)
index = self._binary_search_key(key)
if index >= 0:
return self.__data[2 * index + 1]
return self.__data[2 * index + 1] # type: ignore
raise KeyError(key)

def __iter__(self) -> Iterator[K]:
return (self.__data[i] for i in range(0, len(self.__data), 2))
return (self.__data[i] for i in range(0, len(self.__data), 2)) # type: ignore

def __len__(self) -> int:
return len(self.__data) // 2

def __binary_search_key(self, key: K) -> int:
# TODO: special sort order for strings
low, high = 0, (len(self.__data) // 2) - 1
while low <= high:
mid = (low + high) // 2
mid_key = self.__data[2 * mid]
if mid_key < key:
low = mid + 1
elif mid_key > key:
high = mid - 1
else:
return mid
return -(low + 1)
def _binary_search_key(self, key: K) -> int:
keys = [self.__data[i] for i in range(0, len(self.__data), 2)]
return BinarySearchUtil.binary_search(keys, key)

def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
index = self.__binary_search_key(key)
index = self._binary_search_key(key)
if index >= 0:
raise ValueError("Key already exists")
insert_at = -(index + 1)
new_data = self.__data[:]
new_data[insert_at * 2 : insert_at * 2] = [key, value]
return ArrayMap(new_data)
new_data.insert(insert_at * 2, key)
new_data.insert(insert_at * 2 + 1, value)
return ArrayMap.__create(new_data)

def minus_sorted_indices(self, indicesToRemove: List[int]) -> "ArrayMap[K, V]":
if not indicesToRemove:
return self
newData = []
for i in range(0, len(self.__data), 2):
if i // 2 not in indicesToRemove:
newData.extend(self.__data[i : i + 2])
return ArrayMap(newData)
def minus_sorted_indices(self, indices: List[int]) -> "ArrayMap[K, V]":
new_data = self.__data[:]
adjusted_indices = [i * 2 for i in indices] + [i * 2 + 1 for i in indices]
adjusted_indices.sort(reverse=True)
for index in adjusted_indices:
del new_data[index]
return ArrayMap.__create(new_data)