Skip to content

Commit 0b06cae

Browse files
author
User
committed
update
1 parent 241bbba commit 0b06cae

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

ann_logistic_extra/ann_train.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,15 @@ def cross_entropy(T, pY):
6666
test_costs.append(ctest)
6767

6868
# gradient descent
69-
W2 -= learning_rate*Ztrain.T.dot(pYtrain - Ytrain_ind)
70-
b2 -= learning_rate*(pYtrain - Ytrain_ind).sum(axis=0)
71-
dZ = (pYtrain - Ytrain_ind).dot(W2.T) * (1 - Ztrain*Ztrain)
72-
W1 -= learning_rate*Xtrain.T.dot(dZ)
73-
b1 -= learning_rate*dZ.sum(axis=0)
69+
gW2 = Ztrain.T.dot(pYtrain - Ytrain_ind)
70+
gb2 = (pYtrain - Ytrain_ind).sum(axis=0)
71+
dZ = (pYtrain - Ytrain_ind).dot(W2.T) * (1 - Ztrain * Ztrain)
72+
gW1 = Xtrain.T.dot(dZ)
73+
gb1 = dZ.sum(axis=0)
74+
W2 -= learning_rate * gW2
75+
b2 -= learning_rate * gb2
76+
W1 -= learning_rate * gW1
77+
b1 -= learning_rate * gb1
7478
if i % 1000 == 0:
7579
print(i, ctrain, ctest)
7680

0 commit comments

Comments
 (0)