Skip to content

Commit 45bf4cf

Browse files
authored
Add lite hands model support to TFJS API (#848)
* Add lite hands model support to TFJS API * Move config model type check to validator
1 parent 2cbf343 commit 45bf4cf

File tree

5 files changed

+77
-50
lines changed

5 files changed

+77
-50
lines changed

hand-detection/src/tfjs/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ Pass in `handDetection.SupportedModels.MediaPipeHands` from the
6262

6363
* *maxHands*: Defaults to 2. The maximum number of hands that will be detected by the model. The number of returned hands can be less than the maximum (for example when no hands are present in the input).
6464

65+
* *modelType*: specify which variant to load from `MediaPipeHandsModelType` (i.e.,
66+
'lite', 'full'). If unset, the default is 'full'.
67+
6568
* *detectorModelUrl*: An optional string that specifies custom url of
6669
the detector model. This is useful for area/countries that don't have access to the model hosted on tf.hub.
6770
* *landmarkModelUrl* An optional string that specifies custom url of

hand-detection/src/tfjs/constants.ts

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
* =============================================================================
1616
*/
1717

18-
import { AnchorConfig, ImageToTensorConfig, RectTransformationConfig, TensorsToDetectionsConfig, TensorsToLandmarksConfig } from '../shared/calculators/interfaces/config_interfaces';
19-
import { MediaPipeHandsTfjsEstimationConfig, MediaPipeHandsTfjsModelConfig } from './types';
18+
import {AnchorConfig, ImageToTensorConfig, RectTransformationConfig, TensorsToDetectionsConfig, TensorsToLandmarksConfig} from '../shared/calculators/interfaces/config_interfaces';
19+
import {MediaPipeHandsTfjsEstimationConfig, MediaPipeHandsTfjsModelConfig} from './types';
2020

2121
export const DEFAULT_MPHANDS_DETECTOR_MODEL_URL =
22-
'https://storage.googleapis.com/tfjs-testing/hand-detection/handdetector/model.json';
23-
export const DEFAULT_MPHANDS_LANDMARK_MODEL_URL =
24-
'https://storage.googleapis.com/tfjs-testing/hand-detection/handskeleton/model.json';
22+
'https://storage.googleapis.com/tfjs-testing/hand-detection/handdetector/model.json';
23+
export const DEFAULT_MPHANDS_LANDMARK_MODEL_URL_LITE =
24+
'https://storage.googleapis.com/tfjs-testing/hand-detection/handskeleton_lite/model.json';
25+
export const DEFAULT_MPHANDS_LANDMARK_MODEL_URL_FULL =
26+
'https://storage.googleapis.com/tfjs-testing/hand-detection/handskeleton_full/model.json';
2527
export const MPHANDS_DETECTOR_ANCHOR_CONFIGURATION: AnchorConfig = {
2628
reduceBoxesInLowestLayer: false,
2729
interpolatedScaleAspectRatio: 1.0,
@@ -40,62 +42,63 @@ export const MPHANDS_DETECTOR_ANCHOR_CONFIGURATION: AnchorConfig = {
4042
};
4143
export const DEFAULT_MPHANDS_MODEL_CONFIG: MediaPipeHandsTfjsModelConfig = {
4244
runtime: 'tfjs',
45+
modelType: 'full',
4346
maxHands: 2,
4447
detectorModelUrl: DEFAULT_MPHANDS_DETECTOR_MODEL_URL,
45-
landmarkModelUrl: DEFAULT_MPHANDS_LANDMARK_MODEL_URL
48+
landmarkModelUrl: DEFAULT_MPHANDS_LANDMARK_MODEL_URL_FULL
4649
};
4750
export const DEFAULT_MPHANDS_ESTIMATION_CONFIG:
48-
MediaPipeHandsTfjsEstimationConfig = {
49-
flipHorizontal: false,
50-
staticImageMode: false
51-
};
51+
MediaPipeHandsTfjsEstimationConfig = {
52+
flipHorizontal: false,
53+
staticImageMode: false
54+
};
5255
export const MPHANDS_TENSORS_TO_DETECTION_CONFIGURATION:
53-
TensorsToDetectionsConfig = {
54-
applyExponentialOnBoxSize: false,
55-
flipVertically: false,
56-
ignoreClasses: [] as number[],
57-
numClasses: 1,
58-
numBoxes: 896,
59-
numCoords: 18,
60-
boxCoordOffset: 0,
61-
keypointCoordOffset: 4,
62-
numKeypoints: 7,
63-
numValuesPerKeypoint: 2,
64-
sigmoidScore: true,
65-
scoreClippingThresh: 100.0,
66-
reverseOutputOrder: true,
67-
xScale: 128.0,
68-
yScale: 128.0,
69-
hScale: 128.0,
70-
wScale: 128.0,
71-
minScoreThresh: 0.5
72-
};
56+
TensorsToDetectionsConfig = {
57+
applyExponentialOnBoxSize: false,
58+
flipVertically: false,
59+
ignoreClasses: [] as number[],
60+
numClasses: 1,
61+
numBoxes: 896,
62+
numCoords: 18,
63+
boxCoordOffset: 0,
64+
keypointCoordOffset: 4,
65+
numKeypoints: 7,
66+
numValuesPerKeypoint: 2,
67+
sigmoidScore: true,
68+
scoreClippingThresh: 100.0,
69+
reverseOutputOrder: true,
70+
xScale: 128.0,
71+
yScale: 128.0,
72+
hScale: 128.0,
73+
wScale: 128.0,
74+
minScoreThresh: 0.5
75+
};
7376
export const MPHANDS_DETECTOR_NON_MAX_SUPPRESSION_CONFIGURATION = {
7477
minScoreThreshold: -1.0,
7578
minSuppressionThreshold: 0.3
7679
};
7780
export const MPHANDS_DETECTOR_RECT_TRANSFORMATION_CONFIG:
78-
RectTransformationConfig = {
79-
shiftX: 0,
80-
shiftY: -0.5,
81-
scaleX: 2.6,
82-
scaleY: 2.6,
83-
squareLong: true
84-
};
81+
RectTransformationConfig = {
82+
shiftX: 0,
83+
shiftY: -0.5,
84+
scaleX: 2.6,
85+
scaleY: 2.6,
86+
squareLong: true
87+
};
8588
export const MPHANDS_LANDMARK_RECT_TRANSFORMATION_CONFIG:
86-
RectTransformationConfig = {
87-
shiftX: 0,
88-
shiftY: -0.1,
89-
scaleX: 2.0,
90-
scaleY: 2.0,
91-
squareLong: true
92-
};
89+
RectTransformationConfig = {
90+
shiftX: 0,
91+
shiftY: -0.1,
92+
scaleX: 2.0,
93+
scaleY: 2.0,
94+
squareLong: true
95+
};
9396
export const MPHANDS_DETECTOR_IMAGE_TO_TENSOR_CONFIG: ImageToTensorConfig = {
94-
inputResolution: { width: 128, height: 128 },
97+
inputResolution: {width: 128, height: 128},
9598
keepAspectRatio: true
9699
};
97100
export const MPHANDS_LANDMARK_IMAGE_TO_TENSOR_CONFIG: ImageToTensorConfig = {
98-
inputResolution: { width: 224, height: 224 },
101+
inputResolution: {width: 224, height: 224},
99102
keepAspectRatio: true
100103
};
101104
export const MPHANDS_HAND_PRESENCE_SCORE = 0.5;

hand-detection/src/tfjs/detector.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,14 @@ class MediaPipeHandsTfjsDetector implements HandDetector {
301301
// outputs a list of tensors representing, for instance, detection
302302
// boxes/keypoints and scores.
303303
// The model returns 3 tensors with the following shape:
304-
// conv_landmarks: This tensor (shape: [1, 63]) represents 21 3-d
304+
// Identity_2:0: This tensor (shape: [1, 63]) represents 21 3-d
305305
// keypoints.
306306
// Identity_1:0: This tensor (shape: [1, 1]) represents the
307307
// confidence score of the presence of a hand.
308308
// Identity:0: This tensor (shape: [1, 1]) represents the classication
309309
// score of handedness
310310
const landmarkResult = this.landmarkModel.execute(imageValueShifted, [
311-
'conv_landmarks', 'Identity_1:0', 'Identity:0'
311+
'Identity_2:0', 'Identity_1:0', 'Identity:0'
312312
]) as tf.Tensor[];
313313

314314
const landmarkTensor = landmarkResult[0] as tf.Tensor2D,

hand-detection/src/tfjs/detector_utils.ts

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18-
import {DEFAULT_MPHANDS_ESTIMATION_CONFIG, DEFAULT_MPHANDS_MODEL_CONFIG} from './constants';
18+
import {DEFAULT_MPHANDS_ESTIMATION_CONFIG, DEFAULT_MPHANDS_LANDMARK_MODEL_URL_FULL, DEFAULT_MPHANDS_LANDMARK_MODEL_URL_LITE, DEFAULT_MPHANDS_MODEL_CONFIG} from './constants';
1919
import {MediaPipeHandsTfjsEstimationConfig, MediaPipeHandsTfjsModelConfig} from './types';
2020

2121
export function validateModelConfig(modelConfig: MediaPipeHandsTfjsModelConfig):
@@ -32,12 +32,29 @@ export function validateModelConfig(modelConfig: MediaPipeHandsTfjsModelConfig):
3232
config.maxHands = DEFAULT_MPHANDS_MODEL_CONFIG.maxHands;
3333
}
3434

35+
if (config.modelType == null) {
36+
config.modelType = DEFAULT_MPHANDS_MODEL_CONFIG.modelType;
37+
}
38+
39+
if (config.modelType !== 'lite' && config.modelType !== 'full') {
40+
throw new Error(
41+
`Model type must be one of lite or full, but got ${config.modelType}`);
42+
}
43+
3544
if (config.detectorModelUrl == null) {
3645
config.detectorModelUrl = DEFAULT_MPHANDS_MODEL_CONFIG.detectorModelUrl;
3746
}
3847

3948
if (config.landmarkModelUrl == null) {
40-
config.landmarkModelUrl = DEFAULT_MPHANDS_MODEL_CONFIG.landmarkModelUrl;
49+
switch (config.modelType) {
50+
case 'lite':
51+
config.landmarkModelUrl = DEFAULT_MPHANDS_LANDMARK_MODEL_URL_LITE;
52+
break;
53+
case 'full':
54+
default:
55+
config.landmarkModelUrl = DEFAULT_MPHANDS_LANDMARK_MODEL_URL_FULL;
56+
break;
57+
}
4158
}
4259

4360
return config;

hand-detection/src/tfjs/types.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ import {MediaPipeHandsEstimationConfig, MediaPipeHandsModelConfig} from '../medi
2222
*
2323
* `runtime`: Must set to be 'tfjs'.
2424
*
25+
* `modelType`: Optional. Possible values: 'lite'|'full'. Defaults to
26+
* 'full'. Landmark accuracy as well as inference latency generally go up with
27+
* the increasing model complexity (lite to full).
28+
*
2529
* `detectorModelUrl`: Optional. An optional string that specifies custom url of
2630
* the detector model. This is useful for area/countries that don't have access
2731
* to the model hosted on tf.hub.

0 commit comments

Comments
 (0)