Skip to content

Commit 45d6f0e

Browse files
authored
[SYCL][NativeCPU] Fix kernel argument passing. (#16995)
We were reading the kernel arguments at kernel execution time, but kernel arguments are allowed to change between enqueuing and executing. Make sure to create a copy of kernel arguments ahead of time. This was previously approved as a unified-runtime PR: oneapi-src/unified-runtime#2700
1 parent 42a9485 commit 45d6f0e

File tree

4 files changed

+87
-60
lines changed

4 files changed

+87
-60
lines changed

unified-runtime/source/adapters/native_cpu/enqueue.cpp

+23-20
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
106106
pLocalWorkSize);
107107
auto &tp = hQueue->getDevice()->tp;
108108
const size_t numParallelThreads = tp.num_threads();
109-
hKernel->updateMemPool(numParallelThreads);
110109
std::vector<std::future<void>> futures;
111-
std::vector<std::function<void(size_t, ur_kernel_handle_t_)>> groups;
110+
std::vector<std::function<void(size_t, ur_kernel_handle_t_ &)>> groups;
112111
auto numWG0 = ndr.GlobalSize[0] / ndr.LocalSize[0];
113112
auto numWG1 = ndr.GlobalSize[1] / ndr.LocalSize[1];
114113
auto numWG2 = ndr.GlobalSize[2] / ndr.LocalSize[2];
@@ -119,16 +118,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
119118
auto event = new ur_event_handle_t_(hQueue, UR_COMMAND_KERNEL_LAUNCH);
120119
event->tick_start();
121120

121+
// Create a copy of the kernel and its arguments.
122+
auto kernel = std::make_unique<ur_kernel_handle_t_>(*hKernel);
123+
kernel->updateMemPool(numParallelThreads);
124+
122125
#ifndef NATIVECPU_USE_OCK
123-
hKernel->handleLocalArgs(1, 0);
124126
for (unsigned g2 = 0; g2 < numWG2; g2++) {
125127
for (unsigned g1 = 0; g1 < numWG1; g1++) {
126128
for (unsigned g0 = 0; g0 < numWG0; g0++) {
127129
for (unsigned local2 = 0; local2 < ndr.LocalSize[2]; local2++) {
128130
for (unsigned local1 = 0; local1 < ndr.LocalSize[1]; local1++) {
129131
for (unsigned local0 = 0; local0 < ndr.LocalSize[0]; local0++) {
130132
state.update(g0, g1, g2, local0, local1, local2);
131-
hKernel->_subhandler(hKernel->getArgs().data(), &state);
133+
kernel->_subhandler(kernel->getArgs(1, 0).data(), &state);
132134
}
133135
}
134136
}
@@ -139,7 +141,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
139141
bool isLocalSizeOne =
140142
ndr.LocalSize[0] == 1 && ndr.LocalSize[1] == 1 && ndr.LocalSize[2] == 1;
141143
if (isLocalSizeOne && ndr.GlobalSize[0] > numParallelThreads &&
142-
!hKernel->hasLocalArgs()) {
144+
!kernel->hasLocalArgs()) {
143145
// If the local size is one, we make the assumption that we are running a
144146
// parallel_for over a sycl::range.
145147
// Todo: we could add more compiler checks and
@@ -160,7 +162,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
160162
for (unsigned g1 = 0; g1 < numWG1; g1++) {
161163
for (unsigned g0 = 0; g0 < new_num_work_groups_0; g0 += 1) {
162164
futures.emplace_back(tp.schedule_task(
163-
[ndr, itemsPerThread, kernel = *hKernel, g0, g1, g2](size_t) {
165+
[ndr, itemsPerThread, &kernel = *kernel, g0, g1, g2](size_t) {
164166
native_cpu::state resized_state =
165167
getResizedState(ndr, itemsPerThread);
166168
resized_state.update(g0, g1, g2);
@@ -172,7 +174,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
172174
for (unsigned g0 = new_num_work_groups_0 * itemsPerThread; g0 < numWG0;
173175
g0++) {
174176
state.update(g0, g1, g2);
175-
hKernel->_subhandler(hKernel->getArgs().data(), &state);
177+
kernel->_subhandler(kernel->getArgs().data(), &state);
176178
}
177179
}
178180
}
@@ -185,12 +187,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
185187
for (unsigned g2 = 0; g2 < numWG2; g2++) {
186188
for (unsigned g1 = 0; g1 < numWG1; g1++) {
187189
futures.emplace_back(
188-
tp.schedule_task([state, kernel = *hKernel, numWG0, g1, g2,
190+
tp.schedule_task([state, &kernel = *kernel, numWG0, g1, g2,
189191
numParallelThreads](size_t threadId) mutable {
190192
for (unsigned g0 = 0; g0 < numWG0; g0++) {
191-
kernel.handleLocalArgs(numParallelThreads, threadId);
192193
state.update(g0, g1, g2);
193-
kernel._subhandler(kernel.getArgs().data(), &state);
194+
kernel._subhandler(
195+
kernel.getArgs(numParallelThreads, threadId).data(),
196+
&state);
194197
}
195198
}));
196199
}
@@ -202,13 +205,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
202205
for (unsigned g2 = 0; g2 < numWG2; g2++) {
203206
for (unsigned g1 = 0; g1 < numWG1; g1++) {
204207
for (unsigned g0 = 0; g0 < numWG0; g0++) {
205-
groups.push_back(
206-
[state, g0, g1, g2, numParallelThreads](
207-
size_t threadId, ur_kernel_handle_t_ kernel) mutable {
208-
kernel.handleLocalArgs(numParallelThreads, threadId);
209-
state.update(g0, g1, g2);
210-
kernel._subhandler(kernel.getArgs().data(), &state);
211-
});
208+
groups.push_back([state, g0, g1, g2, numParallelThreads](
209+
size_t threadId,
210+
ur_kernel_handle_t_ &kernel) mutable {
211+
state.update(g0, g1, g2);
212+
kernel._subhandler(
213+
kernel.getArgs(numParallelThreads, threadId).data(), &state);
214+
});
212215
}
213216
}
214217
}
@@ -218,7 +221,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
218221
for (unsigned thread = 0; thread < numParallelThreads; thread++) {
219222
futures.emplace_back(
220223
tp.schedule_task([groups, thread, groupsPerThread,
221-
kernel = *hKernel](size_t threadId) {
224+
&kernel = *kernel](size_t threadId) {
222225
for (unsigned i = 0; i < groupsPerThread; i++) {
223226
auto index = thread * groupsPerThread + i;
224227
groups[index](threadId, kernel);
@@ -231,7 +234,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
231234
futures.emplace_back(
232235
tp.schedule_task([groups, remainder,
233236
scheduled = numParallelThreads * groupsPerThread,
234-
kernel = *hKernel](size_t threadId) {
237+
&kernel = *kernel](size_t threadId) {
235238
for (unsigned i = 0; i < remainder; i++) {
236239
auto index = scheduled + i;
237240
groups[index](threadId, kernel);
@@ -247,7 +250,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
247250
if (phEvent) {
248251
*phEvent = event;
249252
}
250-
event->set_callback([hKernel, event]() {
253+
event->set_callback([kernel = std::move(kernel), hKernel, event]() {
251254
event->tick_end();
252255
// TODO: avoid calling clear() here.
253256
hKernel->_localArgInfo.clear();

unified-runtime/source/adapters/native_cpu/event.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ void ur_event_handle_t_::wait() {
146146
// The callback may need to acquire the lock, so we unlock it here
147147
lock.unlock();
148148

149-
if (callback)
149+
if (callback.valid())
150150
callback();
151151
}
152152

unified-runtime/source/adapters/native_cpu/event.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ struct ur_event_handle_t_ : RefCounted {
2121

2222
~ur_event_handle_t_();
2323

24-
void set_callback(const std::function<void()> &cb) { callback = cb; }
24+
template <typename T> auto set_callback(T &&cb) {
25+
callback = std::packaged_task<void()>(std::forward<T>(cb));
26+
}
2527

2628
void wait();
2729

@@ -60,7 +62,7 @@ struct ur_event_handle_t_ : RefCounted {
6062
bool done;
6163
std::mutex mutex;
6264
std::vector<std::future<void>> futures;
63-
std::function<void()> callback;
65+
std::packaged_task<void()> callback;
6466
uint64_t timestamp_start = 0;
6567
uint64_t timestamp_end = 0;
6668
};

unified-runtime/source/adapters/native_cpu/kernel.hpp

+59-37
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,9 @@ struct ur_kernel_handle_t_ : RefCounted {
3535
ur_kernel_handle_t_(const ur_kernel_handle_t_ &other)
3636
: Args(other.Args), hProgram(other.hProgram), _name(other._name),
3737
_subhandler(other._subhandler), _localArgInfo(other._localArgInfo),
38-
_localMemPool(other._localMemPool),
39-
_localMemPoolSize(other._localMemPoolSize),
40-
ReqdWGSize(other.ReqdWGSize) {
41-
incrementReferenceCount();
42-
}
38+
ReqdWGSize(other.ReqdWGSize) {}
4339

44-
~ur_kernel_handle_t_() {
45-
if (decrementReferenceCount() == 0) {
46-
free(_localMemPool);
47-
Args.deallocate();
48-
}
49-
}
40+
~ur_kernel_handle_t_() { free(_localMemPool); }
5041

5142
ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name,
5243
nativecpu_task_t subhandler,
@@ -64,27 +55,62 @@ struct ur_kernel_handle_t_ : RefCounted {
6455
std::vector<bool> OwnsMem;
6556
static constexpr size_t MaxAlign = 16 * sizeof(double);
6657

58+
arguments() = default;
59+
60+
arguments(const arguments &Other)
61+
: Indices(Other.Indices), ParamSizes(Other.ParamSizes),
62+
OwnsMem(Other.OwnsMem.size(), false) {
63+
for (size_t Index = 0; Index < Indices.size(); Index++) {
64+
if (!Other.OwnsMem[Index]) {
65+
continue;
66+
}
67+
addArg(Index, ParamSizes[Index], Indices[Index]);
68+
}
69+
}
70+
71+
arguments(arguments &&Other) : arguments() {
72+
std::swap(Indices, Other.Indices);
73+
std::swap(ParamSizes, Other.ParamSizes);
74+
std::swap(OwnsMem, Other.OwnsMem);
75+
}
76+
77+
~arguments() {
78+
assert(OwnsMem.size() == Indices.size() && "Size mismatch");
79+
for (size_t Index = 0; Index < Indices.size(); Index++) {
80+
if (!OwnsMem[Index]) {
81+
continue;
82+
}
83+
native_cpu::aligned_free(Indices[Index]);
84+
}
85+
}
86+
6787
/// Add an argument to the kernel.
6888
/// If the argument existed before, it is replaced.
6989
/// Otherwise, it is added.
7090
/// Gaps are filled with empty arguments.
7191
/// Implicit offset argument is kept at the back of the indices collection.
7292
void addArg(size_t Index, size_t Size, const void *Arg) {
93+
bool NeedAlloc = true;
7394
if (Index + 1 > Indices.size()) {
7495
Indices.resize(Index + 1);
7596
OwnsMem.resize(Index + 1);
7697
ParamSizes.resize(Index + 1);
77-
78-
// Update the stored value for the argument
79-
Indices[Index] = native_cpu::aligned_malloc(MaxAlign, Size);
80-
OwnsMem[Index] = true;
81-
ParamSizes[Index] = Size;
82-
} else {
83-
if (ParamSizes[Index] != Size) {
84-
Indices[Index] = realloc(Indices[Index], Size);
85-
ParamSizes[Index] = Size;
98+
} else if (OwnsMem[Index]) {
99+
if (ParamSizes[Index] == Size) {
100+
NeedAlloc = false;
101+
} else {
102+
native_cpu::aligned_free(Indices[Index]);
86103
}
87104
}
105+
if (NeedAlloc) {
106+
size_t Align = MaxAlign;
107+
while (Align > Size) {
108+
Align >>= 1;
109+
}
110+
Indices[Index] = native_cpu::aligned_malloc(Align, Size);
111+
ParamSizes[Index] = Size;
112+
OwnsMem[Index] = true;
113+
}
88114
std::memcpy(Indices[Index], Arg, Size);
89115
}
90116

@@ -100,17 +126,6 @@ struct ur_kernel_handle_t_ : RefCounted {
100126
Indices[Index] = Arg;
101127
}
102128

103-
// This is called by the destructor of ur_kernel_handle_t_, since
104-
// ur_kernel_handle_t_ implements reference counting and we want
105-
// to deallocate only when the reference count is 0.
106-
void deallocate() {
107-
assert(OwnsMem.size() == Indices.size() && "Size mismatch");
108-
for (size_t Index = 0; Index < Indices.size(); Index++) {
109-
if (OwnsMem[Index])
110-
native_cpu::aligned_free(Indices[Index]);
111-
}
112-
}
113-
114129
const args_index_t &getIndices() const noexcept { return Indices; }
115130

116131
} Args;
@@ -144,19 +159,26 @@ struct ur_kernel_handle_t_ : RefCounted {
144159

145160
bool hasLocalArgs() const { return !_localArgInfo.empty(); }
146161

147-
// To be called before executing a work group if local args are present
148-
void handleLocalArgs(size_t numParallelThread, size_t threadId) {
162+
const std::vector<void *> &getArgs() const {
163+
assert(!hasLocalArgs() && "For kernels with local arguments, thread "
164+
"information must be supplied.");
165+
return Args.getIndices();
166+
}
167+
168+
std::vector<void *> getArgs(size_t numThreads, size_t threadId) const {
169+
auto Result = Args.getIndices();
170+
149171
// For each local argument we have size*numthreads
150172
size_t offset = 0;
151173
for (auto &entry : _localArgInfo) {
152-
Args.Indices[entry.argIndex] =
174+
Result[entry.argIndex] =
153175
_localMemPool + offset + (entry.argSize * threadId);
154176
// update offset in the memory pool
155-
offset += entry.argSize * numParallelThread;
177+
offset += entry.argSize * numThreads;
156178
}
157-
}
158179

159-
const std::vector<void *> &getArgs() const { return Args.getIndices(); }
180+
return Result;
181+
}
160182

161183
void addArg(const void *Ptr, size_t Index, size_t Size) {
162184
Args.addArg(Index, Size, Ptr);

0 commit comments

Comments
 (0)