Skip to content

feat: add Fusion Tree #631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions docs/source/pydatastructs/trees/m_ary_trees.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
M-ary Trees
===========

.. autoclass:: pydatastructs.FusionTree
6 changes: 4 additions & 2 deletions pydatastructs/trees/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
CartesianTree,
Treap,
SplayTree,
RedBlackTree
RedBlackTree,
)
__all__.extend(binary_trees.__all__)

from .m_ary_trees import (
MAryTreeNode, MAryTree
MAryTreeNode,
MAryTree,
FusionTree
)

__all__.extend(m_ary_trees.__all__)
Expand Down
2 changes: 1 addition & 1 deletion pydatastructs/trees/binary_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
'CartesianTree',
'Treap',
'SplayTree',
'RedBlackTree'
'RedBlackTree',
]

class BinaryTree(object):
Expand Down
162 changes: 161 additions & 1 deletion pydatastructs/trees/m_ary_trees.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import math
from pydatastructs.utils import MAryTreeNode
from pydatastructs.linear_data_structures.arrays import ArrayForTrees
from pydatastructs.utils.misc_util import (
Backend, raise_if_backend_is_not_python)

__all__ = [
'MAryTree'
'MAryTree',
'FusionTree'
]

class MAryTree(object):
Expand Down Expand Up @@ -170,3 +172,161 @@ def __str__(self):
if j is not None:
to_be_printed[i].append(j)
return str(to_be_printed)


class FusionTree(MAryTree):
"""
Implements a Fusion Tree, a multi-way search tree optimized for integer keys.

Parameters
==========

key: int
The integer key to insert.
root_data: Any
Optional data to store with the key.
backend: pydatastructs.Backend
The backend to be used. Available backends: Python and C++
Optional, by default, the Python backend is used. For faster execution, use the C++ backend.
word_size: int
The size of the integer keys in bits.
Optional, by default, set to 64.

Examples
========

>>> from pydatastructs import FusionTree
>>> ft = FusionTree()
>>> ft.insert(1, 1)
>>> ft.insert(2, 2)
>>> ft.search(1)
0
>>> ft.delete(1)
True
>>> ft.search(1)


References:
- https://en.wikipedia.org/wiki/Fusion_tree
- Fredman & Willard (1990): "Fusion Trees"
"""

__slots__ = ['root_idx', 'tree', 'size', 'B',
'sketch_mask', 'fingerprint_multiplier']

def __new__(cls, key=None, root_data=None, **kwargs):
backend = kwargs.get('backend', Backend.PYTHON)
raise_if_backend_is_not_python(cls, backend)

obj = object.__new__(cls)
key = None if root_data is None else key
root = MAryTreeNode(key, root_data)
root.is_root = True
obj.root_idx = 0
obj.tree, obj.size = ArrayForTrees(MAryTreeNode, [root]), 1
obj.B = int(math.log2(kwargs.get('word_size', 64))
** (1/5)) # Multi-way branching factor
obj.sketch_mask = 0 # Computed dynamically
obj.fingerprint_multiplier = 2654435761 # Prime multiplier for fingerprinting
return obj

def _compute_sketch_mask(self):
"""
Computes a sketch mask for efficient parallel comparisons.
"""
keys = [node.key for node in self.tree if node is not None]
if len(keys) > 1:
significant_bits = [max(k.bit_length() for k in keys)]
self.sketch_mask = sum(1 << b for b in significant_bits)

def insert(self, key, data=None):
"""
Inserts a key into the Fusion Tree.

Parameters
==========

key: int
The integer key to insert.
data: Any
Optional data to store with the key.
"""
# Edge case for root node if not intially inserted
if self.size == 1 and self.tree[0].key is None:
self.tree[0] = MAryTreeNode(key, data)
self.tree[0].is_root = True
return

node = MAryTreeNode(key, data)
self.tree.append(node)
self.size += 1
if self.size > 1:
self._compute_sketch_mask()

def _sketch_key(self, key):
"""
Applies the sketch mask to compress the key for fast comparison.
"""
return key & self.sketch_mask

def _fingerprint(self, key):
"""
Uses multiplication-based fingerprinting to create a unique identifier
for the key, allowing fast parallel searches.
"""
return (key * self.fingerprint_multiplier) & ((1 << 64) - 1)

def search(self, key):
"""
Searches for a key in the Fusion Tree using bitwise sketching and fingerprinting.

Parameters
==========

key: int
The integer key to search.

Returns
=======

int: The index of the key in the tree, or None if not found.
"""
sketch = self._sketch_key(key)
fingerprint = self._fingerprint(key)
for i in range(self.size):
if self._sketch_key(self.tree[i].key) == sketch and self._fingerprint(self.tree[i].key) == fingerprint:
return i
return None

def delete(self, key):
"""
Deletes a key from the Fusion Tree.

Parameters
==========

key: int
The integer key to delete.

Returns
=======

bool: True if the key was successfully deleted, False otherwise.

"""
index = self.search(key)
if index is not None:
self.tree[index] = None # Soft delete
# Compact tree
self.tree = [node for node in self.tree if node is not None]
self.size -= 1
if self.size > 1:
self._compute_sketch_mask()
return True
return False

def __str__(self):
"""
Returns a string representation of the Fusion Tree.
"""
return str([(node.key, node.data) for node in self.tree if node is not None])
2 changes: 1 addition & 1 deletion pydatastructs/trees/tests/test_binary_trees.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydatastructs.trees.binary_trees import (
BinaryTree, BinarySearchTree, BinaryTreeTraversal, AVLTree,
BinarySearchTree, BinaryTreeTraversal, AVLTree,
ArrayForTrees, BinaryIndexedTree, SelfBalancingBinaryTree, SplayTree, CartesianTree, Treap, RedBlackTree)
from pydatastructs.utils.raises_util import raises
from pydatastructs.utils.misc_util import TreeNode
Expand Down
90 changes: 89 additions & 1 deletion pydatastructs/trees/tests/test_m_ary_trees.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,93 @@
from pydatastructs import MAryTree
from pydatastructs.utils.misc_util import Backend
from pydatastructs.trees.m_ary_trees import MAryTree, FusionTree

def test_MAryTree():
m = MAryTree(1, 1)
assert str(m) == '[(1, 1)]'

def _test_FusionTree(backend):
FT = FusionTree
f_tree = FT(backend=backend)

f_tree.insert(8, 8)
f_tree.insert(3, 3)
f_tree.insert(10, 10)
f_tree.insert(1, 1)
f_tree.insert(6, 6)
f_tree.insert(4, 4)
f_tree.insert(7, 7)
f_tree.insert(14, 14)
f_tree.insert(13, 13)

assert f_tree.search(10) is not None
assert f_tree.search(-1) is None

assert f_tree.delete(13) is True
assert f_tree.search(13) is None
assert f_tree.delete(10) is True
assert f_tree.search(10) is None
assert f_tree.delete(3) is True
assert f_tree.search(3) is None
assert f_tree.delete(13) is False # Already deleted

expected_str = '[(8, 8), (1, 1), (6, 6), (4, 4), (7, 7), (14, 14)]'
assert str(f_tree) == expected_str

f_tree.insert(8, 9)
assert f_tree.search(8) is not None

large_key = 10**9
f_tree.insert(large_key, large_key)
assert f_tree.search(large_key) is not None

expected_str = '[(8, 8), (1, 1), (6, 6), (4, 4), (7, 7), (14, 14), (8, 9), (1000000000, 1000000000)]'
assert str(f_tree) == expected_str
assert f_tree.delete(8) is True

expected_str = '[(1, 1), (6, 6), (4, 4), (7, 7), (14, 14), (8, 9), (1000000000, 1000000000)]'
assert str(f_tree) == expected_str

FT = FusionTree
f_tree = FT(8, 8, backend=backend)

f_tree.insert(8, 8)
f_tree.insert(3, 3)
f_tree.insert(10, 10)
f_tree.insert(1, 1)
f_tree.insert(6, 6)
f_tree.insert(4, 4)
f_tree.insert(7, 7)
f_tree.insert(14, 14)
f_tree.insert(13, 13)

assert f_tree.search(10) is not None
assert f_tree.search(-1) is None

assert f_tree.delete(13) is True
assert f_tree.search(13) is None
assert f_tree.delete(10) is True
assert f_tree.search(10) is None
assert f_tree.delete(3) is True
assert f_tree.search(3) is None
assert f_tree.delete(13) is False # Already deleted

expected_str = '[(8, 8), (8, 8), (1, 1), (6, 6), (4, 4), (7, 7), (14, 14)]'
assert str(f_tree) == expected_str

f_tree.insert(8, 9)
assert f_tree.search(8) is not None

large_key = 10**9
f_tree.insert(large_key, large_key)
assert f_tree.search(large_key) is not None

expected_str = '[(8, 8), (8, 8), (1, 1), (6, 6), (4, 4), (7, 7), (14, 14), (8, 9), (1000000000, 1000000000)]'
assert str(f_tree) == expected_str
assert f_tree.delete(8) is True

expected_str = '[(8, 8), (1, 1), (6, 6), (4, 4), (7, 7), (14, 14), (8, 9), (1000000000, 1000000000)]'
assert str(f_tree) == expected_str


def test_FusionTree():
_test_FusionTree(Backend.PYTHON)
Loading