forked from minitorch/Module-0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_operators.py
232 lines (177 loc) · 5.77 KB
/
test_operators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from typing import Callable, List, Tuple
import pytest
from hypothesis import given
from hypothesis.strategies import lists
from minitorch import MathTest
import minitorch
from minitorch.operators import (
add,
addLists,
eq,
id,
inv,
inv_back,
log_back,
lt,
max,
mul,
neg,
negList,
prod,
relu,
relu_back,
sigmoid,
)
from .strategies import assert_close, small_floats
# ## Task 0.1 Basic hypothesis tests.
@pytest.mark.task0_1
@given(small_floats, small_floats)
def test_same_as_python(x: float, y: float) -> None:
"""Check that the main operators all return the same value of the python version"""
assert_close(mul(x, y), x * y)
assert_close(add(x, y), x + y)
assert_close(neg(x), -x)
assert_close(max(x, y), x if x > y else y)
if abs(x) > 1e-5:
assert_close(inv(x), 1.0 / x)
@pytest.mark.task0_1
@given(small_floats)
def test_relu(a: float) -> None:
if a > 0:
assert relu(a) == a
if a < 0:
assert relu(a) == 0.0
@pytest.mark.task0_1
@given(small_floats, small_floats)
def test_relu_back(a: float, b: float) -> None:
if a > 0:
assert relu_back(a, b) == b
if a < 0:
assert relu_back(a, b) == 0.0
@pytest.mark.task0_1
@given(small_floats)
def test_id(a: float) -> None:
assert id(a) == a
@pytest.mark.task0_1
@given(small_floats)
def test_lt(a: float) -> None:
"""Check that a - 1.0 is always less than a"""
assert lt(a - 1.0, a) == 1.0
assert lt(a, a - 1.0) == 0.0
@pytest.mark.task0_1
@given(small_floats)
def test_max(a: float) -> None:
assert max(a - 1.0, a) == a
assert max(a, a - 1.0) == a
assert max(a + 1.0, a) == a + 1.0
assert max(a, a + 1.0) == a + 1.0
@pytest.mark.task0_1
@given(small_floats)
def test_eq(a: float) -> None:
assert eq(a, a) == 1.0
assert eq(a, a - 1.0) == 0.0
assert eq(a, a + 1.0) == 0.0
# ## Task 0.2 - Property Testing
# Implement the following property checks
# that ensure that your operators obey basic
# mathematical rules.
@pytest.mark.task0_2
@given(small_floats)
def test_sigmoid(a: float) -> None:
"""Check properties of the sigmoid function, specifically
* It is always between 0.0 and 1.0.
* one minus sigmoid is the same as sigmoid of the negative
* It crosses 0 at 0.5
* It is strictly increasing.
"""
# TODO: Implement for Task 0.2.
assert 0.0 <= sigmoid(a) and sigmoid(a) <= 1.0
assert_close(1 - sigmoid(a), sigmoid(-a))
assert_close(sigmoid(0.0), 0.5)
@pytest.mark.task0_2
@given(small_floats, small_floats, small_floats)
def test_transitive(a: float, b: float, c: float) -> None:
"""Test the transitive property of less-than (a < b and b < c implies a < c)"""
# TODO: Implement for Task 0.2.
if lt(a, b) and lt(b, c):
assert lt(a, c)
@pytest.mark.task0_2
@given(small_floats, small_floats)
def test_symmetric(a: float, b: float) -> None:
"""Write a test that ensures that :func:`minitorch.operators.mul` is symmetric, i.e.
gives the same value regardless of the order of its input.
"""
# TODO: Implement for Task 0.2.
assert_close(mul(a, b), mul(b, a))
@pytest.mark.task0_2
@given(small_floats, small_floats, small_floats)
def test_distribute(a: float, b: float, c: float) -> None:
r"""Write a test that ensures that your operators distribute, i.e.
:math:`z \times (x + y) = z \times x + z \times y`
"""
# TODO: Implement for Task 0.2.
assert_close(mul(c, add(a, b)), add(mul(c, a), mul(c, b)))
@pytest.mark.task0_2
@given(small_floats)
def test_other(a: float) -> None:
"""Write a test that ensures some other property holds for your functions."""
# TODO: Implement for Task 0.2.
assert id(a) == a
# ## Task 0.3 - Higher-order functions
# These tests check that your higher-order functions obey basic
# properties.
@pytest.mark.task0_3
@given(small_floats, small_floats, small_floats, small_floats)
def test_zip_with(a: float, b: float, c: float, d: float) -> None:
x1, x2 = addLists([a, b], [c, d])
y1, y2 = a + c, b + d
assert_close(x1, y1)
assert_close(x2, y2)
@pytest.mark.task0_3
@given(
lists(small_floats, min_size=5, max_size=5),
lists(small_floats, min_size=5, max_size=5),
)
def test_sum_distribute(ls1: List[float], ls2: List[float]) -> None:
"""Write a test that ensures that the sum of `ls1` plus the sum of `ls2`
is the same as the sum of each element of `ls1` plus each element of `ls2`.
"""
# TODO: Implement for Task 0.3.
assert_close(
minitorch.operators.sum(addLists(ls1, ls2)),
add(minitorch.operators.sum(ls1), minitorch.operators.sum(ls2)),
)
@pytest.mark.task0_3
@given(lists(small_floats))
def test_sum(ls: List[float]) -> None:
assert_close(sum(ls), minitorch.operators.sum(ls))
@pytest.mark.task0_3
@given(small_floats, small_floats, small_floats)
def test_prod(x: float, y: float, z: float) -> None:
assert_close(prod([x, y, z]), x * y * z)
@pytest.mark.task0_3
@given(lists(small_floats))
def test_negList(ls: List[float]) -> None:
check = negList(ls)
for i, j in zip(ls, check):
assert_close(i, -j)
# ## Generic mathematical tests
# For each unit this generic set of mathematical tests will run.
one_arg, two_arg, _ = MathTest._tests()
@given(small_floats)
@pytest.mark.parametrize("fn", one_arg)
def test_one_args(fn: Tuple[str, Callable[[float], float]], t1: float) -> None:
name, base_fn = fn
base_fn(t1)
@given(small_floats, small_floats)
@pytest.mark.parametrize("fn", two_arg)
def test_two_args(
fn: Tuple[str, Callable[[float, float], float]], t1: float, t2: float
) -> None:
name, base_fn = fn
base_fn(t1, t2)
@given(small_floats, small_floats)
def test_backs(a: float, b: float) -> None:
relu_back(a, b)
inv_back(a + 2.4, b)
log_back(abs(a) + 4, b)