Skip to content

Commit 76091d9

Browse files
authored
Merge pull request #55 from NACLab/dev
fixed adj-acc bug in analyze_scores in metrics
2 parents 8e2e66d + f00c63e commit 76091d9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ngclearn/utils/metric_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation st
122122
confusion matrix, precision, recall, misses (empty predictions/all-zero rows),
123123
accuracy, adjusted-accuracy (counts all misses as incorrect)
124124
"""
125-
miss_mask = (jnp.sum(mu, axis=1, keepdims=True) == 0.) * 1.
125+
miss_mask = (jnp.sum(mu, axis=1) == 0.) * 1.
126126
misses = jnp.sum(miss_mask) ## how many misses?
127127
labels = y
128128
if extract_label_indx:
@@ -133,7 +133,7 @@ def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation st
133133
recall = recall_score(labels, guesses, average='macro')
134134
## produce accuracy score measurements
135135
guess = jnp.argmax(mu, axis=1) ## gather all model/output guesses
136-
equality_mask = jnp.equal(guess, labels)
136+
equality_mask = jnp.equal(guess, labels) * 1.
137137
### compute raw accuracy
138138
acc = jnp.sum(equality_mask) / (y.shape[0] * 1.)
139139
### compute hit-masked accuracy (adjusted accuracy

0 commit comments

Comments
 (0)