@@ -591,14 +591,36 @@ def build(
591591 ),
592592 )
593593
594+ dataset_root = self .dataset_root
595+ mask_annotation = frame_annotation .mask
596+ depth_annotation = frame_annotation .depth
597+ image_path : str | None = None
598+ mask_path : str | None = None
599+ depth_path : str | None = None
600+ pcl_path : str | None = None
601+ if dataset_root is not None : # set all paths even if we won’t load blobs
602+ if frame_annotation .image .path is not None :
603+ image_path = os .path .join (dataset_root , frame_annotation .image .path )
604+ frame_data .image_path = image_path
605+
606+ if mask_annotation is not None and mask_annotation .path :
607+ mask_path = os .path .join (dataset_root , mask_annotation .path )
608+ frame_data .mask_path = mask_path
609+
610+ if depth_annotation is not None and depth_annotation .path is not None :
611+ depth_path = os .path .join (dataset_root , depth_annotation .path )
612+ frame_data .depth_path = depth_path
613+
614+ if point_cloud is not None :
615+ pcl_path = os .path .join (dataset_root , point_cloud .path )
616+ frame_data .sequence_point_cloud_path = pcl_path
617+
594618 fg_mask_np : np .ndarray | None = None
595619 bbox_xywh : tuple [float , float , float , float ] | None = None
596- mask_annotation = frame_annotation .mask
597620
598621 if mask_annotation is not None :
599- if load_blobs and self .load_masks :
600- fg_mask_np , mask_path = self ._load_fg_probability (frame_annotation )
601- frame_data .mask_path = mask_path
622+ if load_blobs and self .load_masks and mask_path :
623+ fg_mask_np = self ._load_fg_probability (frame_annotation , mask_path )
602624 frame_data .fg_probability = safe_as_tensor (fg_mask_np , torch .float )
603625
604626 bbox_xywh = mask_annotation .bounding_box_xywh
@@ -608,11 +630,6 @@ def build(
608630 frame_data .image_size_hw = image_size_hw # original image size
609631 # image size after crop/resize
610632 frame_data .effective_image_size_hw = image_size_hw
611- image_path = None
612- dataset_root = self .dataset_root
613- if frame_annotation .image .path is not None and dataset_root is not None :
614- image_path = os .path .join (dataset_root , frame_annotation .image .path )
615- frame_data .image_path = image_path
616633
617634 if load_blobs and self .load_images :
618635 if image_path is None :
@@ -639,25 +656,16 @@ def build(
639656 bbox_xywh = get_bbox_from_mask (fg_mask_np , self .box_crop_mask_thr )
640657 frame_data .bbox_xywh = safe_as_tensor (bbox_xywh , torch .float )
641658
642- depth_annotation = frame_annotation .depth
643- if (
644- load_blobs
645- and self .load_depths
646- and depth_annotation is not None
647- and depth_annotation .path is not None
648- ):
649- (
650- frame_data .depth_map ,
651- frame_data .depth_path ,
652- frame_data .depth_mask ,
653- ) = self ._load_mask_depth (frame_annotation , fg_mask_np )
659+ if load_blobs and self .load_depths and depth_path is not None :
660+ frame_data .depth_map , frame_data .depth_mask = self ._load_mask_depth (
661+ frame_annotation , depth_path , fg_mask_np
662+ )
654663
655664 if load_blobs and self .load_point_clouds and point_cloud is not None :
656- pcl_path = self . _fix_point_cloud_path ( point_cloud . path )
665+ assert pcl_path is not None
657666 frame_data .sequence_point_cloud = load_pointcloud (
658667 self ._local_path (pcl_path ), max_points = self .max_points
659668 )
660- frame_data .sequence_point_cloud_path = pcl_path
661669
662670 if frame_annotation .viewpoint is not None :
663671 frame_data .camera = self ._get_pytorch3d_camera (frame_annotation )
@@ -673,17 +681,14 @@ def build(
673681
674682 return frame_data
675683
676- def _load_fg_probability (self , entry : FrameAnnotationT ) -> Tuple [np .ndarray , str ]:
677- mask_annotation = entry .mask
678- assert self .dataset_root is not None and mask_annotation is not None
679- full_path = os .path .join (self .dataset_root , mask_annotation .path )
680- fg_probability = load_mask (self ._local_path (full_path ))
684+ def _load_fg_probability (self , entry : FrameAnnotationT , path : str ) -> np .ndarray :
685+ fg_probability = load_mask (self ._local_path (path ))
681686 if fg_probability .shape [- 2 :] != entry .image .size :
682687 raise ValueError (
683688 f"bad mask size: { fg_probability .shape [- 2 :]} vs { entry .image .size } !"
684689 )
685690
686- return fg_probability , full_path
691+ return fg_probability
687692
688693 def _postprocess_image (
689694 self ,
@@ -705,13 +710,13 @@ def _postprocess_image(
705710 def _load_mask_depth (
706711 self ,
707712 entry : FrameAnnotationT ,
713+ path : str ,
708714 fg_mask : Optional [np .ndarray ],
709- ) -> Tuple [torch .Tensor , str , torch .Tensor ]:
715+ ) -> tuple [torch .Tensor , torch .Tensor ]:
710716 entry_depth = entry .depth
711717 dataset_root = self .dataset_root
712718 assert dataset_root is not None
713- assert entry_depth is not None and entry_depth .path is not None
714- path = os .path .join (dataset_root , entry_depth .path )
719+ assert entry_depth is not None
715720 depth_map = load_depth (self ._local_path (path ), entry_depth .scale_adjustment )
716721
717722 if self .mask_depths :
@@ -725,7 +730,7 @@ def _load_mask_depth(
725730 else :
726731 depth_mask = (depth_map > 0.0 ).astype (np .float32 )
727732
728- return torch .tensor (depth_map ), path , torch .tensor (depth_mask )
733+ return torch .tensor (depth_map ), torch .tensor (depth_mask )
729734
730735 def _get_pytorch3d_camera (
731736 self ,
@@ -758,19 +763,6 @@ def _get_pytorch3d_camera(
758763 T = torch .tensor (entry_viewpoint .T , dtype = torch .float )[None ],
759764 )
760765
761- def _fix_point_cloud_path (self , path : str ) -> str :
762- """
763- Fix up a point cloud path from the dataset.
764- Some files in Co3Dv2 have an accidental absolute path stored.
765- """
766- unwanted_prefix = (
767- "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
768- )
769- if path .startswith (unwanted_prefix ):
770- path = path [len (unwanted_prefix ) :]
771- assert self .dataset_root is not None
772- return os .path .join (self .dataset_root , path )
773-
774766 def _local_path (self , path : str ) -> str :
775767 if self .path_manager is None :
776768 return path
0 commit comments