Skip to content

Commit 213e786

Browse files
author
User
committed
update
1 parent e12db82 commit 213e786

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

ann_class/xor_donut.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,17 @@ def test_xor():
8282
er = np.mean(prediction != Y)
8383

8484
LL.append(ll)
85-
W2 += learning_rate * (derivative_w2(Z, Y, pY) - regularization * W2)
86-
b2 += learning_rate * (derivative_b2(Y, pY) - regularization * b2)
87-
W1 += learning_rate * (derivative_w1(X, Z, Y, pY, W2) - regularization * W1)
88-
b1 += learning_rate * (derivative_b1(Z, Y, pY, W2) - regularization * b1)
85+
86+
# get gradients
87+
gW2 = derivative_w2(Z, Y, pY)
88+
gb2 = derivative_b2(Y, pY)
89+
gW1 = derivative_w1(X, Z, Y, pY, W2)
90+
gb1 = derivative_b1(Z, Y, pY, W2)
91+
92+
W2 += learning_rate * (gW2 - regularization * W2)
93+
b2 += learning_rate * (gb2 - regularization * b2)
94+
W1 += learning_rate * (gW1 - regularization * W1)
95+
b1 += learning_rate * (gb1 - regularization * b1)
8996
if i % 1000 == 0:
9097
print(ll)
9198

@@ -128,19 +135,26 @@ def test_donut():
128135
prediction = predict(X, W1, b1, W2, b2)
129136
er = np.abs(prediction - Y).mean()
130137
LL.append(ll)
131-
W2 += learning_rate * (derivative_w2(Z, Y, pY) - regularization * W2)
132-
b2 += learning_rate * (derivative_b2(Y, pY) - regularization * b2)
133-
W1 += learning_rate * (derivative_w1(X, Z, Y, pY, W2) - regularization * W1)
134-
b1 += learning_rate * (derivative_b1(Z, Y, pY, W2) - regularization * b1)
138+
139+
# get gradients
140+
gW2 = derivative_w2(Z, Y, pY)
141+
gb2 = derivative_b2(Y, pY)
142+
gW1 = derivative_w1(X, Z, Y, pY, W2)
143+
gb1 = derivative_b1(Z, Y, pY, W2)
144+
145+
W2 += learning_rate * (gW2 - regularization * W2)
146+
b2 += learning_rate * (gb2 - regularization * b2)
147+
W1 += learning_rate * (gW1 - regularization * W1)
148+
b1 += learning_rate * (gb1 - regularization * b1)
135149
if i % 300 == 0:
136150
print("i:", i, "ll:", ll, "classification rate:", 1 - er)
137151
plt.plot(LL)
138152
plt.show()
139153

140154

141155
if __name__ == '__main__':
142-
# test_xor()
143-
test_donut()
156+
test_xor()
157+
# test_donut()
144158

145159

146160

0 commit comments

Comments
 (0)