Skip to content

Commit 49154ef

Browse files
authored
[tfjs-react-native] webcam demo (#1938)
INTERNAL
1 parent 19ef726 commit 49154ef

File tree

9 files changed

+678
-13
lines changed

9 files changed

+678
-13
lines changed

tfjs-react-native/integration_rn59/App.tsx

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,19 @@
1616
*/
1717

1818
import React, { Fragment } from 'react';
19-
import { Button, SafeAreaView, StyleSheet, ScrollView, View, Text, StatusBar } from 'react-native';
19+
import { Button, SafeAreaView, StyleSheet, View, Text, StatusBar } from 'react-native';
2020

2121
import * as tf from '@tensorflow/tfjs';
2222
import '@tensorflow/tfjs-react-native';
2323

2424
import { Diagnostic } from './components/diagnostic';
2525
import { MobilenetDemo } from './components/mobilenet_demo';
2626
import { TestRunner } from './components/tfjs_unit_test_runner';
27+
import { WebcamDemo } from './components/webcam/webcam_demo';
2728

28-
const BACKEND_TO_USE = 'cpu';
29+
const BACKEND_TO_USE = 'rn-webgl';
2930

30-
export type Screen = 'main' | 'diag' | 'demo' | 'test';
31+
export type Screen = 'main' | 'diag' | 'demo' | 'test' | 'webcam';
3132

3233
interface AppState {
3334
isTfReady: boolean;
@@ -46,6 +47,7 @@ export class App extends React.Component<{}, AppState> {
4647
this.showDemoScreen = this.showDemoScreen.bind(this);
4748
this.showMainScreen = this.showMainScreen.bind(this);
4849
this.showTestScreen = this.showTestScreen.bind(this);
50+
this.showWebcamDemo= this.showWebcamDemo.bind(this);
4951
}
5052

5153
async componentDidMount() {
@@ -72,6 +74,10 @@ export class App extends React.Component<{}, AppState> {
7274
this.setState({ currentScreen: 'test' });
7375
}
7476

77+
showWebcamDemo() {
78+
this.setState({ currentScreen: 'webcam' });
79+
}
80+
7581
renderMainScreen() {
7682
return <Fragment>
7783
<View style={styles.sectionContainer}>
@@ -95,6 +101,13 @@ export class App extends React.Component<{}, AppState> {
95101
title='Show Test Screen'
96102
/>
97103
</View>
104+
<View style={styles.sectionContainer}>
105+
<Text style={styles.sectionTitle}>Webcam Demo</Text>
106+
<Button
107+
onPress={this.showWebcamDemo}
108+
title='Show Webcam Demo'
109+
/>
110+
</View>
98111
</Fragment>;
99112
}
100113

@@ -119,6 +132,12 @@ export class App extends React.Component<{}, AppState> {
119132
</Fragment>;
120133
}
121134

135+
renderWebcamDemo() {
136+
return <Fragment>
137+
<WebcamDemo returnToMain={this.showMainScreen}/>
138+
</Fragment>;
139+
}
140+
122141
renderLoadingTF() {
123142
return <Fragment>
124143
<View style={styles.sectionContainer}>
@@ -139,6 +158,8 @@ export class App extends React.Component<{}, AppState> {
139158
return this.renderDemoScreen();
140159
case 'test':
141160
return this.renderTestScreen();
161+
case 'webcam':
162+
return this.renderWebcamDemo();
142163
default:
143164
return this.renderMainScreen();
144165
}
@@ -153,13 +174,9 @@ export class App extends React.Component<{}, AppState> {
153174
<Fragment>
154175
<StatusBar barStyle='dark-content' />
155176
<SafeAreaView>
156-
<ScrollView
157-
contentInsetAdjustmentBehavior='automatic'
158-
style={styles.scrollView}>
159-
<View style={styles.body}>
160-
{this.renderContent()}
161-
</View>
162-
</ScrollView>
177+
<View style={styles.body}>
178+
{this.renderContent()}
179+
</View>
163180
</SafeAreaView>
164181
</Fragment>
165182
);

tfjs-react-native/integration_rn59/android/app/src/main/java/com/integration_rn59/generated/BasePackageList.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ public class BasePackageList {
88
public List<Package> getPackageList() {
99
return Arrays.<Package>asList(
1010
new expo.modules.gl.GLPackage(),
11+
new expo.modules.camera.CameraPackage(),
1112
new expo.modules.constants.ConstantsPackage(),
1213
new expo.modules.filesystem.FileSystemPackage(),
14+
new expo.modules.imagemanipulator.ImageManipulatorPackage(),
1315
new expo.modules.permissions.PermissionsPackage()
1416
);
1517
}

tfjs-react-native/integration_rn59/android/build.gradle

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ allprojects {
4545
maven {
4646
// All of React Native (JS, Obj-C sources, Android binaries) is installed from npm
4747
url "$rootDir/../node_modules/react-native/android"
48+
49+
}
50+
maven {
51+
// expo-camera bundles a custom com.google.android:cameraview
52+
url "$rootDir/../node_modules/expo-camera/android/maven"
4853
}
54+
4955
}
5056
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
import * as ImageManipulator from 'expo-image-manipulator';
20+
import * as jpeg from 'jpeg-js';
21+
22+
export function toDataUri(base64: string): string {
23+
return `data:image/jpeg;base64,${base64}`;
24+
}
25+
26+
export async function resizeImage(
27+
imageUrl: string, width: number): Promise<ImageManipulator.ImageResult> {
28+
const actions = [{
29+
resize: {
30+
width,
31+
},
32+
}];
33+
const saveOptions = {
34+
compress: 0.75,
35+
format: ImageManipulator.SaveFormat.JPEG,
36+
base64: true,
37+
};
38+
const res =
39+
await ImageManipulator.manipulateAsync(imageUrl, actions, saveOptions);
40+
return res;
41+
}
42+
43+
export async function base64ImageToTensor(base64: string):
44+
Promise<tf.Tensor3D> {
45+
const rawImageData = tf.util.encodeString(base64, 'base64');
46+
const TO_UINT8ARRAY = true;
47+
const {width, height, data} = jpeg.decode(rawImageData, TO_UINT8ARRAY);
48+
// Drop the alpha channel info
49+
const buffer = new Uint8Array(width * height * 3);
50+
let offset = 0; // offset into original data
51+
for (let i = 0; i < buffer.length; i += 3) {
52+
buffer[i] = data[offset];
53+
buffer[i + 1] = data[offset + 1];
54+
buffer[i + 2] = data[offset + 2];
55+
56+
offset += 4;
57+
}
58+
return tf.tensor3d(buffer, [height, width, 3]);
59+
}
60+
61+
export async function tensorToImageUrl(imageTensor: tf.Tensor3D):
62+
Promise<string> {
63+
const [height, width] = imageTensor.shape;
64+
const buffer = await imageTensor.toInt().data();
65+
const frameData = new Uint8Array(width * height * 4);
66+
67+
let offset = 0;
68+
for (let i = 0; i < frameData.length; i += 4) {
69+
frameData[i] = buffer[offset];
70+
frameData[i + 1] = buffer[offset + 1];
71+
frameData[i + 2] = buffer[offset + 2];
72+
frameData[i + 3] = 0xFF;
73+
74+
offset += 3;
75+
}
76+
77+
const rawImageData = {
78+
data: frameData,
79+
width,
80+
height,
81+
};
82+
const jpegImageData = jpeg.encode(rawImageData, 75);
83+
const base64Encoding = tf.util.decodeString(jpegImageData.data, 'base64');
84+
return base64Encoding;
85+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
20+
const STYLENET_URL =
21+
'https://cdn.jsdelivr.net/gh/reiinakano/arbitrary-image-stylization-tfjs@master/saved_model_style_js/model.json';
22+
const TRANSFORMNET_URL =
23+
'https://cdn.jsdelivr.net/gh/reiinakano/arbitrary-image-stylization-tfjs@master/saved_model_transformer_separable_js/model.json';
24+
25+
export class StyleTranfer {
26+
private styleNet?: tf.GraphModel;
27+
private transformNet?: tf.GraphModel;
28+
29+
constructor() {}
30+
31+
async init() {
32+
await Promise.all([this.loadStyleModel(), this.loadTransformerModel()]);
33+
await this.warmup();
34+
}
35+
36+
async loadStyleModel() {
37+
if (this.styleNet == null) {
38+
this.styleNet = await tf.loadGraphModel(STYLENET_URL);
39+
console.log('stylenet loaded');
40+
}
41+
}
42+
43+
async loadTransformerModel() {
44+
if (this.transformNet == null) {
45+
this.transformNet = await tf.loadGraphModel(TRANSFORMNET_URL);
46+
console.log('transformnet loaded');
47+
}
48+
}
49+
50+
async warmup() {
51+
// Also warmup
52+
const input: tf.Tensor3D = tf.randomNormal([320, 240, 3]);
53+
const res = this.stylize(input, input);
54+
await res.data();
55+
tf.dispose([input, res]);
56+
}
57+
58+
/**
59+
* This function returns style bottleneck features for
60+
* the given image.
61+
*
62+
* @param style Style image to get 100D bottleneck features for
63+
*/
64+
private predictStyleParameters(styleImage: tf.Tensor3D): tf.Tensor4D {
65+
return tf.tidy(() => {
66+
if (this.styleNet == null) {
67+
throw new Error('Stylenet not loaded');
68+
}
69+
return this.styleNet.predict(
70+
styleImage.toFloat().div(tf.scalar(255)).expandDims());
71+
}) as tf.Tensor4D;
72+
}
73+
74+
/**
75+
* This function stylizes the content image given the bottleneck
76+
* features. It returns a tf.Tensor3D containing the stylized image.
77+
*
78+
* @param content Content image to stylize
79+
* @param bottleneck Bottleneck features for the style to use
80+
*/
81+
private produceStylized(contentImage: tf.Tensor3D, bottleneck: tf.Tensor4D):
82+
tf.Tensor3D {
83+
return tf.tidy(() => {
84+
if (this.transformNet == null) {
85+
throw new Error('Transformnet not loaded');
86+
}
87+
const input = contentImage.toFloat().div(tf.scalar(255)).expandDims();
88+
const image: tf.Tensor4D =
89+
this.transformNet.predict([input, bottleneck]) as tf.Tensor4D;
90+
return image.mul(255).squeeze();
91+
});
92+
}
93+
94+
public stylize(styleImage: tf.Tensor3D, contentImage: tf.Tensor3D):
95+
tf.Tensor3D {
96+
const start = Date.now();
97+
console.log(styleImage.shape, contentImage.shape);
98+
const styleRepresentation = this.predictStyleParameters(styleImage);
99+
const stylized = this.produceStylized(contentImage, styleRepresentation);
100+
tf.dispose([styleRepresentation]);
101+
const end = Date.now();
102+
console.log('stylization scheduled', end - start);
103+
return stylized;
104+
}
105+
}

0 commit comments

Comments
 (0)