@@ -8,6 +8,17 @@ const classifierBuilder = require('./classifier')
88const categories = require ( './categories' )
99const ConfusionMatrix = require ( './confusionMatrix' )
1010
11+ const spinner = new Spinner ( 'Loading...' , [
12+ '⣾' ,
13+ '⣽' ,
14+ '⣻' ,
15+ '⢿' ,
16+ '⡿' ,
17+ '⣟' ,
18+ '⣯' ,
19+ '⣷' ,
20+ ] )
21+
1122/**
1223 * NodeJS Classification-based learner.
1324 * @class Learner
@@ -60,20 +71,11 @@ class Learner {
6071 */
6172 train ( trainSet = this . trainSet ) {
6273 //@todo Move this so it could be used for any potentially lengthy ops
63- const training = new Spinner ( 'Training...' , [
64- '⣾' ,
65- '⣽' ,
66- '⣻' ,
67- '⢿' ,
68- '⡿' ,
69- '⣟' ,
70- '⣯' ,
71- '⣷' ,
72- ] )
73- training . start ( )
74+ spinner . message ( 'Training...' )
75+ spinner . start ( )
7476 this . classifier . trainBatch ( trainSet )
75- training . message ( 'Training complete' )
76- training . stop ( )
77+ // spinner .message('Training complete')
78+ spinner . stop ( )
7779 }
7880
7981 /**
@@ -82,18 +84,27 @@ class Learner {
8284 * @public
8385 */
8486 eval ( ) {
87+ spinner . message ( 'Evaluating...' )
88+ spinner . start ( )
8589 const actual = [ ]
8690 const predicted = [ ]
91+ const len = this . testSet . length
92+ let idx = 0
8793 for ( const data of this . testSet ) {
8894 const predictions = this . classify ( data . input )
8995 actual . push ( data . output )
9096 predicted . push ( predictions . length ? predictions [ 0 ] : 'null' ) //Ignores the rest (as it only wants one guess)
97+ spinner . message (
98+ `Evaluating instances (${ Math . round ( ( idx ++ / len ) * 10000 ) / 100 } %)` ,
99+ )
91100 }
92101 this . confusionMatrix = ConfusionMatrix . fromData (
93102 actual ,
94103 predicted ,
95104 categories ,
96105 )
106+ // spinner.message('Evaluation complete')
107+ spinner . stop ( )
97108 return this . confusionMatrix . getStats ( )
98109 }
99110
@@ -182,29 +193,30 @@ class Learner {
182193 F_1 (or effectiveness) = 2 * (Pr * R) / (Pr + R)
183194 ...
184195 */
196+ spinner . message ( 'Cross-validating...' )
197+ spinner . start ( )
185198 this . macroAvg = new PrecisionRecall ( )
186199 this . microAvg = new PrecisionRecall ( )
200+ const set = [ ...this . trainSet , ...this . validationSet ]
187201
188- partitions . partitions (
189- [ ...this . trainSet , ...this . validationSet ] ,
190- numOfFolds ,
191- ( trainSet , validationSet ) => {
192- if ( log )
193- process . stdout . write (
194- `Training on ${ trainSet . length } samples, testing ${ validationSet . length } samples` ,
195- )
196- this . train ( trainSet )
197- test (
198- this . classifier ,
199- validationSet ,
200- verboseLevel ,
201- this . microAvg ,
202- this . macroAvg ,
203- )
204- } ,
205- )
202+ partitions . partitions ( set , numOfFolds , ( trainSet , validationSet ) => {
203+ const status = `Training on ${ trainSet . length } samples, testing ${ validationSet . length } samples`
204+ //eslint-disable-next-line babel/no-unused-expressions
205+ log ? process . stdout . write ( status ) : spinner . message ( status )
206+ this . train ( trainSet )
207+ test (
208+ this . classifier ,
209+ validationSet ,
210+ verboseLevel ,
211+ this . microAvg ,
212+ this . macroAvg ,
213+ )
214+ } )
215+ spinner . message ( 'Calculating stats' )
206216 this . macroAvg . calculateMacroAverageStats ( numOfFolds )
207217 this . microAvg . calculateStats ( )
218+ // spinner.message('Cross-validation complete')
219+ spinner . stop ( )
208220 return {
209221 macroAvg : this . macroAvg . fullStats ( ) , //preferable in 2-class settings or in balanced multi-class settings
210222 microAvg : this . microAvg . fullStats ( ) , //preferable in multi-class settings (in case of class imbalance)
@@ -278,6 +290,8 @@ class Learner {
278290 * @public
279291 */
280292 getCategoryPartition ( ) {
293+ spinner . message ( 'Generating category partitions...' )
294+ spinner . start ( )
281295 const res = { }
282296 categories . forEach ( cat => {
283297 res [ cat ] = {
@@ -288,11 +302,14 @@ class Learner {
288302 }
289303 } )
290304 this . dataset . forEach ( data => {
305+ spinner . message ( `Adding ${ data . output } data` )
291306 ++ res [ data . output ] . overall
292307 if ( this . trainSet . includes ( data ) ) ++ res [ data . output ] . train
293308 if ( this . validationSet . includes ( data ) ) ++ res [ data . output ] . validation
294309 if ( this . testSet . includes ( data ) ) ++ res [ data . output ] . test
295310 } )
311+ // spinner.message('Category partitions complete')
312+ spinner . stop ( )
296313 return res
297314 }
298315
0 commit comments