Skip to content

Commit c13853d

Browse files
author
User
committed
update
1 parent 213e786 commit c13853d

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

ann_class/backprop.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,16 @@ def main():
144144
# this is gradient ASCENT, not DESCENT
145145
# be comfortable with both!
146146
# oldW2 = W2.copy()
147-
W2 += learning_rate * derivative_w2(hidden, T, output)
148-
b2 += learning_rate * derivative_b2(T, output)
149-
W1 += learning_rate * derivative_w1(X, hidden, T, output, W2)
150-
b1 += learning_rate * derivative_b1(T, output, W2, hidden)
147+
148+
gW2 = derivative_w2(hidden, T, output)
149+
gb2 = derivative_b2(T, output)
150+
gW1 = derivative_w1(X, hidden, T, output, W2)
151+
gb1 = derivative_b1(T, output, W2, hidden)
152+
153+
W2 += learning_rate * gW2
154+
b2 += learning_rate * gb2
155+
W1 += learning_rate * gW1
156+
b1 += learning_rate * gb1
151157

152158
plt.plot(costs)
153159
plt.show()

0 commit comments

Comments
 (0)