Skip to content

Commit 3fd11f1

Browse files
authored
Merge pull request #1246 from 0x12CC/cooperative_kernel_functions
[UR] Add default implementation for cooperative kernel functions
2 parents 24078c2 + 8a8d704 commit 3fd11f1

21 files changed

+176
-24
lines changed

include/ur_api.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -8692,8 +8692,12 @@ urEnqueueCooperativeKernelLaunchExp(
86928692
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
86938693
UR_APIEXPORT ur_result_t UR_APICALL
86948694
urKernelSuggestMaxCooperativeGroupCountExp(
8695-
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8696-
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
8695+
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8696+
size_t localWorkSize, ///< [in] number of local work-items that will form a work-group when the
8697+
///< kernel is launched
8698+
size_t dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
8699+
///< that will be used when the kernel is launched
8700+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
86978701
);
86988702

86998703
#if !defined(__GNUC__)
@@ -9641,6 +9645,8 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
96419645
/// allowing the callback the ability to modify the parameter's value
96429646
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
96439647
ur_kernel_handle_t *phKernel;
9648+
size_t *plocalWorkSize;
9649+
size_t *pdynamicSharedMemorySize;
96449650
uint32_t **ppGroupCountRet;
96459651
} ur_kernel_suggest_max_cooperative_group_count_exp_params_t;
96469652

include/ur_ddi.h

+2
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,8 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
627627
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
628628
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
629629
ur_kernel_handle_t,
630+
size_t,
631+
size_t,
630632
uint32_t *);
631633

632634
///////////////////////////////////////////////////////////////////////////////

include/ur_print.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -11399,6 +11399,16 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1139911399
ur::details::printPtr(os,
1140011400
*(params->phKernel));
1140111401

11402+
os << ", ";
11403+
os << ".localWorkSize = ";
11404+
11405+
os << *(params->plocalWorkSize);
11406+
11407+
os << ", ";
11408+
os << ".dynamicSharedMemorySize = ";
11409+
11410+
os << *(params->pdynamicSharedMemorySize);
11411+
1140211412
os << ", ";
1140311413
os << ".pGroupCountRet = ";
1140411414

scripts/core/exp-cooperative-kernels.yml

+6
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ params:
7878
- type: $x_kernel_handle_t
7979
name: hKernel
8080
desc: "[in] handle of the kernel object"
81+
- type: size_t
82+
name: localWorkSize
83+
desc: "[in] number of local work-items that will form a work-group when the kernel is launched"
84+
- type: size_t
85+
name: dynamicSharedMemorySize
86+
desc: "[in] size of dynamic shared memory, for each work-group, in bytes, that will be used when the kernel is launched"
8187
- type: "uint32_t*"
8288
name: "pGroupCountRet"
8389
desc: "[out] pointer to maximum number of groups"

source/adapters/cuda/enqueue.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
494494
return Result;
495495
}
496496

497+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
498+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
499+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
500+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
501+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
502+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
503+
pGlobalWorkSize, pLocalWorkSize,
504+
numEventsInWaitList, phEventWaitList, phEvent);
505+
}
506+
497507
/// Set parameters for general 3D memory copy.
498508
/// If the source and/or destination is on the device, SrcPtr and/or DstPtr
499509
/// must be a pointer to a CUdeviceptr

source/adapters/cuda/kernel.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
169169
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
170170
}
171171

172+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
173+
ur_kernel_handle_t hKernel, size_t localWorkSize,
174+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
175+
(void)hKernel;
176+
(void)localWorkSize;
177+
(void)dynamicSharedMemorySize;
178+
*pGroupCountRet = 1;
179+
return UR_RESULT_SUCCESS;
180+
}
181+
172182
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
173183
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
174184
const ur_kernel_arg_value_properties_t *pProperties,

source/adapters/cuda/ur_interface_loader.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
404404
return result;
405405
}
406406

407-
pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
407+
pDdiTable->pfnCooperativeKernelLaunchExp =
408+
urEnqueueCooperativeKernelLaunchExp;
408409

409410
return UR_RESULT_SUCCESS;
410411
}
@@ -416,7 +417,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
416417
return result;
417418
}
418419

419-
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
420+
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
421+
urKernelSuggestMaxCooperativeGroupCountExp;
420422

421423
return UR_RESULT_SUCCESS;
422424
}

source/adapters/hip/enqueue.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
465465
return Result;
466466
}
467467

468+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
469+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
470+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
471+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
472+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
473+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
474+
pGlobalWorkSize, pLocalWorkSize,
475+
numEventsInWaitList, phEventWaitList, phEvent);
476+
}
477+
468478
/// Enqueues a wait on the given queue for all events.
469479
/// See \ref enqueueEventWait
470480
///

source/adapters/hip/kernel.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
158158
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
159159
}
160160

161+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
162+
ur_kernel_handle_t hKernel, size_t localWorkSize,
163+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
164+
(void)hKernel;
165+
(void)localWorkSize;
166+
(void)dynamicSharedMemorySize;
167+
*pGroupCountRet = 1;
168+
return UR_RESULT_SUCCESS;
169+
}
170+
161171
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
162172
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
163173
const ur_kernel_arg_value_properties_t *, const void *pArgValue) {

source/adapters/hip/ur_interface_loader.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
374374
return result;
375375
}
376376

377-
pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
377+
pDdiTable->pfnCooperativeKernelLaunchExp =
378+
urEnqueueCooperativeKernelLaunchExp;
378379

379380
return UR_RESULT_SUCCESS;
380381
}
@@ -386,7 +387,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
386387
return result;
387388
}
388389

389-
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
390+
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
391+
urKernelSuggestMaxCooperativeGroupCountExp;
390392

391393
return UR_RESULT_SUCCESS;
392394
}

source/adapters/level_zero/kernel.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
264264
return UR_RESULT_SUCCESS;
265265
}
266266

267+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
268+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
269+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
270+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
271+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
272+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
273+
pGlobalWorkSize, pLocalWorkSize,
274+
numEventsInWaitList, phEventWaitList, phEvent);
275+
}
276+
267277
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
268278
ur_queue_handle_t Queue, ///< [in] handle of the queue to submit to.
269279
ur_program_handle_t Program, ///< [in] handle of the program containing the
@@ -787,6 +797,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
787797
return UR_RESULT_SUCCESS;
788798
}
789799

800+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
801+
ur_kernel_handle_t hKernel, size_t localWorkSize,
802+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
803+
(void)hKernel;
804+
(void)localWorkSize;
805+
(void)dynamicSharedMemorySize;
806+
*pGroupCountRet = 1;
807+
return UR_RESULT_SUCCESS;
808+
}
809+
790810
UR_APIEXPORT ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
791811
ur_native_handle_t NativeKernel, ///< [in] the native handle of the kernel.
792812
ur_context_handle_t Context, ///< [in] handle of the context object

source/adapters/level_zero/ur_interface_loader.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
451451
return result;
452452
}
453453

454-
pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
454+
pDdiTable->pfnCooperativeKernelLaunchExp =
455+
urEnqueueCooperativeKernelLaunchExp;
455456

456457
return UR_RESULT_SUCCESS;
457458
}
@@ -463,7 +464,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
463464
return result;
464465
}
465466

466-
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
467+
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
468+
urKernelSuggestMaxCooperativeGroupCountExp;
467469

468470
return UR_RESULT_SUCCESS;
469471
}

source/adapters/null/ur_nullddi.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -5443,15 +5443,22 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
54435443
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
54445444
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
54455445
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
5446-
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
5446+
size_t
5447+
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
5448+
///< kernel is launched
5449+
size_t
5450+
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
5451+
///< that will be used when the kernel is launched
5452+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
54475453
) try {
54485454
ur_result_t result = UR_RESULT_SUCCESS;
54495455

54505456
// if the driver has created a custom function, then call it instead of using the generic path
54515457
auto pfnSuggestMaxCooperativeGroupCountExp =
54525458
d_context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp;
54535459
if (nullptr != pfnSuggestMaxCooperativeGroupCountExp) {
5454-
result = pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
5460+
result = pfnSuggestMaxCooperativeGroupCountExp(
5461+
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
54555462
} else {
54565463
// generic implementation
54575464
}

source/adapters/opencl/enqueue.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
4141
return UR_RESULT_SUCCESS;
4242
}
4343

44+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
45+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
46+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
47+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
48+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
49+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
50+
pGlobalWorkSize, pLocalWorkSize,
51+
numEventsInWaitList, phEventWaitList, phEvent);
52+
}
53+
4454
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWait(
4555
ur_queue_handle_t hQueue, uint32_t numEventsInWaitList,
4656
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

source/adapters/opencl/kernel.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "common.hpp"
1111

1212
#include <algorithm>
13+
#include <cstddef>
1314
#include <memory>
1415

1516
UR_APIEXPORT ur_result_t UR_APICALL
@@ -376,6 +377,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
376377
return UR_RESULT_SUCCESS;
377378
}
378379

380+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
381+
ur_kernel_handle_t hKernel, size_t localWorkSize,
382+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
383+
(void)hKernel;
384+
(void)localWorkSize;
385+
(void)dynamicSharedMemorySize;
386+
*pGroupCountRet = 1;
387+
return UR_RESULT_SUCCESS;
388+
}
389+
379390
UR_APIEXPORT ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
380391
ur_native_handle_t hNativeKernel, ur_context_handle_t, ur_program_handle_t,
381392
const ur_kernel_native_properties_t *pProperties,

source/adapters/opencl/ur_interface_loader.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
395395
return result;
396396
}
397397

398-
pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
398+
pDdiTable->pfnCooperativeKernelLaunchExp =
399+
urEnqueueCooperativeKernelLaunchExp;
399400

400401
return UR_RESULT_SUCCESS;
401402
}
@@ -407,7 +408,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
407408
return result;
408409
}
409410

410-
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
411+
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
412+
urKernelSuggestMaxCooperativeGroupCountExp;
411413

412414
return UR_RESULT_SUCCESS;
413415
}

source/loader/layers/tracing/ur_trcddi.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -6037,7 +6037,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
60376037
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
60386038
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
60396039
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
6040-
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
6040+
size_t
6041+
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
6042+
///< kernel is launched
6043+
size_t
6044+
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
6045+
///< that will be used when the kernel is launched
6046+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
60416047
) {
60426048
auto pfnSuggestMaxCooperativeGroupCountExp =
60436049
context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp;
@@ -6047,13 +6053,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
60476053
}
60486054

60496055
ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = {
6050-
&hKernel, &pGroupCountRet};
6056+
&hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet};
60516057
uint64_t instance = context.notify_begin(
60526058
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP,
60536059
"urKernelSuggestMaxCooperativeGroupCountExp", &params);
60546060

6055-
ur_result_t result =
6056-
pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
6061+
ur_result_t result = pfnSuggestMaxCooperativeGroupCountExp(
6062+
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
60576063

60586064
context.notify_end(
60596065
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP,

source/loader/layers/validation/ur_valddi.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -8827,7 +8827,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
88278827
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
88288828
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
88298829
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8830-
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
8830+
size_t
8831+
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
8832+
///< kernel is launched
8833+
size_t
8834+
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
8835+
///< that will be used when the kernel is launched
8836+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
88318837
) {
88328838
auto pfnSuggestMaxCooperativeGroupCountExp =
88338839
context.urDdiTable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp;
@@ -8851,8 +8857,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
88518857
refCountContext.logInvalidReference(hKernel);
88528858
}
88538859

8854-
ur_result_t result =
8855-
pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
8860+
ur_result_t result = pfnSuggestMaxCooperativeGroupCountExp(
8861+
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
88568862

88578863
return result;
88588864
}

source/loader/ur_ldrddi.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -7571,7 +7571,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
75717571
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
75727572
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
75737573
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
7574-
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
7574+
size_t
7575+
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
7576+
///< kernel is launched
7577+
size_t
7578+
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
7579+
///< that will be used when the kernel is launched
7580+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
75757581
) {
75767582
ur_result_t result = UR_RESULT_SUCCESS;
75777583

@@ -7587,7 +7593,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
75877593
hKernel = reinterpret_cast<ur_kernel_object_t *>(hKernel)->handle;
75887594

75897595
// forward to device-platform
7590-
result = pfnSuggestMaxCooperativeGroupCountExp(hKernel, pGroupCountRet);
7596+
result = pfnSuggestMaxCooperativeGroupCountExp(
7597+
hKernel, localWorkSize, dynamicSharedMemorySize, pGroupCountRet);
75917598

75927599
return result;
75937600
}

0 commit comments

Comments
 (0)