Skip to content

Commit c2813f0

Browse files
committed
Extract kernel arguments directly on the fast path
Signed-off-by: John Pennycook <[email protected]>
1 parent a92ff30 commit c2813f0

File tree

3 files changed

+64
-26
lines changed

3 files changed

+64
-26
lines changed

sycl/source/detail/scheduler/commands.cpp

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,7 +2385,10 @@ static ur_result_t SetKernelParamsAndLaunch(
23852385
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
23862386
bool IsCooperative, bool KernelUsesClusterLaunch,
23872387
uint32_t WorkGroupMemorySize, const RTDeviceBinaryImage *BinImage,
2388-
KernelNameStrRefT KernelName) {
2388+
KernelNameStrRefT KernelName, void *KernelFuncPtr = nullptr,
2389+
int KernelNumArgs = 0,
2390+
detail::kernel_param_desc_t (*KernelParamDescGetter)(int) = nullptr,
2391+
bool KernelHasSpecialCaptures = false) {
23892392
assert(Queue && "Kernel submissions should have an associated queue");
23902393
const AdapterPtr &Adapter = Queue->getAdapter();
23912394

@@ -2397,13 +2400,37 @@ static ur_result_t SetKernelParamsAndLaunch(
23972400
: Empty);
23982401
}
23992402

2400-
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2401-
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2402-
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2403-
Queue->getContextImplPtr(), Arg, NextTrueIndex);
2404-
};
2405-
2406-
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
2403+
if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2404+
// TODO: Refactor to avoid SetArgBasedOnType duplication
2405+
// TODO: Find a way to use the built-ins instead of variables.
2406+
for (int I = 0; I < KernelNumArgs; ++I) {
2407+
auto ParamDesc = KernelParamDescGetter(I);
2408+
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
2409+
switch (ParamDesc.kind) {
2410+
case kernel_param_kind_t::kind_std_layout: {
2411+
int Size = ParamDesc.info;
2412+
Adapter->call<UrApiKind::urKernelSetArgValue>(Kernel, I, Size, nullptr,
2413+
ArgPtr);
2414+
break;
2415+
}
2416+
case kernel_param_kind_t::kind_pointer: {
2417+
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
2418+
Adapter->call<UrApiKind::urKernelSetArgPointer>(Kernel, I, nullptr,
2419+
Ptr);
2420+
break;
2421+
}
2422+
default:
2423+
throw std::runtime_error("Direct kernel argument copy failed.");
2424+
}
2425+
}
2426+
} else {
2427+
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
2428+
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2429+
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2430+
Queue->getContextImplPtr(), Arg, NextTrueIndex);
2431+
};
2432+
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
2433+
}
24072434

24082435
std::optional<int> ImplicitLocalArg =
24092436
ProgramManager::getInstance().kernelImplicitLocalArgPos(KernelName);
@@ -2655,7 +2682,9 @@ void enqueueImpKernel(
26552682
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
26562683
ur_kernel_cache_config_t KernelCacheConfig, const bool KernelIsCooperative,
26572684
const bool KernelUsesClusterLaunch, const size_t WorkGroupMemorySize,
2658-
const RTDeviceBinaryImage *BinImage) {
2685+
const RTDeviceBinaryImage *BinImage, void *KernelFuncPtr, int KernelNumArgs,
2686+
detail::kernel_param_desc_t (*KernelParamDescGetter)(int),
2687+
bool KernelHasSpecialCaptures) {
26592688
assert(Queue && "Kernel submissions should have an associated queue");
26602689
// Run OpenCL kernel
26612690
auto &ContextImpl = Queue->getContextImplPtr();
@@ -2739,7 +2768,8 @@ void enqueueImpKernel(
27392768
Queue, Args, DeviceImageImpl, Kernel, NDRDesc, EventsWaitList,
27402769
OutEventImpl, EliminatedArgMask, getMemAllocationFunc,
27412770
KernelIsCooperative, KernelUsesClusterLaunch, WorkGroupMemorySize,
2742-
BinImage, KernelName);
2771+
BinImage, KernelName, KernelFuncPtr, KernelNumArgs,
2772+
KernelParamDescGetter, KernelHasSpecialCaptures);
27432773

27442774
const AdapterPtr &Adapter = Queue->getAdapter();
27452775
if (!SyclKernelImpl && !MSyclKernel) {

sycl/source/detail/scheduler/commands.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,10 @@ void enqueueImpKernel(
626626
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
627627
ur_kernel_cache_config_t KernelCacheConfig, bool KernelIsCooperative,
628628
const bool KernelUsesClusterLaunch, const size_t WorkGroupMemorySize,
629-
const RTDeviceBinaryImage *BinImage = nullptr);
629+
const RTDeviceBinaryImage *BinImage = nullptr,
630+
void *KernelFuncPtr = nullptr, int KernelNumArgs = 0,
631+
detail::kernel_param_desc_t (*KernelParamDescGetter)(int) = nullptr,
632+
bool KernelHasSpecialCaptures = false);
630633

631634
/// The exec CG command enqueues execution of kernel or explicit memory
632635
/// operation.

sycl/source/handler.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,18 @@ event handler::finalize() {
411411
return MLastEvent;
412412
MIsFinalized = true;
413413

414-
// Extract arguments from the kernel lambda.
415-
// TODO: Skip this in simple cases.
416414
const auto &type = getType();
417-
if (type == detail::CGType::Kernel && impl->MKernelFuncPtr) {
415+
const bool KernelFastPath =
416+
(MQueue && !impl->MGraph && !impl->MSubgraphNode &&
417+
!MQueue->hasCommandGraph() && !impl->CGData.MRequirements.size() &&
418+
!MStreamStorage.size() &&
419+
detail::Scheduler::areEventsSafeForSchedulerBypass(
420+
impl->CGData.MEvents, MQueue->getContextImplPtr()));
421+
422+
// Extract arguments from the kernel lambda, if required.
423+
// Skipping this is currently limited to simple kernels on the fast path.
424+
if (type == detail::CGType::Kernel && impl->MKernelFuncPtr &&
425+
!(KernelFastPath && impl->MKernelHasSpecialCaptures)) {
418426
clearArgs();
419427
extractArgsAndReqsFromLambda((char *)impl->MKernelFuncPtr,
420428
impl->MKernelParamDescGetter,
@@ -516,11 +524,7 @@ event handler::finalize() {
516524
}
517525
}
518526

519-
if (MQueue && !impl->MGraph && !impl->MSubgraphNode &&
520-
!MQueue->hasCommandGraph() && !impl->CGData.MRequirements.size() &&
521-
!MStreamStorage.size() &&
522-
detail::Scheduler::areEventsSafeForSchedulerBypass(
523-
impl->CGData.MEvents, MQueue->getContextImplPtr())) {
527+
if (KernelFastPath) {
524528
// if user does not add a new dependency to the dependency graph, i.e.
525529
// the graph is not changed, then this faster path is used to submit
526530
// kernel bypassing scheduler and avoiding CommandGroup, Command objects
@@ -566,13 +570,14 @@ event handler::finalize() {
566570
detail::retrieveKernelBinary(MQueue, MKernelName.data());
567571
assert(BinImage && "Failed to obtain a binary image.");
568572
}
569-
enqueueImpKernel(MQueue, impl->MNDRDesc, impl->MArgs,
570-
KernelBundleImpPtr, MKernel.get(), MKernelName.data(),
571-
RawEvents,
572-
DiscardEvent ? nullptr : LastEventImpl.get(), nullptr,
573-
impl->MKernelCacheConfig, impl->MKernelIsCooperative,
574-
impl->MKernelUsesClusterLaunch,
575-
impl->MKernelWorkGroupMemorySize, BinImage);
573+
enqueueImpKernel(
574+
MQueue, impl->MNDRDesc, impl->MArgs, KernelBundleImpPtr,
575+
MKernel.get(), MKernelName.data(), RawEvents,
576+
DiscardEvent ? nullptr : LastEventImpl.get(), nullptr,
577+
impl->MKernelCacheConfig, impl->MKernelIsCooperative,
578+
impl->MKernelUsesClusterLaunch, impl->MKernelWorkGroupMemorySize,
579+
BinImage, impl->MKernelFuncPtr, impl->MKernelNumArgs,
580+
impl->MKernelParamDescGetter, impl->MKernelHasSpecialCaptures);
576581
#ifdef XPTI_ENABLE_INSTRUMENTATION
577582
if (xptiEnabled) {
578583
// Emit signal only when event is created

0 commit comments

Comments
 (0)