@@ -106,9 +106,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
106
106
pLocalWorkSize);
107
107
auto &tp = hQueue->getDevice ()->tp ;
108
108
const size_t numParallelThreads = tp.num_threads ();
109
- hKernel->updateMemPool (numParallelThreads);
110
109
std::vector<std::future<void >> futures;
111
- std::vector<std::function<void (size_t , ur_kernel_handle_t_)>> groups;
110
+ std::vector<std::function<void (size_t , ur_kernel_handle_t_ & )>> groups;
112
111
auto numWG0 = ndr.GlobalSize [0 ] / ndr.LocalSize [0 ];
113
112
auto numWG1 = ndr.GlobalSize [1 ] / ndr.LocalSize [1 ];
114
113
auto numWG2 = ndr.GlobalSize [2 ] / ndr.LocalSize [2 ];
@@ -119,16 +118,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
119
118
auto event = new ur_event_handle_t_ (hQueue, UR_COMMAND_KERNEL_LAUNCH);
120
119
event->tick_start ();
121
120
121
+ // Create a copy of the kernel and its arguments.
122
+ auto kernel = std::make_unique<ur_kernel_handle_t_>(*hKernel);
123
+ kernel->updateMemPool (numParallelThreads);
124
+
122
125
#ifndef NATIVECPU_USE_OCK
123
- hKernel->handleLocalArgs (1 , 0 );
124
126
for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
125
127
for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
126
128
for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
127
129
for (unsigned local2 = 0 ; local2 < ndr.LocalSize [2 ]; local2++) {
128
130
for (unsigned local1 = 0 ; local1 < ndr.LocalSize [1 ]; local1++) {
129
131
for (unsigned local0 = 0 ; local0 < ndr.LocalSize [0 ]; local0++) {
130
132
state.update (g0, g1, g2, local0, local1, local2);
131
- hKernel ->_subhandler (hKernel ->getArgs ().data (), &state);
133
+ kernel ->_subhandler (kernel ->getArgs (1 , 0 ).data (), &state);
132
134
}
133
135
}
134
136
}
@@ -139,7 +141,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
139
141
bool isLocalSizeOne =
140
142
ndr.LocalSize [0 ] == 1 && ndr.LocalSize [1 ] == 1 && ndr.LocalSize [2 ] == 1 ;
141
143
if (isLocalSizeOne && ndr.GlobalSize [0 ] > numParallelThreads &&
142
- !hKernel ->hasLocalArgs ()) {
144
+ !kernel ->hasLocalArgs ()) {
143
145
// If the local size is one, we make the assumption that we are running a
144
146
// parallel_for over a sycl::range.
145
147
// Todo: we could add more compiler checks and
@@ -160,7 +162,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
160
162
for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
161
163
for (unsigned g0 = 0 ; g0 < new_num_work_groups_0; g0 += 1 ) {
162
164
futures.emplace_back (tp.schedule_task (
163
- [ndr, itemsPerThread, kernel = *hKernel , g0, g1, g2](size_t ) {
165
+ [ndr, itemsPerThread, & kernel = *kernel , g0, g1, g2](size_t ) {
164
166
native_cpu::state resized_state =
165
167
getResizedState (ndr, itemsPerThread);
166
168
resized_state.update (g0, g1, g2);
@@ -172,7 +174,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
172
174
for (unsigned g0 = new_num_work_groups_0 * itemsPerThread; g0 < numWG0;
173
175
g0++) {
174
176
state.update (g0, g1, g2);
175
- hKernel ->_subhandler (hKernel ->getArgs ().data (), &state);
177
+ kernel ->_subhandler (kernel ->getArgs ().data (), &state);
176
178
}
177
179
}
178
180
}
@@ -185,12 +187,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
185
187
for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
186
188
for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
187
189
futures.emplace_back (
188
- tp.schedule_task ([state, kernel = *hKernel , numWG0, g1, g2,
190
+ tp.schedule_task ([state, & kernel = *kernel , numWG0, g1, g2,
189
191
numParallelThreads](size_t threadId) mutable {
190
192
for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
191
- kernel.handleLocalArgs (numParallelThreads, threadId);
192
193
state.update (g0, g1, g2);
193
- kernel._subhandler (kernel.getArgs ().data (), &state);
194
+ kernel._subhandler (
195
+ kernel.getArgs (numParallelThreads, threadId).data (),
196
+ &state);
194
197
}
195
198
}));
196
199
}
@@ -202,13 +205,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
202
205
for (unsigned g2 = 0 ; g2 < numWG2; g2++) {
203
206
for (unsigned g1 = 0 ; g1 < numWG1; g1++) {
204
207
for (unsigned g0 = 0 ; g0 < numWG0; g0++) {
205
- groups.push_back (
206
- [state, g0, g1, g2, numParallelThreads](
207
- size_t threadId, ur_kernel_handle_t_ kernel) mutable {
208
- kernel. handleLocalArgs (numParallelThreads, threadId );
209
- state. update (g0, g1, g2);
210
- kernel._subhandler (kernel. getArgs ().data (), &state);
211
- });
208
+ groups.push_back ([state, g0, g1, g2, numParallelThreads](
209
+ size_t threadId,
210
+ ur_kernel_handle_t_ & kernel) mutable {
211
+ state. update (g0, g1, g2 );
212
+ kernel. _subhandler (
213
+ kernel.getArgs (numParallelThreads, threadId ).data (), &state);
214
+ });
212
215
}
213
216
}
214
217
}
@@ -218,7 +221,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
218
221
for (unsigned thread = 0 ; thread < numParallelThreads; thread++) {
219
222
futures.emplace_back (
220
223
tp.schedule_task ([groups, thread, groupsPerThread,
221
- kernel = *hKernel ](size_t threadId) {
224
+ & kernel = *kernel ](size_t threadId) {
222
225
for (unsigned i = 0 ; i < groupsPerThread; i++) {
223
226
auto index = thread * groupsPerThread + i;
224
227
groups[index ](threadId, kernel);
@@ -231,7 +234,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
231
234
futures.emplace_back (
232
235
tp.schedule_task ([groups, remainder ,
233
236
scheduled = numParallelThreads * groupsPerThread,
234
- kernel = *hKernel ](size_t threadId) {
237
+ & kernel = *kernel ](size_t threadId) {
235
238
for (unsigned i = 0 ; i < remainder ; i++) {
236
239
auto index = scheduled + i;
237
240
groups[index ](threadId, kernel);
@@ -247,7 +250,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
247
250
if (phEvent) {
248
251
*phEvent = event;
249
252
}
250
- event->set_callback ([hKernel, event]() {
253
+ event->set_callback ([kernel = std::move (kernel), hKernel, event]() {
251
254
event->tick_end ();
252
255
// TODO: avoid calling clear() here.
253
256
hKernel->_localArgInfo .clear ();
0 commit comments