We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 241bbba commit 0b06caeCopy full SHA for 0b06cae
ann_logistic_extra/ann_train.py
@@ -66,11 +66,15 @@ def cross_entropy(T, pY):
66
test_costs.append(ctest)
67
68
# gradient descent
69
- W2 -= learning_rate*Ztrain.T.dot(pYtrain - Ytrain_ind)
70
- b2 -= learning_rate*(pYtrain - Ytrain_ind).sum(axis=0)
71
- dZ = (pYtrain - Ytrain_ind).dot(W2.T) * (1 - Ztrain*Ztrain)
72
- W1 -= learning_rate*Xtrain.T.dot(dZ)
73
- b1 -= learning_rate*dZ.sum(axis=0)
+ gW2 = Ztrain.T.dot(pYtrain - Ytrain_ind)
+ gb2 = (pYtrain - Ytrain_ind).sum(axis=0)
+ dZ = (pYtrain - Ytrain_ind).dot(W2.T) * (1 - Ztrain * Ztrain)
+ gW1 = Xtrain.T.dot(dZ)
+ gb1 = dZ.sum(axis=0)
74
+ W2 -= learning_rate * gW2
75
+ b2 -= learning_rate * gb2
76
+ W1 -= learning_rate * gW1
77
+ b1 -= learning_rate * gb1
78
if i % 1000 == 0:
79
print(i, ctrain, ctest)
80
0 commit comments