16
16
*/
17
17
18
18
import * as tf from '@tensorflow/tfjs' ;
19
+ import * as tfd from '@tensorflow/tfjs-data' ;
19
20
20
21
import { ControllerDataset } from './controller_dataset' ;
21
22
import * as ui from './ui' ;
22
- import { Webcam } from './webcam' ;
23
23
24
24
// The number of classes we want to predict. In this example, we will be
25
25
// predicting 4 classes for up, down, left, and right.
26
26
const NUM_CLASSES = 4 ;
27
27
28
- // A webcam class that generates Tensors from the images from the webcam.
29
- const webcam = new Webcam ( document . getElementById ( 'webcam' ) ) ;
28
+ // A webcam iterator that generates Tensors from the images from the webcam.
29
+ let webcam ;
30
30
31
31
// The dataset object where we will store activations.
32
32
const controllerDataset = new ControllerDataset ( NUM_CLASSES ) ;
@@ -48,15 +48,15 @@ async function loadTruncatedMobileNet() {
48
48
// When the UI buttons are pressed, read a frame from the webcam and associate
49
49
// it with the class label given by the button. up, down, left, right are
50
50
// labels 0, 1, 2, 3 respectively.
51
- ui . setExampleHandler ( label => {
52
- tf . tidy ( ( ) => {
53
- const img = webcam . capture ( ) ;
54
- controllerDataset . addExample ( truncatedMobileNet . predict ( img ) , label ) ;
51
+ ui . setExampleHandler ( async label => {
52
+ let img = await getImage ( ) ;
55
53
56
- // Draw the preview thumbnail.
57
- ui . drawThumb ( img , label ) ;
58
- } ) ;
59
- } ) ;
54
+ controllerDataset . addExample ( truncatedMobileNet . predict ( img ) , label ) ;
55
+
56
+ // Draw the preview thumbnail.
57
+ ui . drawThumb ( img , label ) ;
58
+ img . dispose ( ) ;
59
+ } )
60
60
61
61
/**
62
62
* Sets up and trains the classifier.
@@ -129,32 +129,41 @@ let isPredicting = false;
129
129
async function predict ( ) {
130
130
ui . isPredicting ( ) ;
131
131
while ( isPredicting ) {
132
- const predictedClass = tf . tidy ( ( ) => {
133
- // Capture the frame from the webcam.
134
- const img = webcam . capture ( ) ;
132
+ // Capture the frame from the webcam.
133
+ const img = await getImage ( ) ;
135
134
136
- // Make a prediction through mobilenet, getting the internal activation of
137
- // the mobilenet model, i.e., "embeddings" of the input images.
138
- const embeddings = truncatedMobileNet . predict ( img ) ;
135
+ // Make a prediction through mobilenet, getting the internal activation of
136
+ // the mobilenet model, i.e., "embeddings" of the input images.
137
+ const embeddings = truncatedMobileNet . predict ( img ) ;
139
138
140
- // Make a prediction through our newly-trained model using the embeddings
141
- // from mobilenet as input.
142
- const predictions = model . predict ( embeddings ) ;
143
-
144
- // Returns the index with the maximum probability. This number corresponds
145
- // to the class the model thinks is the most probable given the input.
146
- return predictions . as1D ( ) . argMax ( ) ;
147
- } ) ;
139
+ // Make a prediction through our newly-trained model using the embeddings
140
+ // from mobilenet as input.
141
+ const predictions = model . predict ( embeddings ) ;
148
142
143
+ // Returns the index with the maximum probability. This number corresponds
144
+ // to the class the model thinks is the most probable given the input.
145
+ const predictedClass = predictions . as1D ( ) . argMax ( ) ;
149
146
const classId = ( await predictedClass . data ( ) ) [ 0 ] ;
150
- predictedClass . dispose ( ) ;
147
+ img . dispose ( ) ;
151
148
152
149
ui . predictClass ( classId ) ;
153
150
await tf . nextFrame ( ) ;
154
151
}
155
152
ui . donePredicting ( ) ;
156
153
}
157
154
155
+ /**
156
+ * Captures a frame from the webcam and normalizes it between -1 and 1.
157
+ * Returns a batched image (1-element batch) of shape [1, w, h, c].
158
+ */
159
+ async function getImage ( ) {
160
+ const img = await webcam . capture ( ) ;
161
+ const processedImg =
162
+ tf . tidy ( ( ) => img . expandDims ( 0 ) . toFloat ( ) . div ( 127 ) . sub ( 1 ) ) ;
163
+ img . dispose ( ) ;
164
+ return processedImg ;
165
+ }
166
+
158
167
document . getElementById ( 'train' ) . addEventListener ( 'click' , async ( ) => {
159
168
ui . trainStatus ( 'Training...' ) ;
160
169
await tf . nextFrame ( ) ;
@@ -170,18 +179,21 @@ document.getElementById('predict').addEventListener('click', () => {
170
179
171
180
async function init ( ) {
172
181
try {
173
- await webcam . setup ( ) ;
182
+ webcam = await tfd . webcam ( document . getElementById ( 'webcam' ) ) ;
174
183
} catch ( e ) {
184
+ console . log ( e ) ;
175
185
document . getElementById ( 'no-webcam' ) . style . display = 'block' ;
176
186
}
177
187
truncatedMobileNet = await loadTruncatedMobileNet ( ) ;
178
188
189
+ ui . init ( ) ;
190
+
179
191
// Warm up the model. This uploads weights to the GPU and compiles the WebGL
180
192
// programs so the first time we collect data from the webcam it will be
181
193
// quick.
182
- tf . tidy ( ( ) => truncatedMobileNet . predict ( webcam . capture ( ) ) ) ;
183
-
184
- ui . init ( ) ;
194
+ const screenShot = await webcam . capture ( ) ;
195
+ truncatedMobileNet . predict ( screenShot . expandDims ( 0 ) ) ;
196
+ screenShot . dispose ( ) ;
185
197
}
186
198
187
199
// Initialize the application.
0 commit comments