-
Notifications
You must be signed in to change notification settings - Fork 381
/
Copy pathtest_autodiff.py
143 lines (108 loc) · 3.61 KB
/
test_autodiff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from typing import Tuple
import pytest
import minitorch
from minitorch import Context, ScalarFunction, ScalarHistory
# ## Task 1.3 - Tests for the autodifferentiation machinery.
# Simple sanity check and debugging tests.
class Function1(ScalarFunction):
@staticmethod
def forward(ctx: Context, x: float, y: float) -> float:
"$f(x, y) = x + y + 10$"
return x + y + 10
@staticmethod
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:
"Derivatives are $f'_x(x, y) = 1$ and $f'_y(x, y) = 1$"
return d_output, d_output
class Function2(ScalarFunction):
@staticmethod
def forward(ctx: Context, x: float, y: float) -> float:
"$f(x, y) = x \times y + x$"
ctx.save_for_backward(x, y)
return x * y + x
@staticmethod
def backward(ctx: Context, d_output: float) -> Tuple[float, float]:
"Derivatives are $f'_x(x, y) = y + 1$ and $f'_y(x, y) = x$"
x, y = ctx.saved_values
return d_output * (y + 1), d_output * x
# Checks for the chain rule function.
@pytest.mark.task1_3
def test_chain_rule1() -> None:
x = minitorch.Scalar(0.0)
constant = minitorch.Scalar(
0.0, ScalarHistory(Function1, ctx=Context(), inputs=[x, x])
)
back = constant.chain_rule(d_output=5)
assert len(list(back)) == 2
@pytest.mark.task1_3
def test_chain_rule2() -> None:
var = minitorch.Scalar(0.0, ScalarHistory())
constant = minitorch.Scalar(
0.0, ScalarHistory(Function1, ctx=Context(), inputs=[var, var])
)
back = constant.chain_rule(d_output=5)
back = list(back)
assert len(back) == 2
variable, deriv = back[0]
assert deriv == 5
@pytest.mark.task1_3
def test_chain_rule3() -> None:
"Check that constants are ignored and variables get derivatives."
constant = 10
var = minitorch.Scalar(5)
y = Function2.apply(constant, var)
back = y.chain_rule(d_output=5)
back = list(back)
assert len(back) == 2
variable, deriv = back[1]
# assert variable.name == var.name
assert deriv == 5 * 10
@pytest.mark.task1_3
def test_chain_rule4() -> None:
var1 = minitorch.Scalar(5)
var2 = minitorch.Scalar(10)
y = Function2.apply(var1, var2)
back = y.chain_rule(d_output=5)
back = list(back)
assert len(back) == 2
variable, deriv = back[0]
# assert variable.name == var1.name
assert deriv == 5 * (10 + 1)
variable, deriv = back[1]
# assert variable.name == var2.name
assert deriv == 5 * 5
# ## Task 1.4 - Run some simple backprop tests
# Main tests are in test_scalar.py
@pytest.mark.task1_4
def test_backprop1() -> None:
# Example 1: F1(0, v)
var = minitorch.Scalar(0)
var2 = Function1.apply(0, var)
var2.backward(d_output=5)
assert var.derivative == 5
@pytest.mark.task1_4
def test_backprop2() -> None:
# Example 2: F1(0, 0)
var = minitorch.Scalar(0)
var2 = Function1.apply(0, var)
var3 = Function1.apply(0, var2)
var3.backward(d_output=5)
assert var.derivative == 5
@pytest.mark.task1_4
def test_backprop3() -> None:
# Example 3: F1(F1(0, v1), F1(0, v1))
var1 = minitorch.Scalar(0)
var2 = Function1.apply(0, var1)
var3 = Function1.apply(0, var1)
var4 = Function1.apply(var2, var3)
var4.backward(d_output=5)
assert var1.derivative == 10
@pytest.mark.task1_4
def test_backprop4() -> None:
# Example 4: F1(F1(0, v1), F1(0, v1))
var0 = minitorch.Scalar(0)
var1 = Function1.apply(0, var0)
var2 = Function1.apply(0, var1)
var3 = Function1.apply(0, var1)
var4 = Function1.apply(var2, var3)
var4.backward(d_output=5)
assert var0.derivative == 10