diff --git a/src/__tests__/confusionMatrix.js b/src/__tests__/confusionMatrix.js index 5480858..f1354c2 100644 --- a/src/__tests__/confusionMatrix.js +++ b/src/__tests__/confusionMatrix.js @@ -150,6 +150,13 @@ test('Predicted Negatives', () => { expect(cm.getPredNegative('other')).toStrictEqual(11) }) +test('Support', () => { + const cm = new CM(CATEGORIES, M0) + expect(cm.getSupport('bug')).toStrictEqual(6) + expect(cm.getSupport('code')).toStrictEqual(3) + expect(cm.getSupport('other')).toStrictEqual(11) +}) + describe('Accuracy', () => { const cm = new CM(CATEGORIES, M0) test('Accuracy', () => { @@ -158,12 +165,19 @@ describe('Accuracy', () => { expect(cm.getAccuracy('other')).toStrictEqual(0.8) }) - test('Macro accuracy', () => { + test('Macro Accuracy', () => { expect(cm.getMacroAccuracy()).toStrictEqual(M[56]) }) - test('Micro accuracy', () => { + + test('Micro Accuracy', () => { expect(cm.getMicroAccuracy()).toStrictEqual(0.75) }) + + test('Weighted Accuracy', () => { + expect( + Math.round(cm.getWeightedAccuracy() * 100000) / 100000, + ).toStrictEqual(0.83) + }) }) describe('Recall', () => { @@ -174,13 +188,19 @@ describe('Recall', () => { expect(cm.getRecall('other')).toStrictEqual(M[811]) //.727 }) - test('Macro recall', () => { + test('Macro Recall', () => { expect(cm.getMacroRecall()).toStrictEqual((3 / 2 + M[811]) / 3) //~.742 }) - test('Micro recall', () => { + test('Micro Recall', () => { expect(cm.getMicroRecall()).toStrictEqual(0.75) }) + + test('Weighted Recall', () => { + expect(Math.round(cm.getWeightedRecall() * 100000) / 100000).toStrictEqual( + 0.75, + ) + }) }) describe('Precision', () => { @@ -191,13 +211,19 @@ describe('Precision', () => { expect(cm.getPrecision('other')).toStrictEqual(M[89]) //.889 }) - test('Macro precision', () => { + test('Macro Precision', () => { expect(cm.getMacroPrecision()).toStrictEqual((M[56] + 0.4 + M[89]) / 3) //~.707 }) - test('Micro precision', () => { + test('Micro Precision', () => { expect(cm.getMicroPrecision()).toStrictEqual(0.75) }) + + test('Weighted Precision', () => { + expect( + Math.round(cm.getWeightedPrecision() * 100000) / 100000, + ).toStrictEqual(0.79889) + }) }) describe('F1', () => { @@ -215,6 +241,12 @@ describe('F1', () => { test('Micro F1', () => { expect(cm.getMicroF1()).toStrictEqual(0.75) }) + + test('Weighted F1', () => { + expect(Math.round(cm.getWeightedF1() * 100000) / 100000).toStrictEqual( + 0.765, + ) + }) }) describe('MissRate', () => { @@ -232,6 +264,12 @@ describe('MissRate', () => { test('Micro MissRate', () => { expect(cm.getMicroMissRate()).toStrictEqual(0.25) }) + + test('Weighted MissRate', () => { + expect( + Math.round(cm.getWeightedMissRate() * 100000) / 100000, + ).toStrictEqual(0.25) + }) }) describe('FallOut', () => { @@ -250,6 +288,12 @@ describe('FallOut', () => { test('Micro FallOut', () => { expect(cm.getMicroFallOut()).toStrictEqual(0.125) }) + + test('Weighted FallOut', () => { + expect(Math.round(cm.getWeightedFallOut() * 100000) / 100000).toStrictEqual( + 0.10901, + ) + }) }) describe('Specificity', () => { @@ -268,6 +312,12 @@ describe('Specificity', () => { test('Micro Specificity', () => { expect(cm.getMicroSpecificity()).toStrictEqual(0.875) }) + + test('Weighted Specificity', () => { + expect( + Math.round(cm.getWeightedSpecificity() * 100000) / 100000, + ).toStrictEqual(0.89099) + }) }) describe('Prevalence', () => { @@ -285,6 +335,12 @@ describe('Prevalence', () => { test('Micro Prevalence', () => { expect(cm.getMicroPrevalence()).toStrictEqual(M[13]) }) + + test('Weighted Prevalence', () => { + expect( + Math.round(cm.getWeightedPrevalence() * 100000) / 100000, + ).toStrictEqual(0.415) + }) }) describe('fromData', () => { @@ -416,16 +472,42 @@ describe('toString', () => { }) }) -test('shortStats', () => { - const cm = new CM(CATEGORIES, M0) - const ss = `Total: 20 +describe('shortStats', () => { + test('default', () => { + const cm = new CM(CATEGORIES, M0) + const ss = `Total: 20 True: 15 False: 5 Accuracy: 75% Precision: 75% Recall: 75% F1: 75%` - expect(cm.getShortStats()).toStrictEqual(ss) + expect(cm.getShortStats()).toStrictEqual(ss) + }) + + test('macro', () => { + const cm = new CM(CATEGORIES, M0) + const ss = `Total: 20 +True: 15 +False: 5 +Accuracy: 83.33333333333334% +Precision: 70.74074074074073% +Recall: 74.24242424242425% +F1: 71.11111111111111%` + expect(cm.getShortStats('macro')).toStrictEqual(ss) + }) + + test('weighted', () => { + const cm = new CM(CATEGORIES, M0) + const ss = `Total: 20 +True: 15 +False: 5 +Accuracy: 83% +Precision: 79.88888888888889% +Recall: 75% +F1: 76.49999999999999%` + expect(cm.getShortStats('weighted')).toStrictEqual(ss) + }) }) describe('Long stats', () => { @@ -439,6 +521,7 @@ describe('Long stats', () => { classes: CATEGORIES, microAvg: {}, macroAvg: {}, + weightedAvg: {}, results: { bug: {}, code: {}, @@ -473,6 +556,19 @@ describe('Long stats', () => { }) }) + it('has weighted details', () => { + expect(stats.weightedAvg).toMatchObject({ + accuracy: 0.83 + 1e-16, + f1: 0.7649999999999999, + fallOut: 0.10901027077497664, + missRate: 0.25, + precision: 0.7988888888888889, + prevalence: 0.41500000000000004, + recall: 0.75, + specificity: 0.8909897292250232, + }) + }) + it('has class details', () => { const bugStats = stats.results.bug expect(bugStats).toMatchObject({ diff --git a/src/confusionMatrix.js b/src/confusionMatrix.js index e56fe3f..133c3e4 100644 --- a/src/confusionMatrix.js +++ b/src/confusionMatrix.js @@ -9,6 +9,7 @@ const { rmEmpty, clrVal, fxSum, + fxWeightedSum, mapObject, } = require('./utils') @@ -258,6 +259,16 @@ class ConfusionMatrix { return this.getTN(category) + this.getFN(category) } + /** + * Support value (count/occurrences) of `category` in the matrix + * @param {string} category Class/category to look at + * @returns {number} Support value + */ + getSupport(category) { + const counts = Object.values(this.matrix[category]) + return sum(...counts) + } + /** * Prediction accuracy for `category`. * @param {string} category Class/category considered as positive @@ -279,13 +290,21 @@ class ConfusionMatrix { /** * Macro-average of accuracy. - * @returns {number} (A0 + ... An_1) / n + * @returns {number} (A0 + ...+ An_1) / n * @protected */ getMacroAccuracy() { return fxSum(this, 'Accuracy') / this.classes.length } + /** + * Weighted accuracy. + * @returns {number} (A0 * s0 + ... + An * sn) / Total + */ + getWeightedAccuracy() { + return fxWeightedSum(this, 'Accuracy') / this.getTotal() + } + /** * Predicition recall. * @param {string} category Class/category considered as positive @@ -318,6 +337,14 @@ class ConfusionMatrix { return fxSum(this, 'Recall') / this.classes.length } + /** + * Weighted recalll. + * @returns {number} (R0 * s0 + ... + Rn * sn) / Total + */ + getWeightedRecall() { + return fxWeightedSum(this, 'Recall') / this.getTotal() + } + /** * Prediction precision for `category`. * @alias getPositivePredictiveValue @@ -349,6 +376,14 @@ class ConfusionMatrix { return fxSum(this, 'Precision') / this.classes.length } + /** + * Weighted precision. + * @returns {number} (Pr0 * s0 + ... + Prn * sn) / Total + */ + getWeightedPrecision() { + return fxWeightedSum(this, 'Precision') / this.getTotal() + } + /** * Prediction F1 score for `category`. * @alias getPositivePredictiveValue @@ -375,14 +410,23 @@ class ConfusionMatrix { } /** - * Macro-average of the precision. + * Macro-average of the F1 score. * @returns {number} (F0_1 + F1_1 + ... + F_n-1_1) / n * @protected */ getMacroF1() { + //@todo Perhaps convert NaNs to 0's to reflect correct calculations (e.g https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html) return fxSum(this, 'F1') / this.classes.length } + /** + * Weighted F1. + * @returns {number} (F0_1 * s0 + ... + Fn_1 * sn) / Total + */ + getWeightedF1() { + return fxWeightedSum(this, 'F1') / this.getTotal() + } + /** * Miss rates on predictions for `category`. * @alias getFalseNegativeRate @@ -414,6 +458,14 @@ class ConfusionMatrix { return fxSum(this, 'MissRate') / this.classes.length } + /** + * Weighted miss rate. + * @returns {number} (M0 * s0 + ... + Mn * sn) / Total + */ + getWeightedMissRate() { + return fxWeightedSum(this, 'MissRate') / this.getTotal() + } + /** * Fall out (false alarm) on predictions for `category`. * @alias getFalsePositiveRate @@ -445,6 +497,14 @@ class ConfusionMatrix { return fxSum(this, 'FallOut') / this.classes.length } + /** + * Weighted fall out. + * @returns {number} (Fo0 * s0 + ... + Fon * sn) / Total + */ + getWeightedFallOut() { + return fxWeightedSum(this, 'FallOut') / this.getTotal() + } + /** * Specificity on predictions for `category`. * @alias getSelectivity @@ -477,6 +537,14 @@ class ConfusionMatrix { return fxSum(this, 'Specificity') / this.classes.length } + /** + * Weighted specificity. + * @returns {number} (S0 * s0 + ... + Sn * sn) / Total + */ + getWeightedSpecificity() { + return fxWeightedSum(this, 'Specificity') / this.getTotal() + } + /** * Prevalence on predictions for `category`. * @param {string} category Class/category considered as positive @@ -500,13 +568,21 @@ class ConfusionMatrix { /** * Macro-average of the prevalence. - * @returns {number} (S0 + S1 + ... + Sn) / n + * @returns {number} (Pe0 + Pe1 + ... + Pen) / n * @protected */ getMacroPrevalence() { return fxSum(this, 'Prevalence') / this.classes.length } + /** + * Weighted prevalence. + * @returns {number} (Pe0 * s0 + ... + Pen * sn) / Total + */ + getWeightedPrevalence() { + return fxWeightedSum(this, 'Prevalence') / this.getTotal() + } + //getFalseDiscoveryRate: getFP() / getPredictedPositive() //getFalseOmmissionRate: getFN() / getPredictedNegative() //getNegPredictiveVal: getTN() / getPredictedNegative() @@ -623,13 +699,38 @@ class ConfusionMatrix { /** * @returns {string} Short statistics (total, true, false, accuracy, precision, recall and f1) - * @protected - */ - getShortStats() { - return `Total: ${this.getTotal()}\nTrue: ${this.getTrue()}\nFalse: ${this.getFalse()}\nAccuracy: ${this.getMicroAccuracy() * - 100}%\nPrecision: ${this.getMicroPrecision() * - 100}%\nRecall: ${this.getMicroRecall() * 100}%\nF1: ${this.getMicroF1() * - 100}%` + * @param {string} [type='micro'] Type of stats (`micro`/`macro`/`weighted` average) + * @todo Add options to use `micro`/`macro`/`weighted` + * @protected + */ + getShortStats(type = 'micro') { + const stats = `Total: ${this.getTotal()}\nTrue: ${this.getTrue()}\nFalse: ${this.getFalse()}\n` + let Ac = 0 + let Pr = 0 + let R = 0 + let F1 = 0 + switch (type) { + case 'macro': + Ac = this.getMacroAccuracy() + Pr = this.getMacroPrecision() + R = this.getMacroRecall() + F1 = this.getMacroF1() + break + case 'weighted': + Ac = this.getWeightedAccuracy() + Pr = this.getWeightedPrecision() + R = this.getWeightedRecall() + F1 = this.getWeightedF1() + break + default: + Ac = this.getMicroAccuracy() + Pr = this.getMicroPrecision() + R = this.getMicroRecall() + F1 = this.getMicroF1() + } + + return `${stats}Accuracy: ${Ac * 100}%\nPrecision: ${Pr * + 100}%\nRecall: ${R * 100}%\nF1: ${F1 * 100}%` } /** @@ -639,6 +740,11 @@ class ConfusionMatrix { */ getStats() { const total = this.getTotal() + const weightedAverage = () => { + const res = {} + for (const m of METRICS) res[camel(m)] = this[`getWeighted${m}`]() + return res + } return { total, correctPredictions: this.getTrue(), @@ -646,6 +752,7 @@ class ConfusionMatrix { classes: this.classes, microAvg: mAvg(this, 'Mi'), macroAvg: mAvg(this, 'Ma'), + weightedAvg: weightedAverage(), results: getResults(this, total), } } diff --git a/src/utils.js b/src/utils.js index a95fbf3..58cfae3 100644 --- a/src/utils.js +++ b/src/utils.js @@ -181,6 +181,18 @@ const clrVal = (num, maxVal, goodValue = false) => { */ const fxSum = (cm, fx) => sum(...cm.classes.map(c => cm[`get${fx}`](c))) +/** + * Functional weighted sum on all classes of a confusion matrix. + * @param {ConfusionMatrix} cm Confusion matrix instance + * @param {string} fx Function name (without the `get`) + * @returns {number} Sum + * @private + */ +const fxWeightedSum = (cm, fx) => { + const nums = cm.classes.map(c => cm[`get${fx}`](c) * cm.getSupport(c)) + return sum(...nums) +} + /** * Maps the values of an array to an object using a function, * where the key-value pairs consist of the original value as the key and the mapped value. @@ -208,5 +220,6 @@ module.exports = { rmEmpty, clrVal, fxSum, + fxWeightedSum, mapObject, }