@@ -2385,7 +2385,10 @@ static ur_result_t SetKernelParamsAndLaunch(
2385
2385
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2386
2386
bool IsCooperative, bool KernelUsesClusterLaunch,
2387
2387
uint32_t WorkGroupMemorySize, const RTDeviceBinaryImage *BinImage,
2388
- KernelNameStrRefT KernelName) {
2388
+ KernelNameStrRefT KernelName, void *KernelFuncPtr = nullptr,
2389
+ int KernelNumArgs = 0,
2390
+ detail::kernel_param_desc_t (*KernelParamDescGetter)(int ) = nullptr,
2391
+ bool KernelHasSpecialCaptures = false) {
2389
2392
assert (Queue && " Kernel submissions should have an associated queue" );
2390
2393
const AdapterPtr &Adapter = Queue->getAdapter ();
2391
2394
@@ -2397,13 +2400,37 @@ static ur_result_t SetKernelParamsAndLaunch(
2397
2400
: Empty);
2398
2401
}
2399
2402
2400
- auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2401
- &Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2402
- SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2403
- Queue->getContextImplPtr (), Arg, NextTrueIndex);
2404
- };
2405
-
2406
- applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
2403
+ if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2404
+ // TODO: Refactor to avoid SetArgBasedOnType duplication
2405
+ // TODO: Find a way to use the built-ins instead of variables.
2406
+ for (int I = 0 ; I < KernelNumArgs; ++I) {
2407
+ auto ParamDesc = KernelParamDescGetter (I);
2408
+ const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
2409
+ switch (ParamDesc.kind ) {
2410
+ case kernel_param_kind_t ::kind_std_layout: {
2411
+ int Size = ParamDesc.info ;
2412
+ Adapter->call <UrApiKind::urKernelSetArgValue>(Kernel, I, Size, nullptr ,
2413
+ ArgPtr);
2414
+ break ;
2415
+ }
2416
+ case kernel_param_kind_t ::kind_pointer: {
2417
+ const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2418
+ Adapter->call <UrApiKind::urKernelSetArgPointer>(Kernel, I, nullptr ,
2419
+ Ptr);
2420
+ break ;
2421
+ }
2422
+ default :
2423
+ throw std::runtime_error (" Direct kernel argument copy failed." );
2424
+ }
2425
+ }
2426
+ } else {
2427
+ auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2428
+ &Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2429
+ SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2430
+ Queue->getContextImplPtr (), Arg, NextTrueIndex);
2431
+ };
2432
+ applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
2433
+ }
2407
2434
2408
2435
std::optional<int > ImplicitLocalArg =
2409
2436
ProgramManager::getInstance ().kernelImplicitLocalArgPos (KernelName);
@@ -2655,7 +2682,9 @@ void enqueueImpKernel(
2655
2682
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2656
2683
ur_kernel_cache_config_t KernelCacheConfig, const bool KernelIsCooperative,
2657
2684
const bool KernelUsesClusterLaunch, const size_t WorkGroupMemorySize,
2658
- const RTDeviceBinaryImage *BinImage) {
2685
+ const RTDeviceBinaryImage *BinImage, void *KernelFuncPtr, int KernelNumArgs,
2686
+ detail::kernel_param_desc_t (*KernelParamDescGetter)(int ),
2687
+ bool KernelHasSpecialCaptures) {
2659
2688
assert (Queue && " Kernel submissions should have an associated queue" );
2660
2689
// Run OpenCL kernel
2661
2690
auto &ContextImpl = Queue->getContextImplPtr ();
@@ -2739,7 +2768,8 @@ void enqueueImpKernel(
2739
2768
Queue, Args, DeviceImageImpl, Kernel, NDRDesc, EventsWaitList,
2740
2769
OutEventImpl, EliminatedArgMask, getMemAllocationFunc,
2741
2770
KernelIsCooperative, KernelUsesClusterLaunch, WorkGroupMemorySize,
2742
- BinImage, KernelName);
2771
+ BinImage, KernelName, KernelFuncPtr, KernelNumArgs,
2772
+ KernelParamDescGetter, KernelHasSpecialCaptures);
2743
2773
2744
2774
const AdapterPtr &Adapter = Queue->getAdapter ();
2745
2775
if (!SyclKernelImpl && !MSyclKernel) {
0 commit comments