Skip to content

Commit 9fd7358

Browse files
[SYCL] Avoid unnecessary kernel copies (#17584)
1 parent e59fcca commit 9fd7358

File tree

3 files changed

+71
-22
lines changed

3 files changed

+71
-22
lines changed

sycl/include/sycl/detail/cg_types.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ class HostKernel : public HostKernelBase {
167167
KernelType MKernel;
168168

169169
public:
170-
HostKernel(KernelType Kernel) : MKernel(Kernel) {}
170+
HostKernel(const KernelType &Kernel) : MKernel(Kernel) {}
171+
HostKernel(KernelType &&Kernel) : MKernel(std::move(Kernel)) {}
171172

172173
char *getPtr() override { return reinterpret_cast<char *>(&MKernel); }
173174

sycl/include/sycl/handler.hpp

+26-21
Original file line numberDiff line numberDiff line change
@@ -715,17 +715,17 @@ class __SYCL_EXPORT handler {
715715
/// \param KernelFunc is a SYCL kernel function
716716
/// \param ParamDescs is the vector of kernel parameter descriptors.
717717
template <typename KernelName, typename KernelType, int Dims,
718-
typename LambdaArgType>
719-
void StoreLambda(KernelType KernelFunc) {
718+
typename LambdaArgType, typename KernelTypeUniversalRef>
719+
void StoreLambda(KernelTypeUniversalRef &&KernelFunc) {
720720
constexpr bool IsCallableWithKernelHandler =
721721
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
722722
LambdaArgType>::value;
723723

724724
// Not using `std::make_unique` to avoid unnecessary instantiations of
725725
// `std::unique_ptr<HostKernel<...>>`. Only
726726
// `std::unique_ptr<HostKernelBase>` is necessary.
727-
MHostKernel.reset(
728-
new detail::HostKernel<KernelType, LambdaArgType, Dims>(KernelFunc));
727+
MHostKernel.reset(new detail::HostKernel<KernelType, LambdaArgType, Dims>(
728+
std::forward<KernelTypeUniversalRef>(KernelFunc)));
729729

730730
constexpr bool KernelHasName =
731731
detail::getKernelName<KernelName>() != nullptr &&
@@ -739,7 +739,7 @@ class __SYCL_EXPORT handler {
739739
#ifdef __INTEL_SYCL_USE_INTEGRATION_HEADERS
740740
static_assert(
741741
!KernelHasName ||
742-
sizeof(KernelFunc) == detail::getKernelSize<KernelName>(),
742+
sizeof(KernelType) == detail::getKernelSize<KernelName>(),
743743
"Unexpected kernel lambda size. This can be caused by an "
744744
"external host compiler producing a lambda with an "
745745
"unexpected layout. This is a limitation of the compiler."
@@ -1133,7 +1133,7 @@ class __SYCL_EXPORT handler {
11331133
typename KernelName, typename KernelType, int Dims,
11341134
typename PropertiesT = ext::oneapi::experimental::empty_properties_t>
11351135
void parallel_for_lambda_impl(range<Dims> UserRange, PropertiesT Props,
1136-
KernelType KernelFunc) {
1136+
const KernelType &KernelFunc) {
11371137
#ifndef __SYCL_DEVICE_ONLY__
11381138
throwIfActionIsCreated();
11391139
throwOnKernelParameterMisuse<KernelName, KernelType>();
@@ -1545,19 +1545,22 @@ class __SYCL_EXPORT handler {
15451545
// methods side.
15461546

15471547
template <typename... TypesToForward, typename... ArgsTy>
1548-
static void kernel_single_task_unpack(handler *h, ArgsTy... Args) {
1549-
h->kernel_single_task<TypesToForward..., Props...>(Args...);
1548+
static void kernel_single_task_unpack(handler *h, ArgsTy &&...Args) {
1549+
h->kernel_single_task<TypesToForward..., Props...>(
1550+
std::forward<ArgsTy>(Args)...);
15501551
}
15511552

15521553
template <typename... TypesToForward, typename... ArgsTy>
1553-
static void kernel_parallel_for_unpack(handler *h, ArgsTy... Args) {
1554-
h->kernel_parallel_for<TypesToForward..., Props...>(Args...);
1554+
static void kernel_parallel_for_unpack(handler *h, ArgsTy &&...Args) {
1555+
h->kernel_parallel_for<TypesToForward..., Props...>(
1556+
std::forward<ArgsTy>(Args)...);
15551557
}
15561558

15571559
template <typename... TypesToForward, typename... ArgsTy>
15581560
static void kernel_parallel_for_work_group_unpack(handler *h,
1559-
ArgsTy... Args) {
1560-
h->kernel_parallel_for_work_group<TypesToForward..., Props...>(Args...);
1561+
ArgsTy &&...Args) {
1562+
h->kernel_parallel_for_work_group<TypesToForward..., Props...>(
1563+
std::forward<ArgsTy>(Args)...);
15611564
}
15621565
};
15631566

@@ -1622,9 +1625,9 @@ class __SYCL_EXPORT handler {
16221625
void kernel_single_task_wrapper(const KernelType &KernelFunc) {
16231626
unpack<KernelName, KernelType, PropertiesT,
16241627
detail::KernelLambdaHasKernelHandlerArgT<KernelType>::value>(
1625-
KernelFunc, [&](auto Unpacker, auto... args) {
1628+
KernelFunc, [&](auto Unpacker, auto &&...args) {
16261629
Unpacker.template kernel_single_task_unpack<KernelName, KernelType>(
1627-
args...);
1630+
std::forward<decltype(args)>(args)...);
16281631
});
16291632
}
16301633

@@ -1635,9 +1638,10 @@ class __SYCL_EXPORT handler {
16351638
unpack<KernelName, KernelType, PropertiesT,
16361639
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
16371640
ElementType>::value>(
1638-
KernelFunc, [&](auto Unpacker, auto... args) {
1641+
KernelFunc, [&](auto Unpacker, auto &&...args) {
16391642
Unpacker.template kernel_parallel_for_unpack<KernelName, ElementType,
1640-
KernelType>(args...);
1643+
KernelType>(
1644+
std::forward<decltype(args)>(args)...);
16411645
});
16421646
}
16431647

@@ -1648,9 +1652,10 @@ class __SYCL_EXPORT handler {
16481652
unpack<KernelName, KernelType, PropertiesT,
16491653
detail::KernelLambdaHasKernelHandlerArgT<KernelType,
16501654
ElementType>::value>(
1651-
KernelFunc, [&](auto Unpacker, auto... args) {
1655+
KernelFunc, [&](auto Unpacker, auto &&...args) {
16521656
Unpacker.template kernel_parallel_for_work_group_unpack<
1653-
KernelName, ElementType, KernelType>(args...);
1657+
KernelName, ElementType, KernelType>(
1658+
std::forward<decltype(args)>(args)...);
16541659
});
16551660
}
16561661

@@ -1900,21 +1905,21 @@ class __SYCL_EXPORT handler {
19001905
void parallel_for(range<1> NumWorkItems, const KernelType &KernelFunc) {
19011906
parallel_for_lambda_impl<KernelName>(
19021907
NumWorkItems, ext::oneapi::experimental::empty_properties_t{},
1903-
std::move(KernelFunc));
1908+
KernelFunc);
19041909
}
19051910

19061911
template <typename KernelName = detail::auto_name, typename KernelType>
19071912
void parallel_for(range<2> NumWorkItems, const KernelType &KernelFunc) {
19081913
parallel_for_lambda_impl<KernelName>(
19091914
NumWorkItems, ext::oneapi::experimental::empty_properties_t{},
1910-
std::move(KernelFunc));
1915+
KernelFunc);
19111916
}
19121917

19131918
template <typename KernelName = detail::auto_name, typename KernelType>
19141919
void parallel_for(range<3> NumWorkItems, const KernelType &KernelFunc) {
19151920
parallel_for_lambda_impl<KernelName>(
19161921
NumWorkItems, ext::oneapi::experimental::empty_properties_t{},
1917-
std::move(KernelFunc));
1922+
KernelFunc);
19181923
}
19191924

19201925
/// Enqueues a command to the SYCL runtime to invoke \p Func once.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: %{build} -o %t.out
2+
// RUN: %{run} %t.out
3+
4+
#include <sycl/detail/core.hpp>
5+
6+
size_t copy_count = 0;
7+
size_t move_count = 0;
8+
9+
template <int N> class kernel {
10+
public:
11+
kernel() {};
12+
kernel(const kernel &other) { copy_count++; };
13+
kernel(kernel &&other) { ++move_count; }
14+
15+
void operator()(sycl::id<1> id) const {}
16+
void operator()(sycl::nd_item<1> id) const {}
17+
void operator()() const {}
18+
};
19+
template <int N> struct sycl::is_device_copyable<kernel<N>> : std::true_type {};
20+
21+
int main(int argc, char **argv) {
22+
sycl::queue q;
23+
24+
kernel<0> krn0;
25+
q.parallel_for(sycl::range<1>{1}, krn0);
26+
assert(copy_count == 1);
27+
assert(move_count == 0);
28+
copy_count = 0;
29+
30+
kernel<1> krn1;
31+
q.parallel_for(sycl::nd_range<1>{1, 1}, krn1);
32+
assert(copy_count == 1);
33+
assert(move_count == 0);
34+
copy_count = 0;
35+
36+
kernel<2> krn2;
37+
q.single_task(krn2);
38+
assert(copy_count == 1);
39+
assert(move_count == 0);
40+
copy_count = 0;
41+
42+
return 0;
43+
}

0 commit comments

Comments
 (0)