Skip to content

Commit 2ab6656

Browse files
authored
[Feature/Cleanup] Add CUDADepthImageSegmenter and Remove OpenCL Classes (#911)
* Added CUDADepthImageSegmenter and removed OpenCL classes
1 parent 53d9542 commit 2ab6656

File tree

12 files changed

+237
-282
lines changed

12 files changed

+237
-282
lines changed

ihmc-high-level-behaviors/src/libgdx/java/us/ihmc/rdx/ui/RDXRawImagePointCloudRendererDemo.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import imgui.type.ImFloat;
77
import imgui.type.ImInt;
88
import us.ihmc.perception.RawImage;
9-
import us.ihmc.perception.opencl.OpenCLPointCloudExtractor;
109
import us.ihmc.rdx.AbstractRDXPointCloudRenderer.ColoringMethod;
1110
import us.ihmc.rdx.Lwjgl3ApplicationAdapter;
1211
import us.ihmc.rdx.ui.gizmo.RDXPose3DGizmo;
@@ -27,7 +26,6 @@ public class RDXRawImagePointCloudRendererDemo
2726
private final float[] defaultColor = new float[] {1.0f, 1.0f, 1.0f, 1.0f};
2827

2928
private RDXRawImagePointCloudRenderer pointCloudRenderer;
30-
private final OpenCLPointCloudExtractor pointCloudExtractor = new OpenCLPointCloudExtractor();
3129

3230
private long lastGrabSequenceNumber = -1L;
3331
private ZEDColorDepthImageRetriever zed;
@@ -74,7 +72,7 @@ public void render()
7472
{
7573
pointCloudRenderer.updateMesh(depthImage, colorImage);
7674
}
77-
else // inputMethod == InputMethod.POINT_CLOUD
75+
else
7876
{
7977
pointCloudRenderer.updateMesh(depthImage);
8078
}
@@ -119,7 +117,6 @@ public void dispose()
119117
{
120118
zed.destroy();
121119
pointCloudRenderer.dispose();
122-
pointCloudExtractor.destroy();
123120
baseUI.dispose();
124121
}
125122
});

ihmc-high-level-behaviors/src/main/java/us/ihmc/perception/IterativeClosestPointManager.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import us.ihmc.euclid.tuple3D.Point3D32;
1010
import us.ihmc.euclid.tuple3D.Vector3D;
1111
import us.ihmc.log.LogTools;
12-
import us.ihmc.perception.opencl.OpenCLPointCloudExtractor;
12+
import us.ihmc.perception.cuda.CUDAPointCloudExtractor;
1313
import us.ihmc.perception.sceneGraph.SceneGraph;
1414
import us.ihmc.perception.sceneGraph.SceneNode;
1515
import us.ihmc.perception.sceneGraph.rigidBody.primitive.PrimitiveRigidBodyShape;
@@ -29,7 +29,7 @@ public class IterativeClosestPointManager
2929
private final ROS2Helper ros2Helper;
3030
private final SceneGraph sceneGraph;
3131

32-
private final OpenCLPointCloudExtractor pointCloudExtractor = new OpenCLPointCloudExtractor();
32+
private final CUDAPointCloudExtractor pointCloudExtractor = new CUDAPointCloudExtractor();
3333

3434
private final Random random = new Random(System.nanoTime());
3535
private final ConcurrentHashMap<Long, IterativeClosestPointWorker> nodeIDToWorkerMap = new ConcurrentHashMap<>();
@@ -141,6 +141,7 @@ public void destroy()
141141
{
142142
nodeIDToWorkerMap.clear();
143143
workerThread.blockingKill();
144+
pointCloudExtractor.close();
144145
}
145146

146147
/**

ihmc-high-level-behaviors/src/test/java/us/ihmc/rdx/perception/RDXIterativeClosestPointWorkerDemo.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import us.ihmc.euclid.tuple3D.interfaces.Point3DReadOnly;
2222
import us.ihmc.perception.IterativeClosestPointWorker;
2323
import us.ihmc.perception.RawImage;
24-
import us.ihmc.perception.opencl.OpenCLPointCloudExtractor;
24+
import us.ihmc.perception.cuda.CUDAPointCloudExtractor;
2525
import us.ihmc.perception.sceneGraph.rigidBody.primitive.PrimitiveRigidBodyShape;
2626
import us.ihmc.rdx.Lwjgl3ApplicationAdapter;
2727
import us.ihmc.rdx.RDXPointCloudRendererOld;
@@ -50,7 +50,7 @@ public class RDXIterativeClosestPointWorkerDemo
5050
private static final int MAX_ENVIRONMENT_SIZE = 1000;
5151
private static final int CORRESPONDENCES = 1000;
5252

53-
private final OpenCLPointCloudExtractor pointCloudExtractor = new OpenCLPointCloudExtractor();
53+
private final CUDAPointCloudExtractor pointCloudExtractor = new CUDAPointCloudExtractor();
5454
private final Random random = new Random(System.nanoTime());
5555

5656
private final ROS2Node node = new ROS2NodeBuilder().build("icp_worker_demo");
@@ -277,6 +277,7 @@ public void dispose()
277277
zedImageRetriever.destroy();
278278
perceptionVisualizerPanel.destroy();
279279
baseUI.dispose();
280+
pointCloudExtractor.close();
280281
}
281282

282283
private void renderSettings()

ihmc-high-level-behaviors/src/test/java/us/ihmc/rdx/perception/RDXYOLOv8PipelineDemo.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
import us.ihmc.euclid.tuple3D.Point3D32;
1818
import us.ihmc.log.LogTools;
1919
import us.ihmc.perception.RawImage;
20+
import us.ihmc.perception.cuda.CUDADepthImageSegmenter;
2021
import us.ihmc.perception.cuda.CUDAPointCloudExtractor;
2122
import us.ihmc.perception.detections.yolo.YOLOv8Detection;
2223
import us.ihmc.perception.detections.yolo.YOLOv8DetectionList;
2324
import us.ihmc.perception.detections.yolo.YOLOv8Model;
2425
import us.ihmc.perception.detections.yolo.YOLOv8Tools;
2526
import us.ihmc.perception.imageMessage.PixelFormat;
26-
import us.ihmc.perception.opencl.OpenCLDepthImageSegmenter;
2727
import us.ihmc.rdx.Lwjgl3ApplicationAdapter;
2828
import us.ihmc.rdx.tools.RDXModelBuilder;
2929
import us.ihmc.rdx.ui.RDXBaseUI;
@@ -81,7 +81,7 @@ public class RDXYOLOv8PipelineDemo
8181
private final ImFloat maskThreshold = new ImFloat(0.0f);
8282
private final ImInt erosionKernelRadius = new ImInt(1);
8383

84-
private final OpenCLDepthImageSegmenter depthImageSegmenter = new OpenCLDepthImageSegmenter();
84+
private final CUDADepthImageSegmenter depthImageSegmenter;
8585
private RawImage segmentedDepth;
8686
private final RDXOpenCVVideoVisualizer segmentedDepthVisualizer = new RDXOpenCVVideoVisualizer("Segmented Depth", "Segmented Depth", false);
8787
private final RDXRawImagePointCloudVisualizer segmentedPointCloudVisualizer = new RDXRawImagePointCloudVisualizer("Segmented Point Cloud", true);
@@ -99,7 +99,7 @@ public class RDXYOLOv8PipelineDemo
9999

100100
private final ImInt frameToGrab = new ImInt(0);
101101

102-
private RDXYOLOv8PipelineDemo()
102+
private RDXYOLOv8PipelineDemo() throws Exception
103103
{
104104
for (URL yoloModelDirectory : YOLOv8Tools.getYOLOModelDirectories())
105105
{
@@ -112,6 +112,8 @@ private RDXYOLOv8PipelineDemo()
112112
availableModels.add(model.getName());
113113
}
114114

115+
depthImageSegmenter = new CUDADepthImageSegmenter();
116+
115117
zedPlaybackSensor.useTrackedPose(false);
116118
zedPlaybackSensor.run(true);
117119
try
@@ -463,21 +465,20 @@ private void destroy()
463465

464466
zedPointCloudVisualizer.destroy();
465467
segmentedPointCloudVisualizer.destroy();
466-
depthImageSegmenter.destroy();
467468
colorImageVisualizer.destroy();
468469
depthImageVisualizer.destroy();
469470
detectionMaskVisualizer.destroy();
470471
erodedMaskVisualizer.destroy();
471472
segmentedDepthVisualizer.destroy();
472473
annotatedImageVisualizer.destroy();
473474

474-
depthImageSegmenter.destroy();
475+
depthImageSegmenter.close();
475476
pointCloudExtractor.close();
476477
zedPlaybackSensor.close();
477478
ros2Node.destroy();
478479
}
479480

480-
public static void main(String[] args)
481+
public static void main(String[] args) throws Exception
481482
{
482483
new RDXYOLOv8PipelineDemo();
483484
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package us.ihmc.perception.cuda;
2+
3+
import org.bytedeco.cuda.cudart.CUstream_st;
4+
import org.bytedeco.cuda.cudart.dim3;
5+
import org.bytedeco.javacpp.FloatPointer;
6+
import org.bytedeco.opencv.opencv_core.GpuMat;
7+
import us.ihmc.euclid.transform.RigidBodyTransform;
8+
import us.ihmc.perception.RawImage;
9+
10+
import java.net.URL;
11+
12+
import static org.bytedeco.cuda.global.cudart.*;
13+
14+
public class CUDADepthImageSegmenter implements AutoCloseable
15+
{
16+
private static final int BLOCK_SIZE_XY = 16;
17+
18+
private final CUDAProgram program;
19+
private final CUDAKernel kernel;
20+
private final CUstream_st stream;
21+
22+
private final RigidBodyTransform depthToMaskTransform = new RigidBodyTransform();
23+
private final float[] transformArray = new float[16];
24+
private final FloatPointer transformPointer = new FloatPointer();
25+
26+
private final dim3 blockSize = new dim3(BLOCK_SIZE_XY, BLOCK_SIZE_XY, 1);
27+
private final dim3 gridSize = new dim3();
28+
29+
private int error;
30+
31+
public CUDADepthImageSegmenter() throws Exception
32+
{
33+
if (!CUDATools.hasCUDADevice())
34+
throw new Exception("CUDA unavailable.");
35+
36+
// Get the URLs to all the cu files
37+
URL segmentation = getClass().getResource("DepthImageSegmentation.cu");
38+
URL utils = getClass().getResource("Utils.cu");
39+
URL perceptionUtils = getClass().getResource("PerceptionUtils.cu");
40+
URL mathUtils = getClass().getResource("MathUtils.cuh");
41+
42+
// Compile the program and get the kernel
43+
program = new CUDAProgram(segmentation, utils, perceptionUtils, mathUtils);
44+
kernel = program.loadKernel("segmentDepthImage");
45+
46+
// Get a stream
47+
stream = CUDAStreamManager.getStream();
48+
49+
// Allocate fixed size page-locked memory (on host)
50+
error = cudaMallocHost(transformPointer, 16L * transformPointer.sizeof()); // 16 floats for transform matrix
51+
CUDATools.throwCUDAError(error);
52+
}
53+
54+
public RawImage removeBackground(RawImage depthImage, RawImage mask)
55+
{
56+
// Ensure we get the images
57+
if (depthImage.get() == null)
58+
return null;
59+
60+
if (mask.get() == null)
61+
{
62+
depthImage.release();
63+
return null;
64+
}
65+
66+
// Update the transform array
67+
depthImage.getTransformToWorld().inverseTransform(mask.getTransformToWorld(), depthToMaskTransform);
68+
depthToMaskTransform.get(transformArray);
69+
transformPointer.put(transformArray);
70+
71+
// Get the GPU mats
72+
GpuMat depthMat = depthImage.getGpuImageMat();
73+
GpuMat maskMat = mask.getGpuImageMat();
74+
GpuMat outputMat = new GpuMat(depthMat.size(), depthMat.type());
75+
76+
// Update the necessary grid size
77+
gridSize.x((depthImage.getWidth() + BLOCK_SIZE_XY - 1) / (BLOCK_SIZE_XY * 2));
78+
gridSize.y((depthImage.getHeight() + BLOCK_SIZE_XY - 1) / (BLOCK_SIZE_XY * 2));
79+
80+
// Run the kernel
81+
kernel.withPointer(depthMat.data()).withLong(depthMat.step())
82+
.withInt(depthImage.getWidth()).withInt(depthImage.getHeight())
83+
.withFloat(depthImage.getFocalLengthX()).withFloat(depthImage.getFocalLengthY())
84+
.withFloat(depthImage.getPrincipalPointX()).withFloat(depthImage.getPrincipalPointY())
85+
.withPointer(maskMat.data()).withLong(maskMat.step())
86+
.withInt(mask.getWidth()).withInt(mask.getHeight())
87+
.withFloat(mask.getFocalLengthX()).withFloat(mask.getFocalLengthY())
88+
.withFloat(mask.getPrincipalPointX()).withFloat(mask.getPrincipalPointY())
89+
.withFloat(depthImage.getDepthDiscretization())
90+
.withPointer(transformPointer)
91+
.withPointer(outputMat.data()).withLong(outputMat.step())
92+
.run(stream, gridSize, blockSize, 0);
93+
94+
// Synchronize to ensure the output data is contained in the output mat
95+
error = cudaStreamSynchronize(stream);
96+
CUDATools.checkCUDAError(error);
97+
98+
// Get the result as a RawImage
99+
RawImage result = depthImage.replaceImage(outputMat);
100+
101+
// Release the RawImages
102+
depthImage.release();
103+
mask.release();
104+
105+
// Return the result
106+
return result;
107+
}
108+
109+
@Override
110+
public void close()
111+
{
112+
error = cudaFreeHost(transformPointer);
113+
CUDATools.checkCUDAError(error);
114+
115+
blockSize.close();
116+
gridSize.close();
117+
118+
kernel.close();
119+
program.close();
120+
CUDAStreamManager.releaseStream(stream);
121+
}
122+
}

ihmc-perception/src/main/java/us/ihmc/perception/detections/yolo/YOLOv8DetectionExecutor.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
import us.ihmc.euclid.tuple3D.Point3D32;
2121
import us.ihmc.log.LogTools;
2222
import us.ihmc.perception.RawImage;
23+
import us.ihmc.perception.cuda.CUDADepthImageSegmenter;
2324
import us.ihmc.perception.cuda.CUDAPointCloudExtractor;
2425
import us.ihmc.perception.detections.InstantDetection;
2526
import us.ihmc.perception.imageMessage.CompressionType;
2627
import us.ihmc.perception.imageMessage.PixelFormat;
27-
import us.ihmc.perception.opencl.OpenCLDepthImageSegmenter;
2828
import us.ihmc.perception.tools.PerceptionMessageTools;
2929
import us.ihmc.ros2.ROS2Node;
3030
import us.ihmc.ros2.ROS2NodeBuilder;
@@ -46,8 +46,8 @@ public class YOLOv8DetectionExecutor
4646
{
4747
private final ROS2Node ros2Node = new ROS2NodeBuilder().build("yolo_detection_manager");
4848

49-
private final CUDAPointCloudExtractor extractor = new CUDAPointCloudExtractor();
50-
private final OpenCLDepthImageSegmenter segmenter = new OpenCLDepthImageSegmenter();
49+
private final CUDAPointCloudExtractor extractor;
50+
private final CUDADepthImageSegmenter segmenter;
5151

5252
private final Map<String, YOLOv8Model> availableModels = new LinkedHashMap<>();
5353
private final Map<YOLOv8Model, YOLOv8DetectionList> yoloDetectionResults = new ConcurrentHashMap<>();
@@ -73,6 +73,16 @@ public YOLOv8DetectionExecutor(CRDTInfo crdtInfo, BooleanSupplier annotatedImage
7373
{
7474
this.annotatedImageDemanded = annotatedImageDemanded;
7575

76+
try
77+
{
78+
extractor = new CUDAPointCloudExtractor();
79+
segmenter = new CUDADepthImageSegmenter();
80+
}
81+
catch (Exception e)
82+
{
83+
throw new RuntimeException(e);
84+
}
85+
7686
// Read available YOLO models
7787
for (URL yoloModelDirectory : YOLOv8Tools.getYOLOModelDirectories())
7888
{
@@ -280,7 +290,7 @@ public void destroy()
280290
yoloResults.destroy();
281291

282292
extractor.close();
283-
segmenter.destroy();
293+
segmenter.close();
284294

285295
System.out.println("Destroyed " + getClass().getSimpleName());
286296
}

0 commit comments

Comments
 (0)