From c1d5d15eb1ae5755f75d52407b041b24a2d3bc4a Mon Sep 17 00:00:00 2001 From: Maximilian Berkmann Date: Mon, 22 Jul 2019 16:27:24 +0100 Subject: [PATCH] feat(index): added status update (#21) * feat(index): added status update ... For a better DX/UX * chore(playground): update --- playground/playground.js | 2 + src/index.js | 79 ++++++++++++++++++++++++---------------- 2 files changed, 50 insertions(+), 31 deletions(-) diff --git a/playground/playground.js b/playground/playground.js index c8bb4bc..24a9669 100644 --- a/playground/playground.js +++ b/playground/playground.js @@ -47,3 +47,5 @@ writeFileSync( 'playground-fullStats.json', JSON.stringify(longStats, null, 2), ) && console.log('Saved learner to "playground-fullStats.json"') + +process.exit(0) diff --git a/src/index.js b/src/index.js index 473b493..cac2b4e 100644 --- a/src/index.js +++ b/src/index.js @@ -8,6 +8,17 @@ const classifierBuilder = require('./classifier') const categories = require('./categories') const ConfusionMatrix = require('./confusionMatrix') +const spinner = new Spinner('Loading...', [ + '⣾', + '⣽', + '⣻', + '⢿', + '⡿', + '⣟', + '⣯', + '⣷', +]) + /** * NodeJS Classification-based learner. * @class Learner @@ -60,20 +71,11 @@ class Learner { */ train(trainSet = this.trainSet) { //@todo Move this so it could be used for any potentially lengthy ops - const training = new Spinner('Training...', [ - '⣾', - '⣽', - '⣻', - '⢿', - '⡿', - '⣟', - '⣯', - '⣷', - ]) - training.start() + spinner.message('Training...') + spinner.start() this.classifier.trainBatch(trainSet) - training.message('Training complete') - training.stop() + // spinner.message('Training complete') + spinner.stop() } /** @@ -82,18 +84,27 @@ class Learner { * @public */ eval() { + spinner.message('Evaluating...') + spinner.start() const actual = [] const predicted = [] + const len = this.testSet.length + let idx = 0 for (const data of this.testSet) { const predictions = this.classify(data.input) actual.push(data.output) predicted.push(predictions.length ? predictions[0] : 'null') //Ignores the rest (as it only wants one guess) + spinner.message( + `Evaluating instances (${Math.round((idx++ / len) * 10000) / 100}%)`, + ) } this.confusionMatrix = ConfusionMatrix.fromData( actual, predicted, categories, ) + // spinner.message('Evaluation complete') + spinner.stop() return this.confusionMatrix.getStats() } @@ -182,29 +193,30 @@ class Learner { F_1 (or effectiveness) = 2 * (Pr * R) / (Pr + R) ... */ + spinner.message('Cross-validating...') + spinner.start() this.macroAvg = new PrecisionRecall() this.microAvg = new PrecisionRecall() + const set = [...this.trainSet, ...this.validationSet] - partitions.partitions( - [...this.trainSet, ...this.validationSet], - numOfFolds, - (trainSet, validationSet) => { - if (log) - process.stdout.write( - `Training on ${trainSet.length} samples, testing ${validationSet.length} samples`, - ) - this.train(trainSet) - test( - this.classifier, - validationSet, - verboseLevel, - this.microAvg, - this.macroAvg, - ) - }, - ) + partitions.partitions(set, numOfFolds, (trainSet, validationSet) => { + const status = `Training on ${trainSet.length} samples, testing ${validationSet.length} samples` + //eslint-disable-next-line babel/no-unused-expressions + log ? process.stdout.write(status) : spinner.message(status) + this.train(trainSet) + test( + this.classifier, + validationSet, + verboseLevel, + this.microAvg, + this.macroAvg, + ) + }) + spinner.message('Calculating stats') this.macroAvg.calculateMacroAverageStats(numOfFolds) this.microAvg.calculateStats() + // spinner.message('Cross-validation complete') + spinner.stop() return { macroAvg: this.macroAvg.fullStats(), //preferable in 2-class settings or in balanced multi-class settings microAvg: this.microAvg.fullStats(), //preferable in multi-class settings (in case of class imbalance) @@ -278,6 +290,8 @@ class Learner { * @public */ getCategoryPartition() { + spinner.message('Generating category partitions...') + spinner.start() const res = {} categories.forEach(cat => { res[cat] = { @@ -288,11 +302,14 @@ class Learner { } }) this.dataset.forEach(data => { + spinner.message(`Adding ${data.output} data`) ++res[data.output].overall if (this.trainSet.includes(data)) ++res[data.output].train if (this.validationSet.includes(data)) ++res[data.output].validation if (this.testSet.includes(data)) ++res[data.output].test }) + // spinner.message('Category partitions complete') + spinner.stop() return res }