Skip to content

Commit

Permalink
feat(index): added status update (#21)
Browse files Browse the repository at this point in the history
* feat(index): added status update

... For a better DX/UX

* chore(playground): update
  • Loading branch information
Berkmann18 authored Jul 22, 2019
1 parent 95c8e76 commit c1d5d15
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 31 deletions.
2 changes: 2 additions & 0 deletions playground/playground.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ writeFileSync(
'playground-fullStats.json',
JSON.stringify(longStats, null, 2),
) && console.log('Saved learner to "playground-fullStats.json"')

process.exit(0)
79 changes: 48 additions & 31 deletions src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ const classifierBuilder = require('./classifier')
const categories = require('./categories')
const ConfusionMatrix = require('./confusionMatrix')

const spinner = new Spinner('Loading...', [
'⣾',
'⣽',
'⣻',
'⢿',
'⡿',
'⣟',
'⣯',
'⣷',
])

/**
* NodeJS Classification-based learner.
* @class Learner
Expand Down Expand Up @@ -60,20 +71,11 @@ class Learner {
*/
train(trainSet = this.trainSet) {
//@todo Move this so it could be used for any potentially lengthy ops
const training = new Spinner('Training...', [
'⣾',
'⣽',
'⣻',
'⢿',
'⡿',
'⣟',
'⣯',
'⣷',
])
training.start()
spinner.message('Training...')
spinner.start()
this.classifier.trainBatch(trainSet)
training.message('Training complete')
training.stop()
// spinner.message('Training complete')
spinner.stop()
}

/**
Expand All @@ -82,18 +84,27 @@ class Learner {
* @public
*/
eval() {
spinner.message('Evaluating...')
spinner.start()
const actual = []
const predicted = []
const len = this.testSet.length
let idx = 0
for (const data of this.testSet) {
const predictions = this.classify(data.input)
actual.push(data.output)
predicted.push(predictions.length ? predictions[0] : 'null') //Ignores the rest (as it only wants one guess)
spinner.message(
`Evaluating instances (${Math.round((idx++ / len) * 10000) / 100}%)`,
)
}
this.confusionMatrix = ConfusionMatrix.fromData(
actual,
predicted,
categories,
)
// spinner.message('Evaluation complete')
spinner.stop()
return this.confusionMatrix.getStats()
}

Expand Down Expand Up @@ -182,29 +193,30 @@ class Learner {
F_1 (or effectiveness) = 2 * (Pr * R) / (Pr + R)
...
*/
spinner.message('Cross-validating...')
spinner.start()
this.macroAvg = new PrecisionRecall()
this.microAvg = new PrecisionRecall()
const set = [...this.trainSet, ...this.validationSet]

partitions.partitions(
[...this.trainSet, ...this.validationSet],
numOfFolds,
(trainSet, validationSet) => {
if (log)
process.stdout.write(
`Training on ${trainSet.length} samples, testing ${validationSet.length} samples`,
)
this.train(trainSet)
test(
this.classifier,
validationSet,
verboseLevel,
this.microAvg,
this.macroAvg,
)
},
)
partitions.partitions(set, numOfFolds, (trainSet, validationSet) => {
const status = `Training on ${trainSet.length} samples, testing ${validationSet.length} samples`
//eslint-disable-next-line babel/no-unused-expressions
log ? process.stdout.write(status) : spinner.message(status)
this.train(trainSet)
test(
this.classifier,
validationSet,
verboseLevel,
this.microAvg,
this.macroAvg,
)
})
spinner.message('Calculating stats')
this.macroAvg.calculateMacroAverageStats(numOfFolds)
this.microAvg.calculateStats()
// spinner.message('Cross-validation complete')
spinner.stop()
return {
macroAvg: this.macroAvg.fullStats(), //preferable in 2-class settings or in balanced multi-class settings
microAvg: this.microAvg.fullStats(), //preferable in multi-class settings (in case of class imbalance)
Expand Down Expand Up @@ -278,6 +290,8 @@ class Learner {
* @public
*/
getCategoryPartition() {
spinner.message('Generating category partitions...')
spinner.start()
const res = {}
categories.forEach(cat => {
res[cat] = {
Expand All @@ -288,11 +302,14 @@ class Learner {
}
})
this.dataset.forEach(data => {
spinner.message(`Adding ${data.output} data`)
++res[data.output].overall
if (this.trainSet.includes(data)) ++res[data.output].train
if (this.validationSet.includes(data)) ++res[data.output].validation
if (this.testSet.includes(data)) ++res[data.output].test
})
// spinner.message('Category partitions complete')
spinner.stop()
return res
}

Expand Down

0 comments on commit c1d5d15

Please sign in to comment.