@@ -427,10 +427,14 @@ class BlazePoseTfjsDetector implements PoseDetector {
427427 // activation_heatmap: This tensor (shape: [1, 64, 64, 39]) represents
428428 // heatmap for the 39 landmarks.
429429 // world_3d: This tensor (shape: [1, 117]) represents 39 3DWorld keypoints.
430- const outputTensor = this . landmarkModel . execute ( imageValueShifted , [
431- 'ld_3d' , 'output_poseflag' , 'activation_segmentation' ,
432- 'activation_heatmap' , 'world_3d'
433- ] ) as tf . Tensor [ ] ;
430+ const outputs =
431+ [ 'ld_3d' , 'output_poseflag' , 'activation_heatmap' , 'world_3d' ] ;
432+ if ( this . enableSegmentation ) {
433+ outputs . push ( 'activation_segmentation' ) ;
434+ }
435+
436+ const outputTensor =
437+ this . landmarkModel . execute ( imageValueShifted , outputs ) as tf . Tensor [ ] ;
434438
435439 // Decodes the tensors into the corresponding landmark and segmentation mask
436440 // representation.
@@ -543,9 +547,10 @@ class BlazePoseTfjsDetector implements PoseDetector {
543547 // TensorsToPoseLandmarksAndSegmentation: SplitTensorVectorCalculator.
544548 const landmarkTensor = tensors [ 0 ] as tf . Tensor2D ,
545549 poseFlagTensor = tensors [ 1 ] as tf . Tensor2D ,
546- segmentationTensor = tensors [ 2 ] as tf . Tensor4D ,
547- heatmapTensor = tensors [ 3 ] as tf . Tensor4D ,
548- worldLandmarkTensor = tensors [ 4 ] as tf . Tensor2D ;
550+ heatmapTensor = tensors [ 2 ] as tf . Tensor4D ,
551+ worldLandmarkTensor = tensors [ 3 ] as tf . Tensor2D ,
552+ segmentationTensor =
553+ ( this . enableSegmentation ? tensors [ 4 ] : null ) as tf . Tensor4D ;
549554
550555 // Converts the pose-flag tensor into a float that represents the
551556 // confidence score of pose presence.
0 commit comments