Skip to content

Commit ab97e06

Browse files
Add object detection and image segmentation models to the task API (#705)
* Add more models * update * update * update * test * test * test * test * test * test * test * test * test * test * update * update * address comments
1 parent 59bd40b commit ab97e06

25 files changed

+7353
-106
lines changed

coco-ssd/src/index.ts

+11-8
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ const BASE_PATH = 'https://storage.googleapis.com/tfjs-models/savedmodel/';
2424

2525
export {version} from './version';
2626

27+
/** @docinline */
2728
export type ObjectDetectionBaseModel =
2829
'mobilenet_v1'|'mobilenet_v2'|'lite_mobilenet_v2';
2930

@@ -35,17 +36,19 @@ export interface DetectedObject {
3536

3637
/**
3738
* Coco-ssd model loading is configurable using the following config dictionary.
38-
*
39-
* `base`: ObjectDetectionBaseModel. It determines wich PoseNet architecture
40-
* to load. The supported architectures are: 'mobilenet_v1', 'mobilenet_v2' and
41-
* 'lite_mobilenet_v2'. It is default to 'lite_mobilenet_v2'.
42-
*
43-
* `modelUrl`: An optional string that specifies custom url of the model. This
44-
* is useful for area/countries that don't have access to the model hosted on
45-
* GCP.
4639
*/
4740
export interface ModelConfig {
41+
/**
42+
* It determines wich object detection architecture to load. The supported
43+
* architectures are: 'mobilenet_v1', 'mobilenet_v2' and 'lite_mobilenet_v2'.
44+
* It is default to 'lite_mobilenet_v2'.
45+
*/
4846
base?: ObjectDetectionBaseModel;
47+
/**
48+
*
49+
* An optional string that specifies custom url of the model. This is useful
50+
* for area/countries that don't have access to the model hosted on GCP.
51+
*/
4952
modelUrl?: string;
5053
}
5154

coco-ssd/yarn.lock

+12-12
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,23 @@
6464
estree-walker "^1.0.1"
6565
picomatch "^2.2.2"
6666

67-
"@tensorflow/tfjs-backend-cpu@^3.0.0-rc.1":
68-
version "3.3.0"
69-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-3.3.0.tgz#aa0a3ed2c6237a6e0c169678c5bd4b5a88766b1c"
70-
integrity sha512-DLctv+PUZni26kQW1hq8jwQQ8u+GGc/p764WQIC4/IDagGtfGAUW1mHzWcTxtni2l4re1VrwE41ogWLhv4sGHg==
67+
"@tensorflow/tfjs-backend-cpu@^3.3.0":
68+
version "3.6.0"
69+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-3.6.0.tgz#4e64a7cf1c33b203f71f8f77cd7b0ac1ef25a871"
70+
integrity sha512-ZpAs17hPdKXadbtNjAsymYUILe8V7+pY4fYo8j25nfDTW/HfBpyAwsHPbMcA/n5zyJ7ZJtGKFcCUv1sl24KL1Q==
7171
dependencies:
7272
"@types/seedrandom" "2.4.27"
7373
seedrandom "2.4.3"
7474

75-
"@tensorflow/tfjs-converter@^3.0.0-rc.1":
76-
version "3.3.0"
77-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-3.3.0.tgz#d9f2ffd0fbdbb47c07d5fd7c3e5dc180cff317aa"
78-
integrity sha512-k57wN4yelePhmO9orcT/wzGMIuyedrMpVtg0FhxpV6BQu0+TZ/ti3W4Kb97GWJsoHKXMoing9SnioKfVnBW6hw==
75+
"@tensorflow/tfjs-converter@^3.3.0":
76+
version "3.6.0"
77+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-3.6.0.tgz#32b3ff31b47e29630a82e30fbe01708facad7fd6"
78+
integrity sha512-9MtatbTSvo3gpEulYI6+byTA3OeXSMT2lzyGAegXO9nMxsvjR01zBvlZ5SmsNyecNh6fMSzdL2+cCdQfQtsIBg==
7979

80-
"@tensorflow/tfjs-core@^3.0.0-rc.1":
81-
version "3.3.0"
82-
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-3.3.0.tgz#3d26bd03cb58e0ecf46c96d118c39c4a90b7f5ed"
83-
integrity sha512-6G+LcCiQBl4Kza5mDbWbf8QSWBTW3l7SDjGhQzMO1ITtQatHzxkuHGHcJ4CTUJvNA0JmKf4QJWOvlFqEmxwyLQ==
80+
"@tensorflow/tfjs-core@^3.3.0":
81+
version "3.6.0"
82+
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-3.6.0.tgz#6b4d8175790bdff78868eabe6adc6442eb4dc276"
83+
integrity sha512-bb2c3zwK4SgXZRvkTiC7EhCpWbCGp0GMd+1/3Vo2/Z54jiLB/h3sXIgHQrTNiWwhKPtst/xxA+MsslFlvD0A5w==
8484
dependencies:
8585
"@types/offscreencanvas" "~2019.3.0"
8686
"@types/seedrandom" "2.4.27"

deeplab/demo/yarn.lock

+6,063
Large diffs are not rendered by default.

deeplab/src/index.ts

+8-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,14 @@ import {DeepLabInput, DeepLabOutput, ModelArchitecture, ModelConfig, PredictionC
2222
import {getColormap, getLabels, getURL, toInputTensor, toSegmentationImage} from './utils';
2323

2424
export {version} from './version';
25-
export {getColormap, getLabels, getURL, toSegmentationImage};
25+
export {
26+
getColormap,
27+
getLabels,
28+
getURL,
29+
ModelConfig,
30+
PredictionConfig,
31+
toSegmentationImage
32+
};
2633

2734
/**
2835
* Initializes the DeepLab model and returns a `SemanticSegmentation` object.

deeplab/src/types.ts

+35-40
Original file line numberDiff line numberDiff line change
@@ -34,67 +34,62 @@ export interface Legend {
3434
[name: string]: Color;
3535
}
3636

37-
/*
38-
The model supports quantization to 1 and 2 bytes, leaving 4 for the
39-
non-quantized variant.
40-
*/
37+
/**
38+
* The model supports quantization to 1 and 2 bytes, leaving 4 for the
39+
* non-quantized variant.
40+
*
41+
* @docinline
42+
*/
4143
export type QuantizationBytes = 1|2|4;
42-
/*
43-
Three types of pre-trained weights are available, trained on Pascal, Cityscapes
44-
and ADE20K datasets. Each dataset has its own colormap and labelling scheme.
45-
*/
44+
/**
45+
* Three types of pre-trained weights are available, trained on Pascal,
46+
* Cityscapes and ADE20K datasets. Each dataset has its own colormap and
47+
* labelling scheme.
48+
*
49+
* @docinline
50+
*/
4651
export type ModelArchitecture = 'pascal'|'cityscapes'|'ade20k';
4752

4853
export type DeepLabInput =
4954
|ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement|tf.Tensor3D;
5055

5156
/*
5257
* The model can be configured with any of the following attributes:
53-
*
54-
* * quantizationBytes (optional) :: `QuantizationBytes`
55-
*
56-
* The degree to which weights are quantized (either 1, 2 or 4).
57-
* Setting this attribute to 1 or 2 will load the model with int32 and
58-
* float32 compressed to 1 or 2 bytes respectively.
59-
* Set it to 4 to disable quantization.
60-
*
61-
* * base (optional) :: `ModelArchitecture`
62-
*
63-
* The type of model to load (either `pascal`, `cityscapes` or `ade20k`).
64-
*
65-
* * modelUrl (optional) :: `string`
66-
*
67-
* The URL from which to load the TF.js GraphModel JSON.
68-
* Inferred from `base` and `quantizationBytes` if undefined.
6958
*/
7059
export interface ModelConfig {
60+
/**
61+
* The degree to which weights are quantized (either 1, 2 or 4).
62+
* Setting this attribute to 1 or 2 will load the model with int32 and
63+
* float32 compressed to 1 or 2 bytes respectively.
64+
* Set it to 4 to disable quantization.
65+
*/
7166
quantizationBytes?: QuantizationBytes;
67+
/**
68+
* The type of model to load (either `pascal`, `cityscapes` or `ade20k`).
69+
*/
7270
base?: ModelArchitecture;
71+
/**
72+
*
73+
* The URL from which to load the TF.js GraphModel JSON.
74+
* Inferred from `base` and `quantizationBytes` if undefined.
75+
*/
7376
modelUrl?: string;
7477
}
7578

7679
/*
77-
*
7880
* Segmentation can be fine-tuned with three parameters:
79-
*
80-
* - **canvas** (optional) :: `HTMLCanvasElement`
81-
*
82-
* The canvas where to draw the output
83-
*
84-
* - **colormap** (optional) :: `[number, number, number][]`
85-
*
86-
* The array of RGB colors corresponding to labels
87-
*
88-
* - **labels** (optional) :: `string[]`
89-
*
90-
* The array of names corresponding to labels
91-
*
92-
* By [default](./src/index.ts#L81), `colormap` and `labels` are set
93-
* according to the `base` model attribute passed during initialization.
9481
*/
9582
export interface PredictionConfig {
83+
/** The canvas where to draw the output. */
9684
canvas?: HTMLCanvasElement;
85+
/** The array of RGB colors corresponding to labels. */
9786
colormap?: Color[];
87+
/**
88+
* The array of names corresponding to labels.
89+
*
90+
* By [default](./src/index.ts#L81), `colormap` and `labels` are set
91+
* according to the `base` model attribute passed during initialization.
92+
*/
9893
labels?: string[];
9994
}
10095

tasks/README.md

+124-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ for JS developers without ML knowledge. It has the following features:
99
- **Easy-to-discover models**
1010

1111
Models from different runtime systems (e.g. [TFJS][tfjs], [TFLite][tflite],
12-
[MediaPipe][mediapipe], etc) are grouped by popular ML tasks, such as.
12+
[MediaPipe][mediapipe], etc) are grouped by popular ML tasks, such as
1313
sentiment detection, image classification, pose detection, etc.
1414

1515
- **Clean and powerful APIs**
@@ -28,7 +28,128 @@ for JS developers without ML knowledge. It has the following features:
2828

2929
The following table summarizes all the supported tasks and their models:
3030

31-
(TODO)
31+
<table>
32+
<thead>
33+
<tr>
34+
<th>Task</th>
35+
<th>Model</th>
36+
<th>Supported runtimes · Docs · Resources</th>
37+
</tr>
38+
</thead>
39+
<tbody>
40+
<!-- Image classification -->
41+
<tr>
42+
<td rowspan="2">
43+
<b>Image Classification</b>
44+
<br>
45+
Identify images into predefined classes.
46+
<br>
47+
<a href="https://codepen.io/jinjingforever/pen/VwPOePq">Demo</a>
48+
</td>
49+
<td>Mobilenet</td>
50+
<td>
51+
<div>
52+
<span><code>TFJS  </code></span>
53+
<span>·</span>
54+
<a href="#">API doc</a>
55+
</div>
56+
<div>
57+
<span><code>TFLite</code></span>
58+
<span>·</span>
59+
<a href="#">API doc</a>
60+
</div>
61+
</td>
62+
</tr>
63+
<tr>
64+
<td>Custom model</td>
65+
<td>
66+
<div>
67+
<span><code>TFLite</code></span>
68+
<span>·</span>
69+
<a href="#">API doc</a>
70+
<span>·</span>
71+
<a href="https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier#model_compatibility_requirements">Model requirements</a>
72+
<span>·</span>
73+
<a href="https://tfhub.dev/tensorflow/collections/lite/task-library/image-classifier/1">Model collection</a>
74+
</div>
75+
</td>
76+
</tr>
77+
<!-- Object detection -->
78+
<tr>
79+
<td rowspan="2">
80+
<b>Object Detection</b>
81+
<br>
82+
Localize and identify multiple objects in a single image.
83+
<br>
84+
<a href="https://codepen.io/jinjingforever/pen/PopPPXo">Demo</a>
85+
</td>
86+
<td>COCO-SSD</td>
87+
<td>
88+
<div>
89+
<span><code>TFJS  </code></span>
90+
<span>·</span>
91+
<a href="#">API doc</a>
92+
</div>
93+
<div>
94+
<span><code>TFLite</code></span>
95+
<span>·</span>
96+
<a href="#">API doc</a>
97+
</div>
98+
</td>
99+
</tr>
100+
<tr>
101+
<td>Custom model</td>
102+
<td>
103+
<div>
104+
<span><code>TFLite</code></span>
105+
<span>·</span>
106+
<a href="#">API doc</a>
107+
<span>·</span>
108+
<a href="https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector#model_compatibility_requirements">Model requirements</a>
109+
<span>·</span>
110+
<a href="https://tfhub.dev/tensorflow/collections/lite/task-library/object-detector/1">Model collection</a>
111+
</div>
112+
</td>
113+
</tr>
114+
<!-- Image Segmentation -->
115+
<tr>
116+
<td rowspan="2">
117+
<b>Image Segmentation</b>
118+
<br>
119+
Predict associated class for each pixel of an image.
120+
<br>
121+
<a href="https://codepen.io/jinjingforever/pen/yLMYVJw">Demo</a>
122+
</td>
123+
<td>Deeplab</td>
124+
<td>
125+
<div>
126+
<span><code>TFJS  </code></span>
127+
<span>·</span>
128+
<a href="#">API doc</a>
129+
</div>
130+
<div>
131+
<span><code>TFLite</code></span>
132+
<span>·</span>
133+
<a href="#">API doc</a>
134+
</div>
135+
</td>
136+
</tr>
137+
<tr>
138+
<td>Custom model</td>
139+
<td>
140+
<div>
141+
<span><code>TFLite</code></span>
142+
<span>·</span>
143+
<a href="#">API doc</a>
144+
<span>·</span>
145+
<a href="https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_segmenter#model_compatibility_requirements">Model requirements</a>
146+
<span>·</span>
147+
<a href="https://tfhub.dev/tensorflow/collections/lite/task-library/image-segmenter/1">Model collection</a>
148+
</div>
149+
</td>
150+
</tr>
151+
</tbody>
152+
</table>
32153

33154
(The initial version only supports the web browser environment. NodeJS support is
34155
coming soon)
@@ -78,7 +199,7 @@ const model3 = await tfTask.ImageClassification.CustomModel.TFLite.load({
78199
Since all these models are for the `Image Classification` task, they will have
79200
the same task model type: [`ImageClassifier`][image classifier interface] in
80201
this case. Each task model's `predict` inference method has an unique and
81-
easy-to-use API interface. For example, in `ImageClassiier`, the method takes an
202+
easy-to-use API interface. For example, in `ImageClassifier`, the method takes an
82203
image-like element and returns the predicted classes:
83204

84205
```js

tasks/package.json

+3-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
"@tensorflow/tfjs-converter": "^3.5.0",
2828
"@tensorflow/tfjs-core": "^3.5.0",
2929
"@tensorflow/tfjs-tflite": "0.0.1-alpha.3",
30-
"@tensorflow-models/mobilenet": "^2.1.0",
30+
"@tensorflow-models/mobilenet": "link:../mobilenet",
31+
"@tensorflow-models/coco-ssd": "link:../coco-ssd",
32+
"@tensorflow-models/deeplab": "link:../deeplab",
3133
"@types/jasmine": "~3.6.9",
3234
"clang-format": "~1.5.0",
3335
"jasmine": "~3.7.0",

tasks/src/tasks/all_tasks.ts

+26
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ import {Runtime, Task} from './common';
2020
import {imageClassificationCustomModelTfliteLoader} from './image_classification/custom_model_tflite';
2121
import {mobilenetTfjsLoader} from './image_classification/mobilenet_tfjs';
2222
import {mobilenetTfliteLoader} from './image_classification/mobilenet_tflite';
23+
import {imageSegmenterCustomModelTfliteLoader} from './image_segmentation/custom_model_tflite';
24+
import {deeplabTfjsLoader} from './image_segmentation/deeplab_tfjs';
25+
import {deeplabTfliteLoader} from './image_segmentation/deeplab_tflite';
26+
import {cocoSsdTfjsLoader} from './object_detection/cocossd_tfjs';
27+
import {cocoSsdTfliteLoader} from './object_detection/cocossd_tflite';
28+
import {objectDetectorCustomModelTfliteLoader} from './object_detection/custom_model_tflite';
2329

2430
/**
2531
* The main model index.
@@ -48,11 +54,31 @@ const modelIndex = {
4854
[Runtime.TFLITE]: imageClassificationCustomModelTfliteLoader,
4955
},
5056
},
57+
[Task.OBJECT_DETECTION]: {
58+
CocoSsd: {
59+
[Runtime.TFJS]: cocoSsdTfjsLoader,
60+
[Runtime.TFLITE]: cocoSsdTfliteLoader,
61+
},
62+
CustomModel: {
63+
[Runtime.TFLITE]: objectDetectorCustomModelTfliteLoader,
64+
},
65+
},
66+
[Task.IMAGE_SEGMENTATION]: {
67+
Deeplab: {
68+
[Runtime.TFJS]: deeplabTfjsLoader,
69+
[Runtime.TFLITE]: deeplabTfliteLoader,
70+
},
71+
CustomModel: {
72+
[Runtime.TFLITE]: imageSegmenterCustomModelTfliteLoader,
73+
},
74+
},
5175
};
5276

5377
// Export each task individually.
5478

5579
export const ImageClassification = modelIndex[Task.IMAGE_CLASSIFICATION];
80+
export const ObjectDetection = modelIndex[Task.OBJECT_DETECTION];
81+
export const ImageSegmentation = modelIndex[Task.IMAGE_SEGMENTATION];
5682

5783
/**
5884
* Filter model loaders by runtimes.

0 commit comments

Comments
 (0)