File tree Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -66,11 +66,15 @@ def cross_entropy(T, pY):
66
66
test_costs .append (ctest )
67
67
68
68
# gradient descent
69
- W2 -= learning_rate * Ztrain .T .dot (pYtrain - Ytrain_ind )
70
- b2 -= learning_rate * (pYtrain - Ytrain_ind ).sum (axis = 0 )
71
- dZ = (pYtrain - Ytrain_ind ).dot (W2 .T ) * (1 - Ztrain * Ztrain )
72
- W1 -= learning_rate * Xtrain .T .dot (dZ )
73
- b1 -= learning_rate * dZ .sum (axis = 0 )
69
+ gW2 = Ztrain .T .dot (pYtrain - Ytrain_ind )
70
+ gb2 = (pYtrain - Ytrain_ind ).sum (axis = 0 )
71
+ dZ = (pYtrain - Ytrain_ind ).dot (W2 .T ) * (1 - Ztrain * Ztrain )
72
+ gW1 = Xtrain .T .dot (dZ )
73
+ gb1 = dZ .sum (axis = 0 )
74
+ W2 -= learning_rate * gW2
75
+ b2 -= learning_rate * gb2
76
+ W1 -= learning_rate * gW1
77
+ b1 -= learning_rate * gb1
74
78
if i % 1000 == 0 :
75
79
print (i , ctrain , ctest )
76
80
You can’t perform that action at this time.
0 commit comments