Skip to content

Commit cc8edab

Browse files
committed
Represent continuous attributes as a bitvector
1 parent 7369697 commit cc8edab

File tree

6 files changed

+15
-14
lines changed

6 files changed

+15
-14
lines changed

balance.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def sample(ins, out):
634634
sample([5, 5, 5, 2], 0),
635635
sample([5, 5, 5, 3], 0),
636636
sample([5, 5, 5, 4], 0),
637-
continuous_attributes=[0, 1, 2, 3]
637+
continuous=[True, True, True, True]
638638
)
639639

640640

src/dataset.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def __output__(self):
4444
class __Dataset__(ObliviousSequence):
4545
samples: ObliviousArray
4646
number_of_attributes: int
47-
continuous_attributes: [int]
47+
continuous: [bool]
4848

4949
def len(self):
5050
return self.samples.len()
@@ -80,26 +80,27 @@ def select(self, *include):
8080
return ObliviousDatasetSelection(
8181
selection,
8282
self.number_of_attributes,
83-
self.continuous_attributes
83+
self.continuous
8484
)
8585

8686
def is_continuous(self, attribute_index):
87-
return attribute_index in self.continuous_attributes
87+
return self.continuous[attribute_index]
8888

8989
async def __output__(self):
9090
return await output(self.samples)
9191

9292

9393
class ObliviousDataset(__Dataset__, Secret):
94-
def __init__(self, values, continuous_attributes=()):
94+
def __init__(self, values, continuous=None):
9595
number_of_attributes = len(values[0]) if len(values) > 0 else 0
9696
samples = ObliviousArray.create(values)
97-
__Dataset__.__init__(
98-
self, samples, number_of_attributes, continuous_attributes)
97+
if not continuous:
98+
continuous = [False for i in range(number_of_attributes)]
99+
__Dataset__.__init__(self, samples, number_of_attributes, continuous)
99100

100101
@classmethod
101-
def create(cls, *values, continuous_attributes=()):
102-
return ObliviousDataset(values, continuous_attributes)
102+
def create(cls, *values, continuous=None):
103+
return ObliviousDataset(values, continuous)
103104

104105
def __len__(self):
105106
return len(self.samples)

tests/best_split.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_select_best_continuous_attribute(self):
4343
Sample([s(3)], s(0)),
4444
Sample([s(4)], s(1)),
4545
Sample([s(5)], s(1)),
46-
continuous_attributes=[0]
46+
continuous=[True]
4747
)
4848
(best_attribute, threshold) = select_best_attribute(samples)
4949
self.assertEqual(reveal(best_attribute), 0)
@@ -56,7 +56,7 @@ def test_select_best_attribute_from_continuous_and_binary(self):
5656
Sample([s(1), s(3)], s(0)),
5757
Sample([s(1), s(4)], s(1)),
5858
Sample([s(1), s(5)], s(1)),
59-
continuous_attributes=[1]
59+
continuous=[False, True]
6060
)
6161
(best_attribute, threshold) = select_best_attribute(samples)
6262
self.assertEqual(reveal(best_attribute), 1)

tests/dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_continuous_attributes(self):
9494
dataset = ObliviousDataset.create(
9595
Sample([s(0), s(1), s(1)], s(0)),
9696
Sample([s(1), s(2), s(1)], s(1)),
97-
continuous_attributes=[1]
97+
continuous=[False, True, False]
9898
)
9999
self.assertFalse(dataset.is_continuous(0))
100100
self.assertTrue(dataset.is_continuous(1))

tests/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@ def sample(ins, out):
4242
sample([3, 3, 2, 4], 0),
4343
sample([3, 3, 2, 5], 1),
4444
sample([3, 3, 3, 1], 0),
45-
continuous_attributes=[0, 1, 2, 3]
45+
continuous=[True, True, True, True]
4646
)

tests/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_continuous_attributes(self):
6060
samples = ObliviousDataset.create(
6161
Sample([s(1), s(2)], s(0)),
6262
Sample([s(1), s(3)], s(1)),
63-
continuous_attributes=[1])
63+
continuous=[False, True])
6464
self.assertEqual(
6565
reveal(train(samples, depth=1)),
6666
Branch(1, threshold=2, left=leaf(0), right=leaf(1)))

0 commit comments

Comments
 (0)