Skip to content

Commit fc8646f

Browse files
Kangyi Zhangcaisq
Kangyi Zhang
authored andcommitted
Update webcam-transfer-learning with tf.data.webcam API (#267)
1 parent c80d0e2 commit fc8646f

File tree

4 files changed

+88
-169
lines changed

4 files changed

+88
-169
lines changed

webcam-transfer-learning/index.js

+42-30
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
*/
1717

1818
import * as tf from '@tensorflow/tfjs';
19+
import * as tfd from '@tensorflow/tfjs-data';
1920

2021
import {ControllerDataset} from './controller_dataset';
2122
import * as ui from './ui';
22-
import {Webcam} from './webcam';
2323

2424
// The number of classes we want to predict. In this example, we will be
2525
// predicting 4 classes for up, down, left, and right.
2626
const NUM_CLASSES = 4;
2727

28-
// A webcam class that generates Tensors from the images from the webcam.
29-
const webcam = new Webcam(document.getElementById('webcam'));
28+
// A webcam iterator that generates Tensors from the images from the webcam.
29+
let webcam;
3030

3131
// The dataset object where we will store activations.
3232
const controllerDataset = new ControllerDataset(NUM_CLASSES);
@@ -48,15 +48,15 @@ async function loadTruncatedMobileNet() {
4848
// When the UI buttons are pressed, read a frame from the webcam and associate
4949
// it with the class label given by the button. up, down, left, right are
5050
// labels 0, 1, 2, 3 respectively.
51-
ui.setExampleHandler(label => {
52-
tf.tidy(() => {
53-
const img = webcam.capture();
54-
controllerDataset.addExample(truncatedMobileNet.predict(img), label);
51+
ui.setExampleHandler(async label => {
52+
let img = await getImage();
5553

56-
// Draw the preview thumbnail.
57-
ui.drawThumb(img, label);
58-
});
59-
});
54+
controllerDataset.addExample(truncatedMobileNet.predict(img), label);
55+
56+
// Draw the preview thumbnail.
57+
ui.drawThumb(img, label);
58+
img.dispose();
59+
})
6060

6161
/**
6262
* Sets up and trains the classifier.
@@ -129,32 +129,41 @@ let isPredicting = false;
129129
async function predict() {
130130
ui.isPredicting();
131131
while (isPredicting) {
132-
const predictedClass = tf.tidy(() => {
133-
// Capture the frame from the webcam.
134-
const img = webcam.capture();
132+
// Capture the frame from the webcam.
133+
const img = await getImage();
135134

136-
// Make a prediction through mobilenet, getting the internal activation of
137-
// the mobilenet model, i.e., "embeddings" of the input images.
138-
const embeddings = truncatedMobileNet.predict(img);
135+
// Make a prediction through mobilenet, getting the internal activation of
136+
// the mobilenet model, i.e., "embeddings" of the input images.
137+
const embeddings = truncatedMobileNet.predict(img);
139138

140-
// Make a prediction through our newly-trained model using the embeddings
141-
// from mobilenet as input.
142-
const predictions = model.predict(embeddings);
143-
144-
// Returns the index with the maximum probability. This number corresponds
145-
// to the class the model thinks is the most probable given the input.
146-
return predictions.as1D().argMax();
147-
});
139+
// Make a prediction through our newly-trained model using the embeddings
140+
// from mobilenet as input.
141+
const predictions = model.predict(embeddings);
148142

143+
// Returns the index with the maximum probability. This number corresponds
144+
// to the class the model thinks is the most probable given the input.
145+
const predictedClass = predictions.as1D().argMax();
149146
const classId = (await predictedClass.data())[0];
150-
predictedClass.dispose();
147+
img.dispose();
151148

152149
ui.predictClass(classId);
153150
await tf.nextFrame();
154151
}
155152
ui.donePredicting();
156153
}
157154

155+
/**
156+
* Captures a frame from the webcam and normalizes it between -1 and 1.
157+
* Returns a batched image (1-element batch) of shape [1, w, h, c].
158+
*/
159+
async function getImage() {
160+
const img = await webcam.capture();
161+
const processedImg =
162+
tf.tidy(() => img.expandDims(0).toFloat().div(127).sub(1));
163+
img.dispose();
164+
return processedImg;
165+
}
166+
158167
document.getElementById('train').addEventListener('click', async () => {
159168
ui.trainStatus('Training...');
160169
await tf.nextFrame();
@@ -170,18 +179,21 @@ document.getElementById('predict').addEventListener('click', () => {
170179

171180
async function init() {
172181
try {
173-
await webcam.setup();
182+
webcam = await tfd.webcam(document.getElementById('webcam'));
174183
} catch (e) {
184+
console.log(e);
175185
document.getElementById('no-webcam').style.display = 'block';
176186
}
177187
truncatedMobileNet = await loadTruncatedMobileNet();
178188

189+
ui.init();
190+
179191
// Warm up the model. This uploads weights to the GPU and compiles the WebGL
180192
// programs so the first time we collect data from the webcam it will be
181193
// quick.
182-
tf.tidy(() => truncatedMobileNet.predict(webcam.capture()));
183-
184-
ui.init();
194+
const screenShot = await webcam.capture();
195+
truncatedMobileNet.predict(screenShot.expandDims(0));
196+
screenShot.dispose();
185197
}
186198

187199
// Initialize the application.

webcam-transfer-learning/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"node": ">=8.9.0"
1010
},
1111
"dependencies": {
12-
"@tensorflow/tfjs": "^1.0.4",
12+
"@tensorflow/tfjs": "^1.1.0",
1313
"vega-embed": "^3.0.0"
1414
},
1515
"scripts": {

webcam-transfer-learning/webcam.js

-106
This file was deleted.

0 commit comments

Comments
 (0)