23
23
*/
24
24
25
25
import * as tfvis from '@tensorflow/tfjs-vis' ;
26
- const worker = new Worker ( './worker.js' ) ;
26
+ const worker =
27
+ new Worker ( new URL ( './worker.js' , import . meta. url ) , { type : 'module' } ) ;
27
28
28
29
async function runAdditionRNNDemo ( ) {
29
30
document . getElementById ( 'trainModel' ) . addEventListener ( 'click' , async ( ) => {
30
31
const digits = + ( document . getElementById ( 'digits' ) ) . value ;
31
32
const trainingSize = + ( document . getElementById ( 'trainingSize' ) ) . value ;
32
33
const rnnTypeSelect = document . getElementById ( 'rnnType' ) ;
33
34
const rnnType =
34
- rnnTypeSelect . options [ rnnTypeSelect . selectedIndex ] . getAttribute (
35
- 'value' ) ;
35
+ rnnTypeSelect . options [ rnnTypeSelect . selectedIndex ] . getAttribute (
36
+ 'value' ) ;
36
37
const layers = + ( document . getElementById ( 'rnnLayers' ) ) . value ;
37
38
const hiddenSize = + ( document . getElementById ( 'rnnLayerSize' ) ) . value ;
38
39
const batchSize = + ( document . getElementById ( 'batchSize' ) ) . value ;
@@ -48,45 +49,55 @@ async function runAdditionRNNDemo() {
48
49
const trainingSizeLimit = Math . pow ( Math . pow ( 10 , digits ) , 2 ) ;
49
50
if ( trainingSize > trainingSizeLimit ) {
50
51
status . textContent =
51
- `With digits = ${ digits } , you cannot have more than ` +
52
- `${ trainingSizeLimit } examples` ;
52
+ `With digits = ${ digits } , you cannot have more than ` +
53
+ `${ trainingSizeLimit } examples` ;
53
54
return ;
54
55
}
55
- worker . postMessage ( { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } ) ;
56
+ worker . postMessage ( {
57
+ digits,
58
+ trainingSize,
59
+ rnnType,
60
+ layers,
61
+ hiddenSize,
62
+ trainIterations,
63
+ batchSize,
64
+ numTestExamples
65
+ } ) ;
56
66
worker . addEventListener ( 'message' , ( e ) => {
57
67
if ( e . data . isPredict ) {
58
- const { i, iterations, modelFitTime, lossValues, accuracyValues } = e . data ;
68
+ const { i, iterations, modelFitTime, lossValues, accuracyValues} =
69
+ e . data ;
59
70
document . getElementById ( 'trainStatus' ) . textContent =
60
- `Iteration ${ i + 1 } of ${ iterations } : ` +
61
- `Time per iteration: ${ modelFitTime . toFixed ( 3 ) } (seconds)` ;
71
+ `Iteration ${ i + 1 } of ${ iterations } : ` +
72
+ `Time per iteration: ${ modelFitTime . toFixed ( 3 ) } (seconds)` ;
62
73
const lossContainer = document . getElementById ( 'lossChart' ) ;
63
74
tfvis . render . linechart (
64
- lossContainer , { values : lossValues , series : [ 'train' , 'validation' ] } ,
65
- {
66
- width : 420 ,
67
- height : 300 ,
68
- xLabel : 'epoch' ,
69
- yLabel : 'loss' ,
70
- } ) ;
75
+ lossContainer ,
76
+ { values : lossValues , series : [ 'train' , 'validation' ] } , {
77
+ width : 420 ,
78
+ height : 300 ,
79
+ xLabel : 'epoch' ,
80
+ yLabel : 'loss' ,
81
+ } ) ;
71
82
72
83
const accuracyContainer = document . getElementById ( 'accuracyChart' ) ;
73
84
tfvis . render . linechart (
74
- accuracyContainer ,
75
- { values : accuracyValues , series : [ 'train' , 'validation' ] } , {
76
- width : 420 ,
77
- height : 300 ,
78
- xLabel : 'epoch' ,
79
- yLabel : 'accuracy' ,
80
- } ) ;
85
+ accuracyContainer ,
86
+ { values : accuracyValues , series : [ 'train' , 'validation' ] } , {
87
+ width : 420 ,
88
+ height : 300 ,
89
+ xLabel : 'epoch' ,
90
+ yLabel : 'accuracy' ,
91
+ } ) ;
81
92
} else {
82
- const { isCorrect, examples } = e . data ;
93
+ const { isCorrect, examples} = e . data ;
83
94
const examplesDiv = document . getElementById ( 'testExamples' ) ;
84
95
const examplesContent = examples . map (
85
- ( example , i ) =>
86
- `<div class="${
87
- isCorrect [ i ] ? 'answer-correct' : 'answer-wrong' } ">` +
88
- `${ example } ` +
89
- `</div>` ) ;
96
+ ( example , i ) =>
97
+ `<div class="${
98
+ isCorrect [ i ] ? 'answer-correct' : 'answer-wrong' } ">` +
99
+ `${ example } ` +
100
+ `</div>` ) ;
90
101
91
102
examplesDiv . innerHTML = examplesContent . join ( '\n' ) ;
92
103
}
0 commit comments