1
1
import numpy as np
2
2
import cv2
3
+ import timeit
3
4
# Tensorflow
4
5
5
6
import tensorflow .compat .v1 as tf
6
7
tf .disable_eager_execution ()
7
8
physical_devices = tf .config .experimental .list_physical_devices ('GPU' )
8
9
tf .config .experimental .set_memory_growth (physical_devices [0 ], True )
9
-
10
+ import matplotlib . pyplot as plt
10
11
import sys
11
12
sys .path .insert (0 , '/root/git/contact_graspnet/contact_graspnet' )
12
13
15
16
from contact_grasp_estimator import GraspEstimator
16
17
from visualization_utils import visualize_grasps , show_image
17
18
18
- from ..camera .camera import get_bbox_annotations , get_segmap_from_bbox
19
+ from ..camera .camera import get_bbox_annotations , get_segmap_from_bbox_with_depth , get_segmap_from_bbox
19
20
from ._grasp_predictor import GraspPredictor
21
+ from ..logger import logger
20
22
class ContactGraspNet (GraspPredictor ):#(object)
21
23
def __init__ (self ,
22
24
global_config ,
@@ -47,11 +49,18 @@ def __init__(self,
47
49
48
50
pass
49
51
50
- def generate_grasps (self , rgbd_camera ):
52
+ def generate_grasps (self , rgbd_camera , is_visualize_grasps = False , use_depth_for_seg = False ):
51
53
rgb_image , depth_image = rgbd_camera .get_current_rgbd_frames ()
52
54
bbox , _ = get_bbox_annotations (rgb_image )
53
- segmap = get_segmap_from_bbox (rgb_image , bbox )
55
+ if use_depth_for_seg :
56
+ segmap = get_segmap_from_bbox_with_depth (rgb_image , depth_image , bbox )
57
+ else :
58
+ segmap = get_segmap_from_bbox (rgb_image , bbox )
59
+ #
60
+ plt .imshow (segmap )
61
+ plt .show ()
54
62
cam_k = rgbd_camera .camera_matrix
63
+ start = timeit .timeit ()
55
64
pc_full , pc_segments , pc_colors = self .grasp_estimator .extract_point_clouds (depth_image ,
56
65
cam_k ,
57
66
segmap = segmap ,
@@ -63,10 +72,15 @@ def generate_grasps(self, rgbd_camera):
63
72
local_regions = self .local_regions ,
64
73
filter_grasps = self .filter_grasps ,
65
74
forward_passes = self .forward_passes )
75
+ end = timeit .timeit ()
76
+ logger .info (f"elapsed time { end - start } " )
66
77
67
78
show_image (rgb_image , segmap )
68
- visualize_grasps (pc_full , pred_grasps_cam , scores , plot_opencv_cam = True , pc_colors = pc_colors )
69
-
79
+ plt .imshow (segmap )
80
+ plt .show ()
81
+ if is_visualize_grasps :
82
+ visualize_grasps (pc_full , pred_grasps_cam , scores , plot_opencv_cam = True , pc_colors = pc_colors )
83
+ return pred_grasps_cam , scores , contact_pts #, _
70
84
# get_segmap
71
85
# rgb image
72
86
# depth image
0 commit comments