Skip to content

Commit 59bd40b

Browse files
authored
[pose-detection]Use f16 model. (#708)
FEATURE
1 parent fde8eab commit 59bd40b

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

pose-detection/src/blazepose_tfjs/blazepose_test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {expectArraysClose} from '@tensorflow/tfjs-core/dist/test_util';
2424
import * as poseDetection from '../index';
2525
import {getXYPerFrame, KARMA_SERVER, loadImage, loadVideo} from '../test_util';
2626

27-
const EPSILON_IMAGE = 18;
27+
const EPSILON_IMAGE = 19;
2828
const EPSILON_VIDEO = 15;
2929

3030
// ref:

pose-detection/src/blazepose_tfjs/constants.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
import {BlazePoseTfjsModelConfig} from './types';
1919

2020
export const DEFAULT_BLAZEPOSE_DETECTOR_MODEL_URL =
21-
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/detector/heatmap/model.json';
21+
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/detector/f16/model.json';
2222
export const DEFAULT_BLAZEPOSE_LANDMARK_MODEL_URL_FULL =
23-
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/full/model.json';
23+
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/full-f16/model.json';
2424
export const DEFAULT_BLAZEPOSE_LANDMARK_MODEL_URL_LITE =
25-
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/lite/model.json';
25+
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/lite-f16/model.json';
2626
export const DEFAULT_BLAZEPOSE_LANDMARK_MODEL_URL_HEAVY =
27-
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/heavy/model.json';
27+
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/heavy-f16/model.json';
2828
export const BLAZEPOSE_DETECTOR_ANCHOR_CONFIGURATION = {
2929
reduceBoxesInLowestlayer: false,
3030
interpolatedScaleAspectRatio: 1.0,

pose-detection/src/blazepose_tfjs/detector.ts

+13-13
Original file line numberDiff line numberDiff line change
@@ -324,18 +324,18 @@ export class BlazePoseTfjsDetector extends BasePoseDetector {
324324
// Output[3]: This tensor (shape: [1, 195]) represents 39 5-d keypoints.
325325
// The first 33 refer to the keypoints. The final 6 key points refer to
326326
// the alignment points from the detector model and the hands.)
327-
// Output [0]: This tensor (shape: [1, 1]) represents the confidence
327+
// Output [4]: This tensor (shape: [1, 1]) represents the confidence
328328
// score.
329-
// Output [2]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
329+
// Output [1]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
330330
// the 39 landmarks.
331331
// Lite model:
332-
// Output[1]: This tensor (shape: [1, 195]) represents 39 5-d keypoints.
333-
// Output[2]: This tensor (shape: [1, 1]) represents the confidence score.
334-
// Output[4]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
332+
// Output[4]: This tensor (shape: [1, 195]) represents 39 5-d keypoints.
333+
// Output[3]: This tensor (shape: [1, 1]) represents the confidence score.
334+
// Output[1]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
335335
// the 39 landmarks.
336336
// Heavy model:
337337
// Output[3]: This tensor (shape: [1, 195]) represents 39 5-d keypoints.
338-
// Output[2]: This tensor (shape: [1, 1]) represents the confidence score.
338+
// Output[1]: This tensor (shape: [1, 1]) represents the confidence score.
339339
// Output[4]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
340340
// the 39 landmarks.
341341
const landmarkResult =
@@ -345,18 +345,18 @@ export class BlazePoseTfjsDetector extends BasePoseDetector {
345345

346346
switch (this.modelType) {
347347
case 'lite':
348-
landmarkTensor = landmarkResult[1] as tf.Tensor2D;
349-
poseFlagTensor = landmarkResult[2] as tf.Tensor2D;
350-
heatmapTensor = landmarkResult[4] as tf.Tensor4D;
348+
landmarkTensor = landmarkResult[3] as tf.Tensor2D;
349+
poseFlagTensor = landmarkResult[4] as tf.Tensor2D;
350+
heatmapTensor = landmarkResult[1] as tf.Tensor4D;
351351
break;
352352
case 'full':
353-
landmarkTensor = landmarkResult[3] as tf.Tensor2D;
354-
poseFlagTensor = landmarkResult[0] as tf.Tensor2D;
355-
heatmapTensor = landmarkResult[2] as tf.Tensor4D;
353+
landmarkTensor = landmarkResult[4] as tf.Tensor2D;
354+
poseFlagTensor = landmarkResult[3] as tf.Tensor2D;
355+
heatmapTensor = landmarkResult[1] as tf.Tensor4D;
356356
break;
357357
case 'heavy':
358358
landmarkTensor = landmarkResult[3] as tf.Tensor2D;
359-
poseFlagTensor = landmarkResult[2] as tf.Tensor2D;
359+
poseFlagTensor = landmarkResult[1] as tf.Tensor2D;
360360
heatmapTensor = landmarkResult[4] as tf.Tensor4D;
361361
break;
362362
default:

0 commit comments

Comments
 (0)