-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdataset.py
122 lines (102 loc) · 4.6 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import unittest
from src.dataset import ObliviousDataset, Sample
from src.secint import secint as s
from tests.reveal import reveal
def sample(*inputs):
return Sample(inputs, s(0))
class ObliviousDatasetTest(unittest.TestCase):
def test_column_with_public_index(self):
dataset = ObliviousDataset.create(
sample(s(0), s(1), s(2)),
sample(s(10), s(11), s(12)),
sample(s(20), s(21), s(22))
)
self.assertEqual(reveal(dataset.column(0)), [0, 10, 20])
self.assertEqual(reveal(dataset.column(1)), [1, 11, 21])
self.assertEqual(reveal(dataset.column(2)), [2, 12, 22])
def test_column_of_subset_with_public_index(self):
dataset = ObliviousDataset.create(
sample(s(0), s(1), s(2)),
sample(s(10), s(11), s(12)),
sample(s(20), s(21), s(22))
).select([s(1), s(0), s(1)])
self.assertEqual(reveal(dataset.column(1)), [1, 21])
def test_column_with_secret_index(self):
dataset = ObliviousDataset.create(
sample(s(0), s(1), s(2)),
sample(s(10), s(11), s(12)),
sample(s(20), s(21), s(22))
)
self.assertEqual(reveal(dataset.column(s(0))), [0, 10, 20])
self.assertEqual(reveal(dataset.column(s(1))), [1, 11, 21])
self.assertEqual(reveal(dataset.column(s(2))), [2, 12, 22])
def test_column_of_subset_with_secret_index(self):
dataset = ObliviousDataset.create(
sample(s(0), s(1), s(2)),
sample(s(10), s(11), s(12)),
sample(s(20), s(21), s(22))
).select([s(1), s(0), s(1)])
self.assertEqual(reveal(dataset.column(s(1))), [1, 21])
def test_outcomes(self):
dataset = ObliviousDataset.create(
Sample([s(0), s(1), s(2)], outcome=s(60)),
Sample([s(10), s(11), s(12)], outcome=s(70)),
Sample([s(20), s(21), s(22)], outcome=s(80))
).select([s(1), s(0), s(1)])
self.assertEqual(reveal(dataset.outcomes), [60, 80])
def test_number_of_attributes(self):
dataset = ObliviousDataset.create(
sample(s(1), s(2), s(3)),
sample(s(4), s(5), s(6))
)
self.assertEqual(dataset.number_of_attributes, 3)
def test_number_of_attributes_empty_set(self):
dataset = ObliviousDataset.create()
self.assertEqual(dataset.number_of_attributes, 0)
def test_random_sample_with_one_sample(self):
dataset = ObliviousDataset.create(Sample([s(1), s(2), s(3)], s(4)))
self.assertEqual(reveal(dataset.choice()),
Sample([1, 2, 3], 4))
def test_random_sample(self):
dataset = ObliviousDataset.create(
Sample([s(1), s(2), s(3)], s(4)),
Sample([s(11), s(12), s(13)], s(14))
)
randomSamples = [reveal(dataset.choice()) for _ in range(100)]
self.assertIn(Sample([1, 2, 3], 4), randomSamples)
self.assertIn(Sample([11, 12, 13], 14), randomSamples)
def test_determine_class_single_sample(self):
dataset = ObliviousDataset.create(Sample([s(0)], s(0)))
self.assertEqual(reveal(dataset.determine_class()), 0)
def test_determine_class_multiple_samples(self):
dataset = ObliviousDataset.create(
Sample([s(0)], s(0)),
Sample([s(0)], s(1)),
Sample([s(0)], s(1)))
self.assertEqual(reveal(dataset.determine_class()), 1)
def test_continuous_attributes(self):
dataset = ObliviousDataset.create(
Sample([s(0), s(1), s(1)], s(0)),
Sample([s(1), s(2), s(1)], s(1)),
continuous=[False, True, False]
)
self.assertFalse(dataset.is_continuous(0))
self.assertTrue(dataset.is_continuous(1))
self.assertFalse(dataset.is_continuous(2))
def test_continuous_attribute_check_with_secret_index(self):
dataset = ObliviousDataset.create(
Sample([s(0), s(1), s(1)], s(0)),
Sample([s(1), s(2), s(1)], s(1)),
continuous=[False, True, False]
)
self.assertFalse(reveal(dataset.is_continuous(s(0))))
self.assertTrue(reveal(dataset.is_continuous(s(1))))
self.assertFalse(reveal(dataset.is_continuous(s(2))))
class SampleTest(unittest.TestCase):
def test_add_samples(self):
sample1 = Sample([s(1), s(2), s(3)], s(4))
sample2 = Sample([s(5), s(6), s(7)], s(8))
self.assertEqual(reveal(sample1 + sample2), Sample([6, 8, 10], 12))
def test_multiply_samples(self):
sample = Sample([s(1), s(2), s(3)], s(4))
self.assertEqual(reveal(sample * s(2)), Sample([2, 4, 6], 8))