Skip to content

Commit 5354e81

Browse files
authored
Remove segmentation output from execution when segmentation is disabled (#937)
1 parent 60a5226 commit 5354e81

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

pose-detection/src/blazepose_tfjs/detector.ts

+12-7
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,14 @@ class BlazePoseTfjsDetector implements PoseDetector {
427427
// activation_heatmap: This tensor (shape: [1, 64, 64, 39]) represents
428428
// heatmap for the 39 landmarks.
429429
// world_3d: This tensor (shape: [1, 117]) represents 39 3DWorld keypoints.
430-
const outputTensor = this.landmarkModel.execute(imageValueShifted, [
431-
'ld_3d', 'output_poseflag', 'activation_segmentation',
432-
'activation_heatmap', 'world_3d'
433-
]) as tf.Tensor[];
430+
const outputs =
431+
['ld_3d', 'output_poseflag', 'activation_heatmap', 'world_3d'];
432+
if (this.enableSegmentation) {
433+
outputs.push('activation_segmentation');
434+
}
435+
436+
const outputTensor =
437+
this.landmarkModel.execute(imageValueShifted, outputs) as tf.Tensor[];
434438

435439
// Decodes the tensors into the corresponding landmark and segmentation mask
436440
// representation.
@@ -543,9 +547,10 @@ class BlazePoseTfjsDetector implements PoseDetector {
543547
// TensorsToPoseLandmarksAndSegmentation: SplitTensorVectorCalculator.
544548
const landmarkTensor = tensors[0] as tf.Tensor2D,
545549
poseFlagTensor = tensors[1] as tf.Tensor2D,
546-
segmentationTensor = tensors[2] as tf.Tensor4D,
547-
heatmapTensor = tensors[3] as tf.Tensor4D,
548-
worldLandmarkTensor = tensors[4] as tf.Tensor2D;
550+
heatmapTensor = tensors[2] as tf.Tensor4D,
551+
worldLandmarkTensor = tensors[3] as tf.Tensor2D,
552+
segmentationTensor =
553+
(this.enableSegmentation ? tensors[4] : null) as tf.Tensor4D;
549554

550555
// Converts the pose-flag tensor into a float that represents the
551556
// confidence score of pose presence.

0 commit comments

Comments
 (0)