1616 */
1717
1818import * as tf from '@tensorflow/tfjs' ;
19+ import * as tfd from '@tensorflow/tfjs-data' ;
1920
2021import { ControllerDataset } from './controller_dataset' ;
2122import * as ui from './ui' ;
22- import { Webcam } from './webcam' ;
2323
2424// The number of classes we want to predict. In this example, we will be
2525// predicting 4 classes for up, down, left, and right.
2626const NUM_CLASSES = 4 ;
2727
28- // A webcam class that generates Tensors from the images from the webcam.
29- const webcam = new Webcam ( document . getElementById ( 'webcam' ) ) ;
28+ // A webcam iterator that generates Tensors from the images from the webcam.
29+ let webcam ;
3030
3131// The dataset object where we will store activations.
3232const controllerDataset = new ControllerDataset ( NUM_CLASSES ) ;
@@ -48,15 +48,15 @@ async function loadTruncatedMobileNet() {
4848// When the UI buttons are pressed, read a frame from the webcam and associate
4949// it with the class label given by the button. up, down, left, right are
5050// labels 0, 1, 2, 3 respectively.
51- ui . setExampleHandler ( label => {
52- tf . tidy ( ( ) => {
53- const img = webcam . capture ( ) ;
54- controllerDataset . addExample ( truncatedMobileNet . predict ( img ) , label ) ;
51+ ui . setExampleHandler ( async label => {
52+ let img = await getImage ( ) ;
5553
56- // Draw the preview thumbnail.
57- ui . drawThumb ( img , label ) ;
58- } ) ;
59- } ) ;
54+ controllerDataset . addExample ( truncatedMobileNet . predict ( img ) , label ) ;
55+
56+ // Draw the preview thumbnail.
57+ ui . drawThumb ( img , label ) ;
58+ img . dispose ( ) ;
59+ } )
6060
6161/**
6262 * Sets up and trains the classifier.
@@ -129,32 +129,41 @@ let isPredicting = false;
129129async function predict ( ) {
130130 ui . isPredicting ( ) ;
131131 while ( isPredicting ) {
132- const predictedClass = tf . tidy ( ( ) => {
133- // Capture the frame from the webcam.
134- const img = webcam . capture ( ) ;
132+ // Capture the frame from the webcam.
133+ const img = await getImage ( ) ;
135134
136- // Make a prediction through mobilenet, getting the internal activation of
137- // the mobilenet model, i.e., "embeddings" of the input images.
138- const embeddings = truncatedMobileNet . predict ( img ) ;
135+ // Make a prediction through mobilenet, getting the internal activation of
136+ // the mobilenet model, i.e., "embeddings" of the input images.
137+ const embeddings = truncatedMobileNet . predict ( img ) ;
139138
140- // Make a prediction through our newly-trained model using the embeddings
141- // from mobilenet as input.
142- const predictions = model . predict ( embeddings ) ;
143-
144- // Returns the index with the maximum probability. This number corresponds
145- // to the class the model thinks is the most probable given the input.
146- return predictions . as1D ( ) . argMax ( ) ;
147- } ) ;
139+ // Make a prediction through our newly-trained model using the embeddings
140+ // from mobilenet as input.
141+ const predictions = model . predict ( embeddings ) ;
148142
143+ // Returns the index with the maximum probability. This number corresponds
144+ // to the class the model thinks is the most probable given the input.
145+ const predictedClass = predictions . as1D ( ) . argMax ( ) ;
149146 const classId = ( await predictedClass . data ( ) ) [ 0 ] ;
150- predictedClass . dispose ( ) ;
147+ img . dispose ( ) ;
151148
152149 ui . predictClass ( classId ) ;
153150 await tf . nextFrame ( ) ;
154151 }
155152 ui . donePredicting ( ) ;
156153}
157154
155+ /**
156+ * Captures a frame from the webcam and normalizes it between -1 and 1.
157+ * Returns a batched image (1-element batch) of shape [1, w, h, c].
158+ */
159+ async function getImage ( ) {
160+ const img = await webcam . capture ( ) ;
161+ const processedImg =
162+ tf . tidy ( ( ) => img . expandDims ( 0 ) . toFloat ( ) . div ( 127 ) . sub ( 1 ) ) ;
163+ img . dispose ( ) ;
164+ return processedImg ;
165+ }
166+
158167document . getElementById ( 'train' ) . addEventListener ( 'click' , async ( ) => {
159168 ui . trainStatus ( 'Training...' ) ;
160169 await tf . nextFrame ( ) ;
@@ -170,18 +179,21 @@ document.getElementById('predict').addEventListener('click', () => {
170179
171180async function init ( ) {
172181 try {
173- await webcam . setup ( ) ;
182+ webcam = await tfd . webcam ( document . getElementById ( 'webcam' ) ) ;
174183 } catch ( e ) {
184+ console . log ( e ) ;
175185 document . getElementById ( 'no-webcam' ) . style . display = 'block' ;
176186 }
177187 truncatedMobileNet = await loadTruncatedMobileNet ( ) ;
178188
189+ ui . init ( ) ;
190+
179191 // Warm up the model. This uploads weights to the GPU and compiles the WebGL
180192 // programs so the first time we collect data from the webcam it will be
181193 // quick.
182- tf . tidy ( ( ) => truncatedMobileNet . predict ( webcam . capture ( ) ) ) ;
183-
184- ui . init ( ) ;
194+ const screenShot = await webcam . capture ( ) ;
195+ truncatedMobileNet . predict ( screenShot . expandDims ( 0 ) ) ;
196+ screenShot . dispose ( ) ;
185197}
186198
187199// Initialize the application.
0 commit comments