diff --git a/learning4e.py b/learning4e.py index 12c0defa5..5669c453e 100644 --- a/learning4e.py +++ b/learning4e.py @@ -910,7 +910,7 @@ def weighted_replicate(seq, weights, n): def accuracy_score(y_pred, y_true): assert y_pred.shape == y_true.shape - return np.mean(np.equal(y_pred, y_true)) + return np.mean(y_pred == y_true) def r2_score(y_pred, y_true):