Skip to content

Commit 3b38586

Browse files
committed
added contactgraspnet class and testscript
1 parent c8a5011 commit 3b38586

File tree

2 files changed

+46
-10
lines changed

2 files changed

+46
-10
lines changed

robot_toolkit/robot_arm_algos/src/inference/contact_graspnet.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import numpy as np
22
import cv2
3+
import timeit
34
# Tensorflow
45

56
import tensorflow.compat.v1 as tf
67
tf.disable_eager_execution()
78
physical_devices = tf.config.experimental.list_physical_devices('GPU')
89
tf.config.experimental.set_memory_growth(physical_devices[0], True)
9-
10+
import matplotlib.pyplot as plt
1011
import sys
1112
sys.path.insert(0, '/root/git/contact_graspnet/contact_graspnet')
1213

@@ -15,8 +16,9 @@
1516
from contact_grasp_estimator import GraspEstimator
1617
from visualization_utils import visualize_grasps, show_image
1718

18-
from ..camera.camera import get_bbox_annotations, get_segmap_from_bbox
19+
from ..camera.camera import get_bbox_annotations, get_segmap_from_bbox_with_depth , get_segmap_from_bbox
1920
from ._grasp_predictor import GraspPredictor
21+
from ..logger import logger
2022
class ContactGraspNet(GraspPredictor):#(object)
2123
def __init__(self,
2224
global_config,
@@ -47,11 +49,18 @@ def __init__(self,
4749

4850
pass
4951

50-
def generate_grasps(self, rgbd_camera):
52+
def generate_grasps(self, rgbd_camera, is_visualize_grasps = False, use_depth_for_seg = False):
5153
rgb_image, depth_image = rgbd_camera.get_current_rgbd_frames()
5254
bbox, _ = get_bbox_annotations(rgb_image)
53-
segmap = get_segmap_from_bbox(rgb_image, bbox)
55+
if use_depth_for_seg:
56+
segmap = get_segmap_from_bbox_with_depth(rgb_image, depth_image, bbox)
57+
else:
58+
segmap = get_segmap_from_bbox(rgb_image, bbox)
59+
#
60+
plt.imshow(segmap)
61+
plt.show()
5462
cam_k = rgbd_camera.camera_matrix
63+
start = timeit.timeit()
5564
pc_full, pc_segments, pc_colors = self.grasp_estimator.extract_point_clouds(depth_image,
5665
cam_k,
5766
segmap=segmap,
@@ -63,10 +72,15 @@ def generate_grasps(self, rgbd_camera):
6372
local_regions=self.local_regions,
6473
filter_grasps=self.filter_grasps,
6574
forward_passes=self.forward_passes)
75+
end = timeit.timeit()
76+
logger.info(f"elapsed time {end-start}")
6677

6778
show_image(rgb_image, segmap)
68-
visualize_grasps(pc_full, pred_grasps_cam, scores, plot_opencv_cam=True, pc_colors=pc_colors)
69-
79+
plt.imshow(segmap)
80+
plt.show()
81+
if is_visualize_grasps:
82+
visualize_grasps(pc_full, pred_grasps_cam, scores, plot_opencv_cam=True, pc_colors=pc_colors)
83+
return pred_grasps_cam, scores, contact_pts#, _
7084
# get_segmap
7185
# rgb image
7286
# depth image
+26-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,35 @@
1+
import numpy as np
2+
import open3d as o3d
13
from robot_arm_algos.src.camera.realsense_camera import RealSenseCamera
24
from robot_arm_algos.src.inference.contact_graspnet import ContactGraspNet as cgn
35
import robot_arm_algos.src.inference._contact_graspnet_config_utils as config_utils
46

7+
8+
test_config = {
9+
"use_depth" : False,
10+
"checkpoint_dir" : "/root/git/scratchpad/scene_test_2048_bs3_hor_sigma_0025"
11+
}
12+
513
def main():
614
rs_camera = RealSenseCamera()
7-
checkpoint_dir = "/root/git/scratchpad/scene_test_2048_bs3_hor_sigma_0025"
8-
global_config = config_utils.load_config(checkpoint_dir, batch_size=1, arg_configs=[])
9-
cgn_ = cgn(global_config, checkpoint_dir )
10-
cgn_.generate_grasps(rs_camera)
15+
global_config = config_utils.load_config(test_config["checkpoint_dir"], batch_size=1, arg_configs=[])
16+
cgn_ = cgn(global_config, test_config["checkpoint_dir"] )
17+
grasps, scores, contact_points = cgn_.generate_grasps(rs_camera, use_depth_for_seg = test_config["use_depth"])
18+
19+
# sorted_grasps = [x for _, x in sorted(zip(grasps[255.0], scores[255.0])) ]
20+
sorted_grasps = grasps[255.0][np.argsort(scores[255.0])]
21+
plot_grasps = sorted_grasps[:3]
22+
plot_grasp_frames = []
23+
for plot_grasp in plot_grasps:
24+
frame_mesh = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.2).transform(plot_grasp)
25+
plot_grasp_frames.append(frame_mesh)
26+
27+
rgb_image, depth_image = rs_camera.get_current_rgbd_frames()
28+
pcd = rs_camera.get_pointcloud_rgbd(rgb_image, depth_image)
29+
plot_grasp_frames.append(pcd)
30+
o3d.visualization.draw_geometries(plot_grasp_frames)
31+
32+
1133

1234
if __name__ == "__main__":
1335
main()

0 commit comments

Comments
 (0)