Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon-Rey committed Nov 13, 2024
2 parents 6c5850d + de0c099 commit fad41dc
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 41 deletions.
2 changes: 1 addition & 1 deletion prefsampling/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__author__ = "Simon Rey and Stanisław Szufa"
__email__ = "[email protected]"
__version__ = "0.1.22"
__version__ = "0.1.24"

from enum import Enum
from itertools import chain
Expand Down
30 changes: 16 additions & 14 deletions prefsampling/ordinal/groupseparable.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
schroeder_tree_brute_force,
schroeder_tree_lescanne,
)
from prefsampling.tree.balanced import balanced_tree
from prefsampling.tree.caterpillar import caterpillar_tree
from prefsampling.inputvalidators import validate_num_voters_candidates
from prefsampling.combinatorics import comb

Expand All @@ -36,15 +38,15 @@ class TreeSampler(Enum):
Random Schröder trees sampled following Lescanne (2022)
"""

# CATERPILLAR = "Caterpillar Tree"
# """
# Caterpillar trees
# """
#
# BALANCED = "Balanced Tree"
# """
# Balanced trees
# """
CATERPILLAR = "Caterpillar Tree"
"""
Caterpillar trees
"""

BALANCED = "Balanced Tree"
"""
Balanced trees
"""


@validate_num_voters_candidates
Expand Down Expand Up @@ -159,10 +161,10 @@ def group_separable(
)
else:
raise ValueError("There is something weird with the tree_sampler value...")
# elif tree_sampler == TreeSampler.CATERPILLAR:
# tree_root = caterpillar_tree(num_candidates)
# elif tree_sampler == TreeSampler.BALANCED:
# tree_root = balanced_tree(num_candidates)
elif tree_sampler == TreeSampler.CATERPILLAR:
tree_root = caterpillar_tree(num_candidates)
elif tree_sampler == TreeSampler.BALANCED:
tree_root = balanced_tree(num_candidates)
else:
raise ValueError(
"The `tree` argument needs to be one of the constant defined in the "
Expand All @@ -183,7 +185,7 @@ def group_separable(
signatures = np.zeros((num_voters - 1, num_internal_nodes), dtype=bool)
for r in range(num_internal_nodes):
values_at_pos = rng.choice((True, False), size=num_voters - 1)
while r > 0 and not any(values_at_pos):
while r > 0 and not any(values_at_pos) and num_voters - 1 > 0:
values_at_pos = rng.choice((True, False), size=num_voters - 1)
for i in range(num_voters - 1):
signatures[i][r] = values_at_pos[i]
Expand Down
10 changes: 3 additions & 7 deletions prefsampling/ordinal/singlecrossing.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,9 @@ def single_crossing(
new_line[swap_indices[1]] = domain[line - 1][swap_indices[0]]
domain.append(new_line)

votes = []
last_sampled_index = 0
votes.append(domain[0])
for i in range(1, num_voters):
index = rng.integers(last_sampled_index, domain_size)
votes.append(domain[index])
last_sampled_index = index
indeces = rng.integers(0, domain_size, num_voters)
indeces.sort()
votes = [domain[index] for index in indeces]

return votes

Expand Down
2 changes: 1 addition & 1 deletion prefsampling/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
"schroeder_tree_brute_force",
"all_schroeder_tree",
"caterpillar_tree",
"caterpillar_tree",
"balanced_tree",
]
20 changes: 7 additions & 13 deletions prefsampling/tree/balanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from prefsampling.tree.node import Node


def balanced_tree(num_leaves: int) -> Node:
def balanced_tree(num_leaves: int, seed: int = None) -> Node:
"""
Generates a balanced tree.
Expand All @@ -17,26 +17,20 @@ def balanced_tree(num_leaves: int) -> Node:
"""
validate_int(num_leaves, "number of leaves", lower_bound=1)
num_leaves = int(num_leaves)
root = Node("root")
ctr = 0
root = Node(0)
if num_leaves == 1:
return root
ctr = 1

q = queue.Queue()
q.put(root)

while q.qsize() * 2 < num_leaves:
while ctr < 2*num_leaves-1:
tmp_root = q.get()
for _ in range(2):
inner_node = Node(ctr)
ctr += 1
tmp_root.add_child(inner_node)
q.put(inner_node)
ctr += 1

ctr = 0
while ctr < num_leaves:
tmp_root = q.get()
for _ in range(2):
node = Node(ctr)
tmp_root.add_child(node)
ctr += 1

return root
6 changes: 4 additions & 2 deletions prefsampling/tree/caterpillar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@ def caterpillar_tree(num_leaves: int, seed: int = None) -> Node:

while num_leaves > 2:
leaf = Node(ctr)
ctr += 1
inner_node = Node(ctr)
ctr += 1
tmp_root.add_child(leaf)
tmp_root.add_child(inner_node)
tmp_root = inner_node
num_leaves -= 1
ctr += 1

leaf_1 = Node(ctr)
leaf_2 = Node(ctr + 1)
ctr += 1
leaf_2 = Node(ctr)
tmp_root.add_child(leaf_1)
tmp_root.add_child(leaf_2)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "prefsampling"
version = "0.1.22"
version = "0.1.24"
description = "Algorithms to sample preferences of all kinds."
authors = [
{ name = "Simon Rey", email = "[email protected]" },
Expand Down
2 changes: 2 additions & 0 deletions tests/test_samplers/ordinal/test_ordinal_group_separable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ def all_test_samplers_ordinal_group_separable():
for tree_sampler in [
TreeSampler.SCHROEDER,
TreeSampler.SCHROEDER_LESCANNE,
TreeSampler.CATERPILLAR,
TreeSampler.BALANCED
]
]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def test_all_schroeder_trees(self):
5: 45,
6: 197,
7: 903,
8: 4279,
9: 20793,
# 8: 4279,
# 9: 20793,
# 10: 103049,
# 11: 518859,
# 12: 2646723
Expand Down

0 comments on commit fad41dc

Please sign in to comment.