Skip to content

Commit d910389

Browse files
relflevkkmontanalow
authored
Fix confusion matrix (#374)
* Add serialization for LogisticRegression * Confusion matrix should use labels from predictions and ground truth * Clippy fixes * This is the correct test * clippy lints * fix ownership * fix ownership * cleanup lints * Remove blank lines * Dedup labels (review) * Improve combined_labels API (review) * Make confusion_matrix layout reproducible, use sensible default when boolean classes --------- Co-authored-by: Lev Kokotov <[email protected]> Co-authored-by: Montana Low <[email protected]>
1 parent 9744625 commit d910389

File tree

5 files changed

+57
-9
lines changed

5 files changed

+57
-9
lines changed

algorithms/linfa-ftrl/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ version = "1.0"
2424
features = ["derive"]
2525

2626
[dependencies]
27-
ndarray = { version = "0.15.4", features = ["serde"] }
27+
ndarray = { version = "0.15", features = ["serde"] }
2828
ndarray-rand = "0.14.0"
2929
argmin = { version = "0.9.0", default-features = false }
3030
argmin-math = { version = "0.3", features = ["ndarray_v0_15-nolinalg"] }

algorithms/linfa-logistic/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ argmin = { version = "0.9.0", default-features = false }
3030
argmin-math = { version = "0.3", features = ["ndarray_v0_15-nolinalg"] }
3131
thiserror = "1.0"
3232

33-
3433
linfa = { version = "0.7.1", path = "../.." }
3534

3635
[dev-dependencies]

algorithms/linfa-trees/src/decision_trees/hyperparams.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub enum SplitQuality {
5050
/// let tree = params.fit(&train).unwrap();
5151
/// // Predict on validation and check accuracy
5252
/// let val_accuracy = tree.predict(&val).confusion_matrix(&val).unwrap().accuracy();
53-
/// assert!(val_accuracy > 0.99);
53+
/// assert!(val_accuracy > 0.9);
5454
/// ```
5555
///
5656
#[cfg_attr(

src/dataset/mod.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,28 @@ pub trait Labels {
324324
}
325325

326326
fn labels(&self) -> Vec<Self::Elem> {
327-
self.label_set().into_iter().flatten().collect()
327+
self.label_set()
328+
.into_iter()
329+
.flatten()
330+
.collect::<HashSet<_>>()
331+
.into_iter()
332+
.collect()
333+
}
334+
335+
fn combined_labels<T>(&self, other: &T) -> Vec<Self::Elem>
336+
where
337+
T: Labels<Elem = <Self as Labels>::Elem>,
338+
{
339+
let mut combined = self.label_set();
340+
combined.extend(other.label_set());
341+
342+
combined
343+
.iter()
344+
.flatten()
345+
.collect::<HashSet<_>>()
346+
.into_iter()
347+
.cloned()
348+
.collect()
328349
}
329350
}
330351

src/metrics_classification.rs

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//!
33
//! Scoring is essential for classification and regression tasks. This module implements
44
//! common scoring functions like precision, accuracy, recall, f1-score, ROC and ROC
5-
//! Aread-Under-Curve.
5+
//! Area-Under-Curve.
66
use std::collections::HashMap;
77
use std::fmt;
88

@@ -290,7 +290,23 @@ where
290290
return Err(Error::MismatchedShapes(targets.len(), ground_truth.len()));
291291
}
292292

293-
let classes = self.labels();
293+
let mut classes = self.combined_labels(ground_truth);
294+
// Sort classes to get reproducible confusion_matrix
295+
classes.sort();
296+
if classes.len() == 2 {
297+
// In case of binary classes, we sort in reverse order to get a sensible default for
298+
// boolean values and get a confusion matrix with the conventional layout by default:
299+
//
300+
// | actual true | actual false
301+
// pred true | TP | FP
302+
// -----------------------------------------
303+
// pred false | FN | TN
304+
//
305+
// So to get classes to be [true, false], as false < true or 0 < 1, we reverse the order.
306+
// As precision and recall metrics are computed wrt the first label,
307+
// it is less confusing if it corresponds to true.
308+
classes.reverse();
309+
}
294310

295311
let indices = map_prediction_to_idx(
296312
targets.as_slice().unwrap(),
@@ -595,10 +611,11 @@ mod tests {
595611

596612
let cm = predicted.confusion_matrix(ground_truth).unwrap();
597613

598-
let labels = array![0, 1];
599-
let expected = array![[2., 1.], [0., 3.]];
614+
let expected_labels = array![1, 0];
615+
let expected = array![[3., 0.], [1., 2.]];
600616

601-
assert_cm_eq(&cm, &expected, &labels);
617+
assert_eq!(expected_labels, cm.members);
618+
assert_abs_diff_eq!(expected, cm.matrix);
602619
}
603620

604621
#[test]
@@ -636,6 +653,17 @@ mod tests {
636653
);
637654
}
638655

656+
#[test]
657+
fn test_division_by_zero_cm() {
658+
let ground_truth = Array1::from(vec![1, 1, 0, 1, 0, 1]);
659+
let predicted = Array1::from(vec![0, 0, 0, 0, 0, 0]);
660+
661+
let x = ground_truth.confusion_matrix(predicted).unwrap();
662+
let f1 = x.f1_score();
663+
664+
assert!(f1.is_nan());
665+
}
666+
639667
#[test]
640668
fn test_roc_curve() {
641669
let predicted = ArrayView1::from(&[0.1, 0.3, 0.5, 0.7, 0.8, 0.9]).mapv(Pr::new);

0 commit comments

Comments
 (0)