Skip to content

Commit 59ac208

Browse files
authored
Support passing IOHandler to model urls of selfie_segmentation (#935)
1 parent ede445b commit 59ac208

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

body-segmentation/src/selfie_segmentation_tfjs/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ Pass in `bodySegmentation.SupportedModels.MediaPipeSelfieSegmentation` from the
6969
'general', 'landscape'). If unset, the default is 'general'.
7070

7171
* *modelUrl*: An optional string that specifies custom url of
72-
the segmentation model. This is useful for area/countries that don't have access to the model hosted on tf.hub.
72+
the segmentation model. This is useful for area/countries that don't have access to the model hosted on tf.hub. It also accepts `io.IOHandler` which can be used with
73+
[tfjs-react-native](https://github.com/tensorflow/tfjs/tree/master/tfjs-react-native)
74+
to load model from app bundle directory using
75+
[bundleResourceIO](https://github.com/tensorflow/tfjs/blob/master/tfjs-react-native/src/bundle_resource_io.ts#L169).
7376

7477
```javascript
7578
const model = bodySegmentation.SupportedModels.MediaPipeSelfieSegmentation;

body-segmentation/src/selfie_segmentation_tfjs/segmenter.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ export async function load(
160160
Promise<BodySegmenter> {
161161
const config = validateModelConfig(modelConfig);
162162

163-
const modelFromTFHub = (config.modelUrl.indexOf('https://tfhub.dev') > -1);
163+
const modelFromTFHub = typeof config.modelUrl === 'string' &&
164+
(config.modelUrl.indexOf('https://tfhub.dev') > -1);
164165

165166
const model =
166167
await tfconv.loadGraphModel(config.modelUrl, {fromTFHub: modelFromTFHub});

body-segmentation/src/selfie_segmentation_tfjs/types.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18+
import {io} from '@tensorflow/tfjs-core';
1819
import {MediaPipeSelfieSegmentationModelConfig, MediaPipeSelfieSegmentationSegmentationConfig} from '../selfie_segmentation_mediapipe/types';
1920

2021
/**
@@ -35,7 +36,7 @@ import {MediaPipeSelfieSegmentationModelConfig, MediaPipeSelfieSegmentationSegme
3536
export interface MediaPipeSelfieSegmentationTfjsModelConfig extends
3637
MediaPipeSelfieSegmentationModelConfig {
3738
runtime: 'tfjs';
38-
modelUrl?: string;
39+
modelUrl?: string|io.IOHandler;
3940
}
4041

4142
/**

0 commit comments

Comments
 (0)