Skip to content

Commit

Permalink
Refactor gt vs exte detections + add visibility threshold on gt detec…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
MedericFourmy committed Nov 14, 2023
1 parent 6800d30 commit 2f45689
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 27 deletions.
4 changes: 1 addition & 3 deletions happypose/pose_estimators/megapose/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ def run_eval(
detector_model = happypose.toolbox.inference.utils.load_detector(
cfg.detector_run_id,
)
elif cfg.inference.detection_type == "gt":
detector_model = None
elif cfg.inference.detection_type == "exte":
elif cfg.inference.detection_type in ["gt", "exte"]:
detector_model = None
else:
msg = f"Unknown detection_type={cfg.inference.detection_type}"
Expand Down
46 changes: 22 additions & 24 deletions happypose/pose_estimators/megapose/evaluation/prediction_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def run_inference_pipeline(
self,
pose_estimator: PoseEstimator,
obs_tensor: ObservationTensor,
gt_detections: DetectionsType,
exte_detections: DetectionsType,
detections: DetectionsType,
initial_estimates: Optional[PoseEstimatesType] = None,
) -> Dict[str, PoseEstimatesType]:
"""Runs inference pipeline, extracts the results.
Expand All @@ -95,22 +94,13 @@ def run_inference_pipeline(
- 'refiner/final': preds at final refiner iteration (before depth
refinement).
- 'depth_refinement': preds after depth refinement.
"""
# TODO: this check could be done outside of run_inference_pipeline
# and then only check if detections are None
if self.inference_cfg.detection_type == "gt":
detections = gt_detections
run_detector = False
elif self.inference_cfg.detection_type == "exte":
# print("exte_detections =", exte_detections.bboxes)
detections = exte_detections

if self.inference_cfg.detection_type in ["gt", "exte"]:
run_detector = False
elif self.inference_cfg.detection_type == "detector":
detections = None
run_detector = True

else:
msg = f"Unknown detection type {self.inference_cfg.detection_type}"
raise ValueError(msg)
Expand Down Expand Up @@ -194,7 +184,6 @@ def get_predictions(
# This section opens the detections stored in "baseline.json"
# format it and store it in a dataframe that will be accessed later
######
# Temporary solution
if self.inference_cfg.detection_type == "exte":
df_all_dets, df_targets = load_external_detections(self.scene_ds.ds_dir)

Expand All @@ -209,16 +198,27 @@ def get_predictions(
# Dirty but avoids creating error when running with real detector
dt_det_exte = 0

# Temporary solution
# Select view detections depending detection type
if self.inference_cfg.detection_type == "exte":
exte_detections = filter_detections_scene_view(
detections = filter_detections_scene_view(
scene_id, view_id, df_all_dets, df_targets
)
if len(exte_detections) > 0:
dt_det_exte += exte_detections.infos["time"].iloc[0]
if len(detections) > 0:
dt_det_exte += detections.infos["time"].iloc[0]
elif self.inference_cfg.detection_type == "gt":
detections = data["gt_detections"].cuda()
"""
Some groundtruth detections have non zero visibility and
zero sized bounding boxes
-> remove them to avoid division by zero errors later
"""
min_visibility_fract = 0.05
detections.infos = detections.infos[
detections.infos["visib_fract"] > min_visibility_fract
]
else:
exte_detections = None
gt_detections = data["gt_detections"].cuda()
detections = None

initial_data = None
if data["initial_data"]:
initial_data = data["initial_data"].cuda()
Expand All @@ -232,17 +232,15 @@ def get_predictions(
self.run_inference_pipeline(
pose_estimator,
obs_tensor,
gt_detections,
exte_detections,
detections,
initial_estimates=initial_data,
)

with torch.no_grad():
all_preds, all_preds_data = self.run_inference_pipeline(
pose_estimator,
obs_tensor,
gt_detections,
exte_detections,
detections,
initial_estimates=initial_data,
)

Expand Down
5 changes: 5 additions & 0 deletions happypose/pose_estimators/megapose/models/pose_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,9 @@ def forward_coarse(
def has_nans(tensor):
return bool(tensor.isnan().any())

def has_infs(tensor):
return bool(tensor.isinf().any())

assert (
self.predict_rendered_views_logits
), "Method only valid if coarse classification model"
Expand All @@ -763,6 +766,8 @@ def has_nans(tensor):
print("K has NANS")
if has_nans(TCO_input):
print("TCO_input has NANS")
if has_infs(TCO_input):
print("TCO_input has INFS")

TCO_input = normalize_T(TCO_input).detach()
if has_nans(TCO_input):
Expand Down

0 comments on commit 2f45689

Please sign in to comment.