Skip to content

Commit 0bcbc73

Browse files
authored
Merge pull request #3516 from cudawarped:cuda_moments
`cuda`: add `moments`
2 parents faa5468 + eca7f1c commit 0bcbc73

File tree

8 files changed

+548
-5
lines changed

8 files changed

+548
-5
lines changed

modules/cudaimgproc/include/opencv2/cudaimgproc.hpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
@{
5858
@defgroup cudaimgproc_color Color space processing
5959
@defgroup cudaimgproc_hist Histogram Calculation
60+
@defgroup cudaimgproc_shape Structural Analysis and Shape Descriptors
6061
@defgroup cudaimgproc_hough Hough Transform
6162
@defgroup cudaimgproc_feature Feature Detection
6263
@}
@@ -779,9 +780,84 @@ CV_EXPORTS_AS(connectedComponentsWithAlgorithm) void connectedComponents(InputAr
779780
CV_EXPORTS_W void connectedComponents(InputArray image, OutputArray labels,
780781
int connectivity = 8, int ltype = CV_32S);
781782

782-
783783
//! @}
784784

785+
//! @addtogroup cudaimgproc_shape
786+
//! @{
787+
788+
/** @brief Order of image moments.
789+
* @param FIRST_ORDER_MOMENTS First order moments
790+
* @param SECOND_ORDER_MOMENTS Second order moments.
791+
* @param THIRD_ORDER_MOMENTS Third order moments.
792+
* */
793+
enum MomentsOrder {
794+
FIRST_ORDER_MOMENTS = 1,
795+
SECOND_ORDER_MOMENTS = 2,
796+
THIRD_ORDER_MOMENTS = 3
797+
};
798+
799+
/** @brief Returns the number of image moments less than or equal to the largest image moments \a order.
800+
@param order Order of largest moments to calculate with lower order moments requiring less computation.
801+
@returns number of image moments.
802+
803+
@sa cuda::moments, cuda::spatialMoments, cuda::MomentsOrder
804+
*/
805+
CV_EXPORTS_W int numMoments(const MomentsOrder order);
806+
807+
/** @brief Calculates all of the spatial moments up to the 3rd order of a rasterized shape.
808+
809+
Asynchronous version of cuda::moments() which only calculates the spatial (not centralized or normalized) moments, up to the 3rd order, of a rasterized shape.
810+
Each moment is returned as a column entry in the 1D \a moments array.
811+
812+
@param src Raster image (single-channel 2D array).
813+
@param [out] moments 1D array with each column entry containing a spatial image moment.
814+
@param binaryImage If it is true, all non-zero image pixels are treated as 1's.
815+
@param order Order of largest moments to calculate with lower order moments requiring less computation.
816+
@param momentsType Precision to use when calculating moments. Available types are `CV_32F` and `CV_64F` with the performance of `CV_32F` an order of magnitude greater than `CV_64F`. If the image is small the accuracy from `CV_32F` can be equal or very close to `CV_64F`.
817+
@param stream Stream for the asynchronous version.
818+
819+
@note For maximum performance pre-allocate a 1D GpuMat for \a moments of the correct type and size large enough to store the all the image moments of up to the desired \a order. e.g. With \a order === MomentsOrder::SECOND_ORDER_MOMENTS and \a momentsType == `CV_32F` \a moments can be allocated as
820+
```
821+
GpuMat momentsDevice(1,numMoments(MomentsOrder::SECOND_ORDER_MOMENTS),CV_32F)
822+
```
823+
The central and normalized moments can easily be calculated on the host by downloading the \a moments array and using the cv::Moments constructor. e.g.
824+
```
825+
HostMem momentsHostMem(1, numMoments(MomentsOrder::SECOND_ORDER_MOMENTS), CV_32F);
826+
momentsDevice.download(momentsHostMem, stream);
827+
stream.waitForCompletion();
828+
Mat momentsMat = momentsHostMem.createMatHeader();
829+
cv::Moments cvMoments(momentsMat.at<float>(0), momentsMat.at<float>(1), momentsMat.at<float>(2), momentsMat.at<float>(3), momentsMat.at<float>(4), momentsMat.at<float>(5), momentsMat.at<float>(6), momentsMat.at<float>(7), momentsMat.at<float>(8), momentsMat.at<float>(9));
830+
```
831+
see the \a CUDA_TEST_P(Moments, Async) test inside opencv_contrib_source_code/modules/cudaimgproc/test/test_moments.cpp for an example.
832+
@returns cv::Moments.
833+
@sa cuda::moments
834+
*/
835+
CV_EXPORTS_W void spatialMoments(InputArray src, OutputArray moments, const bool binaryImage = false, const MomentsOrder order = MomentsOrder::THIRD_ORDER_MOMENTS, const int momentsType = CV_64F, Stream& stream = Stream::Null());
836+
837+
/** @brief Calculates all of the moments up to the 3rd order of a rasterized shape.
838+
839+
The function computes moments, up to the 3rd order, of a rasterized shape. The
840+
results are returned in the structure cv::Moments.
841+
842+
@param src Raster image (single-channel 2D array).
843+
@param binaryImage If it is true, all non-zero image pixels are treated as 1's.
844+
@param order Order of largest moments to calculate with lower order moments requiring less computation.
845+
@param momentsType Precision to use when calculating moments. Available types are `CV_32F` and `CV_64F` with the performance of `CV_32F` an order of magnitude greater than `CV_64F`. If the image is small the accuracy from `CV_32F` can be equal or very close to `CV_64F`.
846+
847+
@note For maximum performance use the asynchronous version cuda::spatialMoments() as this version interally allocates and deallocates both GpuMat and HostMem to respectively perform the calculation on the device and download the result to the host.
848+
The costly HostMem allocation cannot be avoided however the GpuMat device allocation can be by using BufferPool, e.g.
849+
```
850+
setBufferPoolUsage(true);
851+
setBufferPoolConfig(getDevice(), numMoments(order) * ((momentsType == CV_64F) ? sizeof(double) : sizeof(float)), 1);
852+
```
853+
see the \a CUDA_TEST_P(Moments, Accuracy) test inside opencv_contrib_source_code/modules/cudaimgproc/test/test_moments.cpp for an example.
854+
@returns cv::Moments.
855+
@sa cuda::spatialMoments
856+
*/
857+
CV_EXPORTS_W Moments moments(InputArray src, const bool binaryImage = false, const MomentsOrder order = MomentsOrder::THIRD_ORDER_MOMENTS, const int momentsType = CV_64F);
858+
859+
//! @} cudaimgproc_shape
860+
785861
}} // namespace cv { namespace cuda {
786862

787863
#endif /* OPENCV_CUDAIMGPROC_HPP */

modules/cudaimgproc/misc/python/test/test_cudaimgproc.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,30 @@ def test_cvtColor(self):
8989
self.assertTrue(np.allclose(cv.cuda.cvtColor(cuMat, cv.COLOR_BGR2HSV).download(),
9090
cv.cvtColor(npMat, cv.COLOR_BGR2HSV)))
9191

92+
def test_moments(self):
93+
# setup
94+
src_host = (np.ones([10,10])).astype(np.uint8)*255
95+
cpu_moments = cv.moments(src_host, True)
96+
moments_order = cv.cuda.THIRD_ORDER_MOMENTS
97+
n_moments = cv.cuda.numMoments(cv.cuda.THIRD_ORDER_MOMENTS)
98+
src_device = cv.cuda.GpuMat(src_host)
99+
100+
# synchronous
101+
cv.cuda.setBufferPoolUsage(True)
102+
cv.cuda.setBufferPoolConfig(cv.cuda.getDevice(), n_moments * np.dtype(float).itemsize, 1);
103+
gpu_moments = cv.cuda.moments(src_device, True, moments_order, cv.CV_64F)
104+
self.assertTrue(len([1 for moment_type in cpu_moments if moment_type in gpu_moments and cpu_moments[moment_type] == gpu_moments[moment_type]]) == 24)
105+
106+
# asynchronous
107+
stream = cv.cuda.Stream()
108+
moments_array_host = np.empty([1, n_moments], np.float64)
109+
cv.cuda.registerPageLocked(moments_array_host)
110+
moments_array_device = cv.cuda.GpuMat(1, n_moments, cv.CV_64F)
111+
cv.cuda.spatialMoments(src_device, moments_array_device, True, moments_order, cv.CV_64F, stream)
112+
moments_array_device.download(stream, moments_array_host);
113+
stream.waitForCompletion()
114+
cv.cuda.unregisterPageLocked(moments_array_host)
115+
self.assertTrue(len([ 1 for moment_type,gpu_moment in zip(cpu_moments,moments_array_host[0]) if cpu_moments[moment_type] == gpu_moment]) == 10)
116+
92117
if __name__ == '__main__':
93118
NewOpenCVTests.bootstrap()
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
5+
#include "perf_precomp.hpp"
6+
7+
namespace opencv_test { namespace {
8+
static void drawCircle(cv::Mat& dst, const cv::Vec3i& circle, bool fill)
9+
{
10+
dst.setTo(Scalar::all(0));
11+
cv::circle(dst, Point2i(circle[0], circle[1]), circle[2], Scalar::all(255), fill ? -1 : 1, cv::LINE_AA);
12+
}
13+
14+
DEF_PARAM_TEST(Sz_Depth, Size, MatDepth);
15+
PERF_TEST_P(Sz_Depth, SpatialMoments, Combine(CUDA_TYPICAL_MAT_SIZES, Values(MatDepth(CV_32F), MatDepth((CV_64F)))))
16+
{
17+
const cv::Size size = GET_PARAM(0);
18+
const int momentsType = GET_PARAM(1);
19+
Mat imgHost(size, CV_8U);
20+
const Vec3i circle(size.width / 2, size.height / 2, static_cast<int>(static_cast<float>(size.width / 2) * 0.9));
21+
drawCircle(imgHost, circle, true);
22+
if (PERF_RUN_CUDA()) {
23+
const MomentsOrder order = MomentsOrder::THIRD_ORDER_MOMENTS;
24+
const int nMoments = numMoments(order);
25+
GpuMat momentsDevice(1, nMoments, momentsType);
26+
const GpuMat imgDevice(imgHost);
27+
TEST_CYCLE() cuda::spatialMoments(imgDevice, momentsDevice, false, order, momentsType);
28+
SANITY_CHECK_NOTHING();
29+
}
30+
else {
31+
cv::Moments momentsHost;
32+
TEST_CYCLE() momentsHost = cv::moments(imgHost, false);
33+
SANITY_CHECK_NOTHING();
34+
}
35+
}
36+
37+
PERF_TEST_P(Sz_Depth, Moments, Combine(CUDA_TYPICAL_MAT_SIZES, Values(MatDepth(CV_32F), MatDepth(CV_64F))))
38+
{
39+
const cv::Size size = GET_PARAM(0);
40+
const int momentsType = GET_PARAM(1);
41+
Mat imgHost(size, CV_8U);
42+
const Vec3i circle(size.width / 2, size.height / 2, static_cast<int>(static_cast<float>(size.width / 2) * 0.9));
43+
drawCircle(imgHost, circle, true);
44+
if (PERF_RUN_CUDA()) {
45+
const MomentsOrder order = MomentsOrder::THIRD_ORDER_MOMENTS;
46+
const int nMoments = numMoments(order);
47+
setBufferPoolUsage(true);
48+
setBufferPoolConfig(getDevice(), nMoments * ((momentsType == CV_64F) ? sizeof(double) : sizeof(float)), 1);
49+
const GpuMat imgDevice(imgHost);
50+
cv::Moments momentsHost;
51+
TEST_CYCLE() momentsHost = cuda::moments(imgDevice, false, order, momentsType);
52+
SANITY_CHECK_NOTHING();
53+
}
54+
else {
55+
cv::Moments momentsHost;
56+
TEST_CYCLE() momentsHost = cv::moments(imgHost, false);
57+
SANITY_CHECK_NOTHING();
58+
}
59+
}
60+
61+
}}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
5+
#if !defined CUDA_DISABLER
6+
7+
#include <opencv2/core/cuda/common.hpp>
8+
#include <opencv2/cudev/util/atomic.hpp>
9+
#include "moments.cuh"
10+
11+
namespace cv { namespace cuda { namespace device { namespace imgproc {
12+
13+
constexpr int blockSizeX = 32;
14+
constexpr int blockSizeY = 16;
15+
16+
template <typename T>
17+
__device__ T butterflyWarpReduction(T value) {
18+
for (int i = 16; i >= 1; i /= 2)
19+
value += __shfl_xor_sync(0xffffffff, value, i, 32);
20+
return value;
21+
}
22+
23+
template <typename T>
24+
__device__ T butterflyHalfWarpReduction(T value) {
25+
for (int i = 8; i >= 1; i /= 2)
26+
value += __shfl_xor_sync(0xffff, value, i, 32);
27+
return value;
28+
}
29+
30+
template<typename T, int nMoments>
31+
__device__ void updateSums(const T val, const unsigned int x, T r[4]) {
32+
const T x2 = x * x;
33+
const T x3 = static_cast<T>(x) * x2;
34+
r[0] += val;
35+
r[1] += val * x;
36+
if (nMoments >= n12) r[2] += val * x2;
37+
if (nMoments >= n123) r[3] += val * x3;
38+
}
39+
40+
template<typename TSrc, typename TMoments, int nMoments>
41+
__device__ void rowReductions(const PtrStepSz<TSrc> img, const bool binary, const unsigned int y, TMoments r[4], TMoments smem[][nMoments + 1]) {
42+
for (int x = threadIdx.x; x < img.cols; x += blockDim.x) {
43+
const TMoments val = (!binary || img(y, x) == 0) ? img(y, x) : 1;
44+
updateSums<TMoments,nMoments>(val, x, r);
45+
}
46+
}
47+
48+
template<typename TSrc, typename TMoments, bool fourByteAligned, int nMoments>
49+
__device__ void rowReductionsCoalesced(const PtrStepSz<TSrc> img, const bool binary, const unsigned int y, TMoments r[4], const int offsetX, TMoments smem[][nMoments + 1]) {
50+
const int alignedOffset = fourByteAligned ? 0 : 4 - offsetX;
51+
// load uncoalesced head
52+
if (!fourByteAligned && threadIdx.x == 0) {
53+
for (int x = 0; x < ::min(alignedOffset, static_cast<int>(img.cols)); x++) {
54+
const TMoments val = (!binary || img(y, x) == 0) ? img(y, x) : 1;
55+
updateSums<TMoments, nMoments>(val, x, r);
56+
}
57+
}
58+
59+
// coalesced loads
60+
const unsigned int* rowPtrIntAligned = (const unsigned int*)(fourByteAligned ? img.ptr(y) : img.ptr(y) + alignedOffset);
61+
const int cols4 = fourByteAligned ? img.cols / 4 : (img.cols - alignedOffset) / 4;
62+
for (int x = threadIdx.x; x < cols4; x += blockDim.x) {
63+
const unsigned int data = rowPtrIntAligned[x];
64+
#pragma unroll 4
65+
for (int i = 0; i < 4; i++) {
66+
const int iX = alignedOffset + 4 * x + i;
67+
const uchar ucharVal = ((data >> i * 8) & 0xFFU);
68+
const TMoments val = (!binary || ucharVal == 0) ? ucharVal : 1;
69+
updateSums<TMoments, nMoments>(val, iX, r);
70+
}
71+
}
72+
73+
// load uncoalesced tail
74+
if (threadIdx.x == 0) {
75+
const int iTailStart = fourByteAligned ? cols4 * 4 : cols4 * 4 + alignedOffset;
76+
for (int x = iTailStart; x < img.cols; x++) {
77+
const TMoments val = (!binary || img(y, x) == 0) ? img(y, x) : 1;
78+
updateSums<TMoments, nMoments>(val, x, r);
79+
}
80+
}
81+
}
82+
83+
template <typename TSrc, typename TMoments, bool coalesced = false, bool fourByteAligned = false, int nMoments>
84+
__global__ void spatialMoments(const PtrStepSz<TSrc> img, const bool binary, TMoments* moments, const int offsetX = 0) {
85+
const unsigned int y = blockIdx.x * blockDim.y + threadIdx.y;
86+
__shared__ TMoments smem[blockSizeY][nMoments + 1];
87+
if (threadIdx.y < nMoments && threadIdx.x < blockSizeY)
88+
smem[threadIdx.x][threadIdx.y] = 0;
89+
__syncthreads();
90+
91+
TMoments r[4] = { 0 };
92+
if (y < img.rows) {
93+
if (coalesced)
94+
rowReductionsCoalesced<TSrc, TMoments, fourByteAligned, nMoments>(img, binary, y, r, offsetX, smem);
95+
else
96+
rowReductions<TSrc, TMoments, nMoments>(img, binary, y, r, smem);
97+
}
98+
99+
const unsigned long y2 = y * y;
100+
const TMoments y3 = static_cast<TMoments>(y2) * y;
101+
const TMoments res = butterflyWarpReduction<float>(r[0]);
102+
if (res) {
103+
smem[threadIdx.y][0] = res; //0th
104+
smem[threadIdx.y][1] = butterflyWarpReduction(r[1]); //1st
105+
smem[threadIdx.y][2] = y * res; //1st
106+
if (nMoments >= n12) {
107+
smem[threadIdx.y][3] = butterflyWarpReduction(r[2]); //2nd
108+
smem[threadIdx.y][4] = smem[threadIdx.y][1] * y; //2nd
109+
smem[threadIdx.y][5] = y2 * res; //2nd
110+
}
111+
if (nMoments >= n123) {
112+
smem[threadIdx.y][6] = butterflyWarpReduction(r[3]); //3rd
113+
smem[threadIdx.y][7] = smem[threadIdx.y][3] * y; //3rd
114+
smem[threadIdx.y][8] = smem[threadIdx.y][1] * y2; //3rd
115+
smem[threadIdx.y][9] = y3 * res; //3rd
116+
}
117+
}
118+
__syncthreads();
119+
120+
if (threadIdx.x < blockSizeY && threadIdx.y < nMoments)
121+
smem[threadIdx.y][nMoments] = butterflyHalfWarpReduction(smem[threadIdx.x][threadIdx.y]);
122+
__syncthreads();
123+
124+
if (threadIdx.y == 0 && threadIdx.x < nMoments) {
125+
if (smem[threadIdx.x][nMoments])
126+
cudev::atomicAdd(&moments[threadIdx.x], smem[threadIdx.x][nMoments]);
127+
}
128+
}
129+
130+
template <typename TSrc, typename TMoments, int nMoments> struct momentsDispatcherNonChar {
131+
static void call(const PtrStepSz<TSrc> src, PtrStepSz<TMoments> moments, const bool binary, const int offsetX, const cudaStream_t stream) {
132+
dim3 blockSize(blockSizeX, blockSizeY);
133+
dim3 gridSize = dim3(divUp(src.rows, blockSizeY));
134+
spatialMoments<TSrc, TMoments, false, false, nMoments> << <gridSize, blockSize, 0, stream >> > (src, binary, moments.ptr());
135+
if (stream == 0)
136+
cudaSafeCall(cudaStreamSynchronize(stream));
137+
};
138+
};
139+
140+
template <typename TSrc, int nMoments> struct momentsDispatcherChar {
141+
static void call(const PtrStepSz<TSrc> src, PtrStepSz<float> moments, const bool binary, const int offsetX, const cudaStream_t stream) {
142+
dim3 blockSize(blockSizeX, blockSizeY);
143+
dim3 gridSize = dim3(divUp(src.rows, blockSizeY));
144+
if (offsetX)
145+
spatialMoments<TSrc, float, true, false, nMoments> << <gridSize, blockSize, 0, stream >> > (src, binary, moments.ptr(), offsetX);
146+
else
147+
spatialMoments<TSrc, float, true, true, nMoments> << <gridSize, blockSize, 0, stream >> > (src, binary, moments.ptr());
148+
149+
if (stream == 0)
150+
cudaSafeCall(cudaStreamSynchronize(stream));
151+
};
152+
};
153+
154+
template <typename TSrc, typename TMoments, int nMoments> struct momentsDispatcher : momentsDispatcherNonChar<TSrc, TMoments, nMoments> {};
155+
template <int nMoments> struct momentsDispatcher<uchar, float, nMoments> : momentsDispatcherChar<uchar, nMoments> {};
156+
template <int nMoments> struct momentsDispatcher<schar, float, nMoments> : momentsDispatcherChar<schar, nMoments> {};
157+
158+
template <typename TSrc, typename TMoments>
159+
void moments(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream) {
160+
if (order == 1)
161+
momentsDispatcher<TSrc, TMoments, n1>::call(static_cast<PtrStepSz<TSrc>>(src), static_cast<PtrStepSz<TMoments>>(moments), binary, offsetX, stream);
162+
else if (order == 2)
163+
momentsDispatcher<TSrc, TMoments, n12>::call(static_cast<PtrStepSz<TSrc>>(src), static_cast<PtrStepSz<TMoments>>(moments), binary, offsetX, stream);
164+
else if (order == 3)
165+
momentsDispatcher<TSrc, TMoments, n123>::call(static_cast<PtrStepSz<TSrc>>(src), static_cast<PtrStepSz<TMoments>>(moments), binary, offsetX, stream);
166+
};
167+
168+
template void moments<uchar, float>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
169+
template void moments<schar, float>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
170+
template void moments<ushort, float>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
171+
template void moments<short, float>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
172+
template void moments<int, float>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
173+
template void moments<float, float>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
174+
template void moments<double, float>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
175+
176+
template void moments<uchar, double>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
177+
template void moments<schar, double>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
178+
template void moments<ushort, double>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
179+
template void moments<short, double>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
180+
template void moments<int, double>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
181+
template void moments<float, double>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
182+
template void moments<double, double>(const PtrStepSzb src, PtrStepSzb moments, const bool binary, const int order, const int offsetX, const cudaStream_t stream);
183+
184+
}}}}
185+
186+
#endif /* CUDA_DISABLER */
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
namespace cv { namespace cuda { namespace device { namespace imgproc {
3+
constexpr int n1 = 3;
4+
constexpr int n12 = 6;
5+
constexpr int n123 = 10;
6+
}}}}

0 commit comments

Comments
 (0)