Skip to content

Commit fad41dc

Browse files
committed
2 parents 6c5850d + de0c099 commit fad41dc

File tree

9 files changed

+37
-41
lines changed

9 files changed

+37
-41
lines changed

prefsampling/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__author__ = "Simon Rey and Stanisław Szufa"
22
__email__ = "[email protected]"
3-
__version__ = "0.1.22"
3+
__version__ = "0.1.24"
44

55
from enum import Enum
66
from itertools import chain

prefsampling/ordinal/groupseparable.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
schroeder_tree_brute_force,
1212
schroeder_tree_lescanne,
1313
)
14+
from prefsampling.tree.balanced import balanced_tree
15+
from prefsampling.tree.caterpillar import caterpillar_tree
1416
from prefsampling.inputvalidators import validate_num_voters_candidates
1517
from prefsampling.combinatorics import comb
1618

@@ -36,15 +38,15 @@ class TreeSampler(Enum):
3638
Random Schröder trees sampled following Lescanne (2022)
3739
"""
3840

39-
# CATERPILLAR = "Caterpillar Tree"
40-
# """
41-
# Caterpillar trees
42-
# """
43-
#
44-
# BALANCED = "Balanced Tree"
45-
# """
46-
# Balanced trees
47-
# """
41+
CATERPILLAR = "Caterpillar Tree"
42+
"""
43+
Caterpillar trees
44+
"""
45+
46+
BALANCED = "Balanced Tree"
47+
"""
48+
Balanced trees
49+
"""
4850

4951

5052
@validate_num_voters_candidates
@@ -159,10 +161,10 @@ def group_separable(
159161
)
160162
else:
161163
raise ValueError("There is something weird with the tree_sampler value...")
162-
# elif tree_sampler == TreeSampler.CATERPILLAR:
163-
# tree_root = caterpillar_tree(num_candidates)
164-
# elif tree_sampler == TreeSampler.BALANCED:
165-
# tree_root = balanced_tree(num_candidates)
164+
elif tree_sampler == TreeSampler.CATERPILLAR:
165+
tree_root = caterpillar_tree(num_candidates)
166+
elif tree_sampler == TreeSampler.BALANCED:
167+
tree_root = balanced_tree(num_candidates)
166168
else:
167169
raise ValueError(
168170
"The `tree` argument needs to be one of the constant defined in the "
@@ -183,7 +185,7 @@ def group_separable(
183185
signatures = np.zeros((num_voters - 1, num_internal_nodes), dtype=bool)
184186
for r in range(num_internal_nodes):
185187
values_at_pos = rng.choice((True, False), size=num_voters - 1)
186-
while r > 0 and not any(values_at_pos):
188+
while r > 0 and not any(values_at_pos) and num_voters - 1 > 0:
187189
values_at_pos = rng.choice((True, False), size=num_voters - 1)
188190
for i in range(num_voters - 1):
189191
signatures[i][r] = values_at_pos[i]

prefsampling/ordinal/singlecrossing.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,9 @@ def single_crossing(
134134
new_line[swap_indices[1]] = domain[line - 1][swap_indices[0]]
135135
domain.append(new_line)
136136

137-
votes = []
138-
last_sampled_index = 0
139-
votes.append(domain[0])
140-
for i in range(1, num_voters):
141-
index = rng.integers(last_sampled_index, domain_size)
142-
votes.append(domain[index])
143-
last_sampled_index = index
137+
indeces = rng.integers(0, domain_size, num_voters)
138+
indeces.sort()
139+
votes = [domain[index] for index in indeces]
144140

145141
return votes
146142

prefsampling/tree/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@
1414
"schroeder_tree_brute_force",
1515
"all_schroeder_tree",
1616
"caterpillar_tree",
17-
"caterpillar_tree",
17+
"balanced_tree",
1818
]

prefsampling/tree/balanced.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from prefsampling.tree.node import Node
77

88

9-
def balanced_tree(num_leaves: int) -> Node:
9+
def balanced_tree(num_leaves: int, seed: int = None) -> Node:
1010
"""
1111
Generates a balanced tree.
1212
@@ -17,26 +17,20 @@ def balanced_tree(num_leaves: int) -> Node:
1717
"""
1818
validate_int(num_leaves, "number of leaves", lower_bound=1)
1919
num_leaves = int(num_leaves)
20-
root = Node("root")
21-
ctr = 0
20+
root = Node(0)
21+
if num_leaves == 1:
22+
return root
23+
ctr = 1
2224

2325
q = queue.Queue()
2426
q.put(root)
2527

26-
while q.qsize() * 2 < num_leaves:
28+
while ctr < 2*num_leaves-1:
2729
tmp_root = q.get()
2830
for _ in range(2):
2931
inner_node = Node(ctr)
32+
ctr += 1
3033
tmp_root.add_child(inner_node)
3134
q.put(inner_node)
32-
ctr += 1
33-
34-
ctr = 0
35-
while ctr < num_leaves:
36-
tmp_root = q.get()
37-
for _ in range(2):
38-
node = Node(ctr)
39-
tmp_root.add_child(node)
40-
ctr += 1
4135

4236
return root

prefsampling/tree/caterpillar.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@ def caterpillar_tree(num_leaves: int, seed: int = None) -> Node:
2525

2626
while num_leaves > 2:
2727
leaf = Node(ctr)
28+
ctr += 1
2829
inner_node = Node(ctr)
30+
ctr += 1
2931
tmp_root.add_child(leaf)
3032
tmp_root.add_child(inner_node)
3133
tmp_root = inner_node
3234
num_leaves -= 1
33-
ctr += 1
3435

3536
leaf_1 = Node(ctr)
36-
leaf_2 = Node(ctr + 1)
37+
ctr += 1
38+
leaf_2 = Node(ctr)
3739
tmp_root.add_child(leaf_1)
3840
tmp_root.add_child(leaf_2)
3941

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "prefsampling"
7-
version = "0.1.22"
7+
version = "0.1.24"
88
description = "Algorithms to sample preferences of all kinds."
99
authors = [
1010
{ name = "Simon Rey", email = "[email protected]" },

tests/test_samplers/ordinal/test_ordinal_group_separable.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ def all_test_samplers_ordinal_group_separable():
1313
for tree_sampler in [
1414
TreeSampler.SCHROEDER,
1515
TreeSampler.SCHROEDER_LESCANNE,
16+
TreeSampler.CATERPILLAR,
17+
TreeSampler.BALANCED
1618
]
1719
]
1820

tests/test_trees.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def test_all_schroeder_trees(self):
6262
5: 45,
6363
6: 197,
6464
7: 903,
65-
8: 4279,
66-
9: 20793,
65+
# 8: 4279,
66+
# 9: 20793,
6767
# 10: 103049,
6868
# 11: 518859,
6969
# 12: 2646723

0 commit comments

Comments
 (0)