Skip to content

[SYCL] Extract args directly from kernel if we can #18387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: sycl
Choose a base branch
from
2 changes: 1 addition & 1 deletion sycl/include/sycl/detail/kernel_desc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ template <typename KernelNameType> constexpr int64_t getKernelSize() {

template <typename KernelNameType> constexpr bool hasSpecialCaptures() {
bool FoundSpecialCapture = false;
for (int I = 0; I < getKernelNumParams<KernelNameType>(); ++I) {
for (unsigned I = 0; I < getKernelNumParams<KernelNameType>(); ++I) {
auto ParamDesc = getKernelParamDesc<KernelNameType>(I);
bool IsSpecialCapture =
(ParamDesc.kind != kernel_param_kind_t::kind_std_layout &&
Expand Down
20 changes: 13 additions & 7 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,8 @@ class __SYCL_EXPORT handler {

/// Stores lambda to the template-free object
///
/// Also initializes kernel name, list of arguments and requirements using
/// information from the integration header/built-ins.
/// Also initializes the kernel name and prepares for arguments to
/// be extracted from the lambda in handler::finalize().
///
/// \param KernelFunc is a SYCL kernel function
/// \param ParamDescs is the vector of kernel parameter descriptors.
Expand Down Expand Up @@ -796,11 +796,13 @@ class __SYCL_EXPORT handler {
if constexpr (KernelHasName) {
// TODO support ESIMD in no-integration-header case too.

clearArgs();
extractArgsAndReqsFromLambda(MHostKernel->getPtr(),
&(detail::getKernelParamDesc<KernelName>),
detail::getKernelNumParams<KernelName>(),
detail::isKernelESIMD<KernelName>());
// Force hasSpecialCaptures to be evaluated at compile-time.
constexpr bool HasSpecialCapt = detail::hasSpecialCaptures<KernelName>();
setKernelInfo((void *)MHostKernel->getPtr(),
detail::getKernelNumParams<KernelName>(),
&(detail::getKernelParamDesc<KernelName>),
detail::isKernelESIMD<KernelName>(), HasSpecialCapt);

MKernelName = detail::getKernelName<KernelName>();
} else {
// In case w/o the integration header it is necessary to process
Expand Down Expand Up @@ -3761,6 +3763,10 @@ class __SYCL_EXPORT handler {
sycl::range<3> LocalSize, sycl::id<3> Offset,
int Dims);

void setKernelInfo(void *KernelFuncPtr, int KernelNumArgs,
detail::kernel_param_desc_t (*KernelParamDescGetter)(int),
bool KernelIsESIMD, bool KernelHasSpecialCaptures);

friend class detail::HandlerAccess;

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
Expand Down
7 changes: 7 additions & 0 deletions sycl/source/detail/handler_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ class handler_impl {

// Allocation ptr to be freed asynchronously.
void *MFreePtr = nullptr;

// Store information about the kernel arguments.
void *MKernelFuncPtr = nullptr;
int MKernelNumArgs = 0;
detail::kernel_param_desc_t (*MKernelParamDescGetter)(int) = nullptr;
bool MKernelIsESIMD = false;
bool MKernelHasSpecialCaptures = true;
};

} // namespace detail
Expand Down
51 changes: 41 additions & 10 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2385,7 +2385,10 @@ static ur_result_t SetKernelParamsAndLaunch(
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
bool IsCooperative, bool KernelUsesClusterLaunch,
uint32_t WorkGroupMemorySize, const RTDeviceBinaryImage *BinImage,
KernelNameStrRefT KernelName) {
KernelNameStrRefT KernelName, void *KernelFuncPtr = nullptr,
int KernelNumArgs = 0,
detail::kernel_param_desc_t (*KernelParamDescGetter)(int) = nullptr,
bool KernelHasSpecialCaptures = true) {
assert(Queue && "Kernel submissions should have an associated queue");
const AdapterPtr &Adapter = Queue->getAdapter();

Expand All @@ -2397,13 +2400,38 @@ static ur_result_t SetKernelParamsAndLaunch(
: Empty);
}

auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
Queue->getContextImplPtr(), Arg, NextTrueIndex);
};

applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
if (KernelFuncPtr && !KernelHasSpecialCaptures) {
auto setFunc = [&Adapter, Kernel,
KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
size_t NextTrueIndex) {
const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset;
switch (ParamDesc.kind) {
case kernel_param_kind_t::kind_std_layout: {
int Size = ParamDesc.info;
Adapter->call<UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
Size, nullptr, ArgPtr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SetArgBasedOnType does

    if (Arg.MPtr) {
      Adapter->call<UrApiKind::urKernelSetArgValue>(
          Kernel, NextTrueIndex, Arg.MSize, nullptr, Arg.MPtr);
    } else {
      Adapter->call<UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
                                                    Arg.MSize, nullptr);
    }

Is there a reason we don't do that here? Is the else-case falling under "special captures"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm honest, I don't really understand how the old MArgs code works. You can see that each argument in MArgs is represented by a pointer to an argument, and this else branch only triggers when that pointer is null.

On the fast path, I'm extracting a standard layout argument directly from the function object. Since it's a standard layout object and not a pointer to one, it can never be null, so I removed the branch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is some special case for some local memory accessors. I wonder if these changes can handle that correctly. Let's hope testing is good enough!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be some overloading of the meaning of the kernel parameter kind. The vector contains the arguments after decomposition, which means that there may indeed be some special fields that need handling.

The array I'm working with contains a description of the original arguments, and so anything that's captured as a "standard layout" class really needs to be one -- a local accessor is identified in the array as an accessor, so it will be counted as a special capture.

break;
}
case kernel_param_kind_t::kind_pointer: {
const void *Ptr = *static_cast<const void *const *>(ArgPtr);
Adapter->call<UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
nullptr, Ptr);
break;
}
default:
throw std::runtime_error("Direct kernel argument copy failed.");
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it make sense to add a SetArgBasedOnType overload or variant that the old one can call if the arguments don't fall under the special types? Avoids a little code replication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not with the current design. This is what I was alluding to when I said it would be a good idea to try and unify the MArgs design with what I've done here, though.

The main problem is that we have two different ways to represent what an argument is. MArgs is a vector of detail::ArgDesc objects, but what I'm reading from the integration header is an array of detail::kernel_param_desc_t objects. Submission uses either the vector or the array (and the vector doesn't exist on the fast path) so we can't currently mix and match.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see what you mean. Yeah, I am not sure it's worth trying to repack the arguments.

applyFuncOnFilteredArgs(EliminatedArgMask, KernelNumArgs,
KernelParamDescGetter, setFunc);
} else {
auto setFunc = [&Adapter, Kernel, &DeviceImageImpl, &getMemAllocationFunc,
&Queue](detail::ArgDesc &Arg, size_t NextTrueIndex) {
SetArgBasedOnType(Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
Queue->getContextImplPtr(), Arg, NextTrueIndex);
};
applyFuncOnFilteredArgs(EliminatedArgMask, Args, setFunc);
}

std::optional<int> ImplicitLocalArg =
ProgramManager::getInstance().kernelImplicitLocalArgPos(KernelName);
Expand Down Expand Up @@ -2655,7 +2683,9 @@ void enqueueImpKernel(
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
ur_kernel_cache_config_t KernelCacheConfig, const bool KernelIsCooperative,
const bool KernelUsesClusterLaunch, const size_t WorkGroupMemorySize,
const RTDeviceBinaryImage *BinImage) {
const RTDeviceBinaryImage *BinImage, void *KernelFuncPtr, int KernelNumArgs,
detail::kernel_param_desc_t (*KernelParamDescGetter)(int),
bool KernelHasSpecialCaptures) {
assert(Queue && "Kernel submissions should have an associated queue");
// Run OpenCL kernel
auto &ContextImpl = Queue->getContextImplPtr();
Expand Down Expand Up @@ -2739,7 +2769,8 @@ void enqueueImpKernel(
Queue, Args, DeviceImageImpl, Kernel, NDRDesc, EventsWaitList,
OutEventImpl, EliminatedArgMask, getMemAllocationFunc,
KernelIsCooperative, KernelUsesClusterLaunch, WorkGroupMemorySize,
BinImage, KernelName);
BinImage, KernelName, KernelFuncPtr, KernelNumArgs,
KernelParamDescGetter, KernelHasSpecialCaptures);

const AdapterPtr &Adapter = Queue->getAdapter();
if (!SyclKernelImpl && !MSyclKernel) {
Expand Down
26 changes: 25 additions & 1 deletion sycl/source/detail/scheduler/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,10 @@ void enqueueImpKernel(
const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
ur_kernel_cache_config_t KernelCacheConfig, bool KernelIsCooperative,
const bool KernelUsesClusterLaunch, const size_t WorkGroupMemorySize,
const RTDeviceBinaryImage *BinImage = nullptr);
const RTDeviceBinaryImage *BinImage = nullptr,
void *KernelFuncPtr = nullptr, int KernelNumArgs = 0,
detail::kernel_param_desc_t (*KernelParamDescGetter)(int) = nullptr,
bool KernelHasSpecialCaptures = true);

/// The exec CG command enqueues execution of kernel or explicit memory
/// operation.
Expand Down Expand Up @@ -780,6 +783,27 @@ void applyFuncOnFilteredArgs(const KernelArgMask *EliminatedArgMask,
}
}

template <typename FuncT>
void applyFuncOnFilteredArgs(
const KernelArgMask *EliminatedArgMask, int KernelNumArgs,
detail::kernel_param_desc_t (*KernelParamDescGetter)(int), FuncT Func) {
if (!EliminatedArgMask || EliminatedArgMask->size() == 0) {
for (int I = 0; I < KernelNumArgs; ++I) {
const detail::kernel_param_desc_t &Param = KernelParamDescGetter(I);
Func(Param, I);
}
} else {
size_t NextTrueIndex = 0;
for (int I = 0; I < KernelNumArgs; ++I) {
const detail::kernel_param_desc_t &Param = KernelParamDescGetter(I);
if ((*EliminatedArgMask)[I])
continue;
Func(Param, NextTrueIndex);
++NextTrueIndex;
}
}
}

void ReverseRangeDimensionsForKernel(NDRDescT &NDR);

} // namespace detail
Expand Down
51 changes: 38 additions & 13 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,24 @@ event handler::finalize() {
return MLastEvent;
MIsFinalized = true;

const auto &type = getType();
const bool KernelFastPath =
(MQueue && !impl->MGraph && !impl->MSubgraphNode &&
!MQueue->hasCommandGraph() && !impl->CGData.MRequirements.size() &&
!MStreamStorage.size() &&
detail::Scheduler::areEventsSafeForSchedulerBypass(
impl->CGData.MEvents, MQueue->getContextImplPtr()));

// Extract arguments from the kernel lambda, if required.
// Skipping this is currently limited to simple kernels on the fast path.
if (type == detail::CGType::Kernel && impl->MKernelFuncPtr &&
(!KernelFastPath || impl->MKernelHasSpecialCaptures)) {
clearArgs();
extractArgsAndReqsFromLambda((char *)impl->MKernelFuncPtr,
impl->MKernelParamDescGetter,
impl->MKernelNumArgs, impl->MKernelIsESIMD);
}

// According to 4.7.6.9 of SYCL2020 spec, if a placeholder accessor is passed
// to a command without being bound to a command group, an exception should
// be thrown.
Expand Down Expand Up @@ -448,7 +466,6 @@ event handler::finalize() {
}
}

const auto &type = getType();
if (type == detail::CGType::Kernel) {
// If there were uses of set_specialization_constant build the kernel_bundle
std::shared_ptr<detail::kernel_bundle_impl> KernelBundleImpPtr =
Expand Down Expand Up @@ -507,11 +524,7 @@ event handler::finalize() {
}
}

if (MQueue && !impl->MGraph && !impl->MSubgraphNode &&
!MQueue->hasCommandGraph() && !impl->CGData.MRequirements.size() &&
!MStreamStorage.size() &&
detail::Scheduler::areEventsSafeForSchedulerBypass(
impl->CGData.MEvents, MQueue->getContextImplPtr())) {
if (KernelFastPath) {
// if user does not add a new dependency to the dependency graph, i.e.
// the graph is not changed, then this faster path is used to submit
// kernel bypassing scheduler and avoiding CommandGroup, Command objects
Expand Down Expand Up @@ -557,13 +570,14 @@ event handler::finalize() {
detail::retrieveKernelBinary(MQueue, MKernelName.data());
assert(BinImage && "Failed to obtain a binary image.");
}
enqueueImpKernel(MQueue, impl->MNDRDesc, impl->MArgs,
KernelBundleImpPtr, MKernel.get(), MKernelName.data(),
RawEvents,
DiscardEvent ? nullptr : LastEventImpl.get(), nullptr,
impl->MKernelCacheConfig, impl->MKernelIsCooperative,
impl->MKernelUsesClusterLaunch,
impl->MKernelWorkGroupMemorySize, BinImage);
enqueueImpKernel(
MQueue, impl->MNDRDesc, impl->MArgs, KernelBundleImpPtr,
MKernel.get(), MKernelName.data(), RawEvents,
DiscardEvent ? nullptr : LastEventImpl.get(), nullptr,
impl->MKernelCacheConfig, impl->MKernelIsCooperative,
impl->MKernelUsesClusterLaunch, impl->MKernelWorkGroupMemorySize,
BinImage, impl->MKernelFuncPtr, impl->MKernelNumArgs,
impl->MKernelParamDescGetter, impl->MKernelHasSpecialCaptures);
#ifdef XPTI_ENABLE_INSTRUMENTATION
if (xptiEnabled) {
// Emit signal only when event is created
Expand Down Expand Up @@ -2254,6 +2268,17 @@ void handler::setNDRangeDescriptorPadded(sycl::range<3> NumWorkItems,
impl->MNDRDesc = NDRDescT{NumWorkItems, LocalSize, Offset, Dims};
}

void handler::setKernelInfo(
void *KernelFuncPtr, int KernelNumArgs,
detail::kernel_param_desc_t (*KernelParamDescGetter)(int),
bool KernelIsESIMD, bool KernelHasSpecialCaptures) {
impl->MKernelFuncPtr = KernelFuncPtr;
impl->MKernelNumArgs = KernelNumArgs;
impl->MKernelParamDescGetter = KernelParamDescGetter;
impl->MKernelIsESIMD = KernelIsESIMD;
impl->MKernelHasSpecialCaptures = KernelHasSpecialCaptures;
}

void handler::saveCodeLoc(detail::code_location CodeLoc, bool IsTopCodeLoc) {
MCodeLoc = CodeLoc;
impl->MIsTopCodeLoc = IsTopCodeLoc;
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_linux.dump
Original file line number Diff line number Diff line change
Expand Up @@ -3523,6 +3523,7 @@ _ZN4sycl3_V17handler12addReductionERKSt10shared_ptrIKvE
_ZN4sycl3_V17handler12setArgHelperEiRNS0_3ext6oneapi12experimental6detail30dynamic_work_group_memory_baseE
_ZN4sycl3_V17handler12setArgHelperEiRNS0_6detail22work_group_memory_implE
_ZN4sycl3_V17handler13getKernelNameEv
_ZN4sycl3_V17handler13setKernelInfoEPviPFNS0_6detail19kernel_param_desc_tEiEbb
_ZN4sycl3_V17handler14addAccessorReqESt10shared_ptrINS0_6detail16AccessorImplHostEE
_ZN4sycl3_V17handler14setNDRangeUsedEb
_ZN4sycl3_V17handler15ext_oneapi_copyENS0_3ext6oneapi12experimental16image_mem_handleENS0_5rangeILi3EEERKNS4_16image_descriptorEPvS7_S7_S7_
Expand Down
1 change: 1 addition & 0 deletions sycl/test/abi/sycl_symbols_windows.dump
Original file line number Diff line number Diff line change
Expand Up @@ -4377,6 +4377,7 @@
?setHandlerKernelBundle@handler@_V1@sycl@@AEAAXVkernel@23@@Z
?setKernelCacheConfig@handler@_V1@sycl@@AEAAXW4StableKernelCacheConfig@123@@Z
?setKernelClusterLaunch@handler@_V1@sycl@@AEAAXV?$range@$02@23@H@Z
?setKernelInfo@handler@_V1@sycl@@AEAAXPEAXHP6A?AUkernel_param_desc_t@detail@23@H@Z_N2@Z
?setKernelIsCooperative@handler@_V1@sycl@@AEAAX_N@Z
?setKernelWorkGroupMem@handler@_V1@sycl@@AEAAX_K@Z
?setLocalAccessorArgHelper@handler@_V1@sycl@@AEAAXHAEAVLocalAccessorBaseHost@detail@23@@Z
Expand Down
Loading