diff --git a/prefsampling/__init__.py b/prefsampling/__init__.py index 5eb0eff..4e0367f 100644 --- a/prefsampling/__init__.py +++ b/prefsampling/__init__.py @@ -1,6 +1,6 @@ __author__ = "Simon Rey and Stanisław Szufa" __email__ = "reysimon@orange.fr" -__version__ = "0.1.22" +__version__ = "0.1.24" from enum import Enum from itertools import chain diff --git a/prefsampling/ordinal/groupseparable.py b/prefsampling/ordinal/groupseparable.py index 1a5c47e..6665c16 100644 --- a/prefsampling/ordinal/groupseparable.py +++ b/prefsampling/ordinal/groupseparable.py @@ -11,6 +11,8 @@ schroeder_tree_brute_force, schroeder_tree_lescanne, ) +from prefsampling.tree.balanced import balanced_tree +from prefsampling.tree.caterpillar import caterpillar_tree from prefsampling.inputvalidators import validate_num_voters_candidates from prefsampling.combinatorics import comb @@ -36,15 +38,15 @@ class TreeSampler(Enum): Random Schröder trees sampled following Lescanne (2022) """ - # CATERPILLAR = "Caterpillar Tree" - # """ - # Caterpillar trees - # """ - # - # BALANCED = "Balanced Tree" - # """ - # Balanced trees - # """ + CATERPILLAR = "Caterpillar Tree" + """ + Caterpillar trees + """ + + BALANCED = "Balanced Tree" + """ + Balanced trees + """ @validate_num_voters_candidates @@ -159,10 +161,10 @@ def group_separable( ) else: raise ValueError("There is something weird with the tree_sampler value...") - # elif tree_sampler == TreeSampler.CATERPILLAR: - # tree_root = caterpillar_tree(num_candidates) - # elif tree_sampler == TreeSampler.BALANCED: - # tree_root = balanced_tree(num_candidates) + elif tree_sampler == TreeSampler.CATERPILLAR: + tree_root = caterpillar_tree(num_candidates) + elif tree_sampler == TreeSampler.BALANCED: + tree_root = balanced_tree(num_candidates) else: raise ValueError( "The `tree` argument needs to be one of the constant defined in the " @@ -183,7 +185,7 @@ def group_separable( signatures = np.zeros((num_voters - 1, num_internal_nodes), dtype=bool) for r in range(num_internal_nodes): values_at_pos = rng.choice((True, False), size=num_voters - 1) - while r > 0 and not any(values_at_pos): + while r > 0 and not any(values_at_pos) and num_voters - 1 > 0: values_at_pos = rng.choice((True, False), size=num_voters - 1) for i in range(num_voters - 1): signatures[i][r] = values_at_pos[i] diff --git a/prefsampling/ordinal/singlecrossing.py b/prefsampling/ordinal/singlecrossing.py index a25a7b2..aebe6b6 100644 --- a/prefsampling/ordinal/singlecrossing.py +++ b/prefsampling/ordinal/singlecrossing.py @@ -134,13 +134,9 @@ def single_crossing( new_line[swap_indices[1]] = domain[line - 1][swap_indices[0]] domain.append(new_line) - votes = [] - last_sampled_index = 0 - votes.append(domain[0]) - for i in range(1, num_voters): - index = rng.integers(last_sampled_index, domain_size) - votes.append(domain[index]) - last_sampled_index = index + indeces = rng.integers(0, domain_size, num_voters) + indeces.sort() + votes = [domain[index] for index in indeces] return votes diff --git a/prefsampling/tree/__init__.py b/prefsampling/tree/__init__.py index 4572cb4..e8ac9b1 100644 --- a/prefsampling/tree/__init__.py +++ b/prefsampling/tree/__init__.py @@ -14,5 +14,5 @@ "schroeder_tree_brute_force", "all_schroeder_tree", "caterpillar_tree", - "caterpillar_tree", + "balanced_tree", ] diff --git a/prefsampling/tree/balanced.py b/prefsampling/tree/balanced.py index ef2a02a..f4be66f 100644 --- a/prefsampling/tree/balanced.py +++ b/prefsampling/tree/balanced.py @@ -6,7 +6,7 @@ from prefsampling.tree.node import Node -def balanced_tree(num_leaves: int) -> Node: +def balanced_tree(num_leaves: int, seed: int = None) -> Node: """ Generates a balanced tree. @@ -17,26 +17,20 @@ def balanced_tree(num_leaves: int) -> Node: """ validate_int(num_leaves, "number of leaves", lower_bound=1) num_leaves = int(num_leaves) - root = Node("root") - ctr = 0 + root = Node(0) + if num_leaves == 1: + return root + ctr = 1 q = queue.Queue() q.put(root) - while q.qsize() * 2 < num_leaves: + while ctr < 2*num_leaves-1: tmp_root = q.get() for _ in range(2): inner_node = Node(ctr) + ctr += 1 tmp_root.add_child(inner_node) q.put(inner_node) - ctr += 1 - - ctr = 0 - while ctr < num_leaves: - tmp_root = q.get() - for _ in range(2): - node = Node(ctr) - tmp_root.add_child(node) - ctr += 1 return root diff --git a/prefsampling/tree/caterpillar.py b/prefsampling/tree/caterpillar.py index e9d4053..cac828a 100644 --- a/prefsampling/tree/caterpillar.py +++ b/prefsampling/tree/caterpillar.py @@ -25,15 +25,17 @@ def caterpillar_tree(num_leaves: int, seed: int = None) -> Node: while num_leaves > 2: leaf = Node(ctr) + ctr += 1 inner_node = Node(ctr) + ctr += 1 tmp_root.add_child(leaf) tmp_root.add_child(inner_node) tmp_root = inner_node num_leaves -= 1 - ctr += 1 leaf_1 = Node(ctr) - leaf_2 = Node(ctr + 1) + ctr += 1 + leaf_2 = Node(ctr) tmp_root.add_child(leaf_1) tmp_root.add_child(leaf_2) diff --git a/pyproject.toml b/pyproject.toml index d0574cf..3d34032 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "prefsampling" -version = "0.1.22" +version = "0.1.24" description = "Algorithms to sample preferences of all kinds." authors = [ { name = "Simon Rey", email = "reysimon@orange.fr" }, diff --git a/tests/test_samplers/ordinal/test_ordinal_group_separable.py b/tests/test_samplers/ordinal/test_ordinal_group_separable.py index 302dc14..ea7425f 100644 --- a/tests/test_samplers/ordinal/test_ordinal_group_separable.py +++ b/tests/test_samplers/ordinal/test_ordinal_group_separable.py @@ -13,6 +13,8 @@ def all_test_samplers_ordinal_group_separable(): for tree_sampler in [ TreeSampler.SCHROEDER, TreeSampler.SCHROEDER_LESCANNE, + TreeSampler.CATERPILLAR, + TreeSampler.BALANCED ] ] diff --git a/tests/test_trees.py b/tests/test_trees.py index c5bdaa8..d9c5ca9 100644 --- a/tests/test_trees.py +++ b/tests/test_trees.py @@ -62,8 +62,8 @@ def test_all_schroeder_trees(self): 5: 45, 6: 197, 7: 903, - 8: 4279, - 9: 20793, + # 8: 4279, + # 9: 20793, # 10: 103049, # 11: 518859, # 12: 2646723