Skip to content
This repository was archived by the owner on Apr 4, 2024. It is now read-only.

Update ArrayMap #50

Merged
merged 5 commits into from
Apr 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 55 additions & 62 deletions python/selfie-lib/selfie_lib/ArrayMap.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,41 @@
from collections.abc import Set, Iterator, Mapping
from typing import List, TypeVar, Union, Any
from typing import List, TypeVar, Union, Any, Callable, Optional, Generator
from abc import abstractmethod, ABC
from functools import total_ordering

T = TypeVar("T")
V = TypeVar("V")
K = TypeVar("K")


@total_ordering
class Comparable:
def __init__(self, value):
self.value = value
class BinarySearchUtil:
@staticmethod
def binary_search(
data, item, compare_func: Optional[Callable[[Any, Any], int]] = None
) -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. since we're in Python, this doesn't have to be in a class. It could just be def _binary_search_util(data, item):

  2. This function is only ever called with two parameters -compare_func is always None. The fix is

def _compare_normal(a, b) -> int:
  if a == b:
    return 0
  else if a < b:
    return -1
  else
    return 1

def _compare_string_slash_first(a: str, b : str) -> int:
  return _compare_normal(a.replace("/", "\0"), b.replace("/", "\0"))

def _binary_search(data, item) -> int:
  compare_func = _compare_string_slash_first if item is str else _compare_normal
  low, high = 0, len(data) - 1
  ...
  the rest of the logic, but now you know that compare_func is never None

low, high = 0, len(data) - 1
while low <= high:
mid = (low + high) // 2
mid_val = data[mid] if not isinstance(data, ListBackedSet) else data[mid]
comparison = (
compare_func(mid_val, item)
if compare_func
else (mid_val > item) - (mid_val < item)
)

def __lt__(self, other: Any) -> bool:
if not isinstance(other, Comparable):
return NotImplemented
return self.value < other.value
if comparison < 0:
low = mid + 1
elif comparison > 0:
high = mid - 1
else:
return mid # item found
return -(low + 1) # item not found

def __eq__(self, other: Any) -> bool:
if not isinstance(other, Comparable):
return NotImplemented
return self.value == other.value
@staticmethod
def default_compare(a: Any, b: Any) -> int:
"""Default comparison function for binary search, with special handling for strings."""
if isinstance(a, str) and isinstance(b, str):
a, b = a.replace("/", "\0"), b.replace("/", "\0")
return (a > b) - (a < b)


class ListBackedSet(Set[T], ABC):
Expand All @@ -31,25 +45,14 @@ def __len__(self) -> int: ...
@abstractmethod
def __getitem__(self, index: Union[int, slice]) -> Union[T, List[T]]: ...

@abstractmethod
def __iter__(self) -> Iterator[T]: ...

def __contains__(self, item: Any) -> bool:
return self._binary_search(item) >= 0

def _binary_search(self, item: Any) -> int:
low = 0
high = len(self) - 1
while low <= high:
mid = (low + high) // 2
try:
mid_val = self[mid]
if mid_val < item:
low = mid + 1
elif mid_val > item:
high = mid - 1
else:
return mid # item found
except TypeError:
raise ValueError(f"Cannot compare items due to a type mismatch.")
return -(low + 1) # item not found
return BinarySearchUtil.binary_search(self, item)


class ArraySet(ListBackedSet[K]):
Expand Down Expand Up @@ -80,59 +83,49 @@ def __getitem__(self, index: Union[int, slice]) -> Union[K, List[K]]:
return self.__data[index]

def plusOrThis(self, element: K) -> "ArraySet[K]":
if element in self:
index = self._binary_search(element)
if index >= 0:
return self
else:
insert_at = -(index + 1)
new_data = self.__data[:]
new_data.append(element)
new_data.sort(key=Comparable)
new_data.insert(insert_at, element)
return ArraySet.__create(new_data)


class ArrayMap(Mapping[K, V]):
def __init__(self, data=None):
if data is None:
self.__data = []
else:
self.__data = data
__data: List[Union[K, V]]

def __init__(self):
raise NotImplementedError("Use ArrayMap.empty() or other class methods instead")

@classmethod
def __create(cls, data: List[Union[K, V]]) -> "ArrayMap[K, V]":
instance = cls.__new__(cls)
instance.__data = data
return instance

@classmethod
def empty(cls) -> "ArrayMap[K, V]":
if not hasattr(cls, "__EMPTY"):
cls.__EMPTY = cls([])
cls.__EMPTY = cls.__create([])
return cls.__EMPTY

def __getitem__(self, key: K) -> V:
index = self._binary_search_key(key)
if index >= 0:
return self.__data[2 * index + 1]
return self.__data[2 * index + 1] # type: ignore
raise KeyError(key)

def __iter__(self) -> Iterator[K]:
return (self.__data[i] for i in range(0, len(self.__data), 2))
return (self.__data[i] for i in range(0, len(self.__data), 2)) # type: ignore

def __len__(self) -> int:
return len(self.__data) // 2

def _binary_search_key(self, key: K) -> int:
def compare(a, b):
"""Comparator that puts '/' first in strings."""
if isinstance(a, str) and isinstance(b, str):
a, b = a.replace("/", "\0"), b.replace("/", "\0")
return (a > b) - (a < b)

low, high = 0, len(self.__data) // 2 - 1
while low <= high:
mid = (low + high) // 2
mid_key = self.__data[2 * mid]
comparison = compare(mid_key, key)
if comparison < 0:
low = mid + 1
elif comparison > 0:
high = mid - 1
else:
return mid # key found
return -(low + 1) # key not found
keys = [self.__data[i] for i in range(0, len(self.__data), 2)]
return BinarySearchUtil.binary_search(keys, key)

def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
index = self._binary_search_key(key)
Expand All @@ -142,12 +135,12 @@ def plus(self, key: K, value: V) -> "ArrayMap[K, V]":
new_data = self.__data[:]
new_data.insert(insert_at * 2, key)
new_data.insert(insert_at * 2 + 1, value)
return ArrayMap(new_data)
return ArrayMap.__create(new_data)

def minus_sorted_indices(self, indices: List[int]) -> "ArrayMap[K, V]":
new_data = self.__data[:]
adjusted_indices = [i * 2 for i in indices] + [i * 2 + 1 for i in indices]
adjusted_indices.sort()
for index in reversed(adjusted_indices):
adjusted_indices.sort(reverse=True)
for index in adjusted_indices:
del new_data[index]
return ArrayMap(new_data)
return ArrayMap.__create(new_data)