@@ -427,10 +427,14 @@ class BlazePoseTfjsDetector implements PoseDetector {
427
427
// activation_heatmap: This tensor (shape: [1, 64, 64, 39]) represents
428
428
// heatmap for the 39 landmarks.
429
429
// world_3d: This tensor (shape: [1, 117]) represents 39 3DWorld keypoints.
430
- const outputTensor = this . landmarkModel . execute ( imageValueShifted , [
431
- 'ld_3d' , 'output_poseflag' , 'activation_segmentation' ,
432
- 'activation_heatmap' , 'world_3d'
433
- ] ) as tf . Tensor [ ] ;
430
+ const outputs =
431
+ [ 'ld_3d' , 'output_poseflag' , 'activation_heatmap' , 'world_3d' ] ;
432
+ if ( this . enableSegmentation ) {
433
+ outputs . push ( 'activation_segmentation' ) ;
434
+ }
435
+
436
+ const outputTensor =
437
+ this . landmarkModel . execute ( imageValueShifted , outputs ) as tf . Tensor [ ] ;
434
438
435
439
// Decodes the tensors into the corresponding landmark and segmentation mask
436
440
// representation.
@@ -543,9 +547,10 @@ class BlazePoseTfjsDetector implements PoseDetector {
543
547
// TensorsToPoseLandmarksAndSegmentation: SplitTensorVectorCalculator.
544
548
const landmarkTensor = tensors [ 0 ] as tf . Tensor2D ,
545
549
poseFlagTensor = tensors [ 1 ] as tf . Tensor2D ,
546
- segmentationTensor = tensors [ 2 ] as tf . Tensor4D ,
547
- heatmapTensor = tensors [ 3 ] as tf . Tensor4D ,
548
- worldLandmarkTensor = tensors [ 4 ] as tf . Tensor2D ;
550
+ heatmapTensor = tensors [ 2 ] as tf . Tensor4D ,
551
+ worldLandmarkTensor = tensors [ 3 ] as tf . Tensor2D ,
552
+ segmentationTensor =
553
+ ( this . enableSegmentation ? tensors [ 4 ] : null ) as tf . Tensor4D ;
549
554
550
555
// Converts the pose-flag tensor into a float that represents the
551
556
// confidence score of pose presence.
0 commit comments