-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgini.py
35 lines (28 loc) · 1.16 KB
/
gini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import unittest
from src.gini import avoid_zero, gini_gain_quotient
from src.secint import secint as s
from tests.reveal import reveal
class GiniTest(unittest.TestCase):
def test_gini_gain_uniform(self):
numerator, denominator = gini_gain_quotient(2, 2, 1, 1, 1, 1)
total = 4
gain = (1 / total) * (numerator / denominator)
self.assertEqual(gain, 0.5)
def test_gini_gain_perfect_split(self):
numerator, denominator = gini_gain_quotient(2, 2, 2, 0, 0, 2)
total = 4
gain = (1 / total) * (numerator / denominator)
self.assertEqual(gain, 1)
def test_avoidance_of_division_by_zero(self):
numerator, denominator = (gini_gain_quotient(0, 0, 0, 0, 0, 0))
total = 1
gain = (1 / total) * numerator / avoid_zero(denominator)
self.assertEqual(gain, 0)
def test_gini_gain_mpc(self):
numerator, denominator = gini_gain_quotient(
s(2), s(2), s(1), s(1), s(1), s(1))
numerator = reveal(numerator)
denominator = reveal(denominator)
total = 4
gain = (1 / total) * float(numerator / denominator)
self.assertEqual(gain, 0.5)