@@ -82,10 +82,17 @@ def test_xor():
82
82
er = np .mean (prediction != Y )
83
83
84
84
LL .append (ll )
85
- W2 += learning_rate * (derivative_w2 (Z , Y , pY ) - regularization * W2 )
86
- b2 += learning_rate * (derivative_b2 (Y , pY ) - regularization * b2 )
87
- W1 += learning_rate * (derivative_w1 (X , Z , Y , pY , W2 ) - regularization * W1 )
88
- b1 += learning_rate * (derivative_b1 (Z , Y , pY , W2 ) - regularization * b1 )
85
+
86
+ # get gradients
87
+ gW2 = derivative_w2 (Z , Y , pY )
88
+ gb2 = derivative_b2 (Y , pY )
89
+ gW1 = derivative_w1 (X , Z , Y , pY , W2 )
90
+ gb1 = derivative_b1 (Z , Y , pY , W2 )
91
+
92
+ W2 += learning_rate * (gW2 - regularization * W2 )
93
+ b2 += learning_rate * (gb2 - regularization * b2 )
94
+ W1 += learning_rate * (gW1 - regularization * W1 )
95
+ b1 += learning_rate * (gb1 - regularization * b1 )
89
96
if i % 1000 == 0 :
90
97
print (ll )
91
98
@@ -128,19 +135,26 @@ def test_donut():
128
135
prediction = predict (X , W1 , b1 , W2 , b2 )
129
136
er = np .abs (prediction - Y ).mean ()
130
137
LL .append (ll )
131
- W2 += learning_rate * (derivative_w2 (Z , Y , pY ) - regularization * W2 )
132
- b2 += learning_rate * (derivative_b2 (Y , pY ) - regularization * b2 )
133
- W1 += learning_rate * (derivative_w1 (X , Z , Y , pY , W2 ) - regularization * W1 )
134
- b1 += learning_rate * (derivative_b1 (Z , Y , pY , W2 ) - regularization * b1 )
138
+
139
+ # get gradients
140
+ gW2 = derivative_w2 (Z , Y , pY )
141
+ gb2 = derivative_b2 (Y , pY )
142
+ gW1 = derivative_w1 (X , Z , Y , pY , W2 )
143
+ gb1 = derivative_b1 (Z , Y , pY , W2 )
144
+
145
+ W2 += learning_rate * (gW2 - regularization * W2 )
146
+ b2 += learning_rate * (gb2 - regularization * b2 )
147
+ W1 += learning_rate * (gW1 - regularization * W1 )
148
+ b1 += learning_rate * (gb1 - regularization * b1 )
135
149
if i % 300 == 0 :
136
150
print ("i:" , i , "ll:" , ll , "classification rate:" , 1 - er )
137
151
plt .plot (LL )
138
152
plt .show ()
139
153
140
154
141
155
if __name__ == '__main__' :
142
- # test_xor()
143
- test_donut ()
156
+ test_xor ()
157
+ # test_donut()
144
158
145
159
146
160
0 commit comments