Skip to content

Commit 0a90db9

Browse files
authored
Merge pull request #2369 from Bensuo/ben/kernel-binary-update-l0
[CMDBUF] Implement kernel binary update for L0 adapter
2 parents c4d9fdb + 6e0bdeb commit 0a90db9

File tree

6 files changed

+173
-37
lines changed

6 files changed

+173
-37
lines changed

scripts/core/EXP-COMMAND-BUFFER.rst

+5
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,11 @@ ${x}CommandBufferAppendKernelLaunchExp. The command can then be updated
256256
to use the new kernel handle by passing it to
257257
${x}CommandBufferUpdateKernelLaunchExp.
258258

259+
.. important::
260+
When updating the kernel handle of a command all required arguments to the
261+
new kernel must be provided in the update descriptor. Failure to do so will
262+
result in undefined behavior.
263+
259264
.. parsed-literal::
260265
261266
// Create a command-buffer with update enabled.

source/adapters/level_zero/command_buffer.cpp

+110-33
Original file line numberDiff line numberDiff line change
@@ -476,21 +476,14 @@ void ur_exp_command_buffer_handle_t_::cleanupCommandBufferResources() {
476476

477477
ur_exp_command_buffer_command_handle_t_::
478478
ur_exp_command_buffer_command_handle_t_(
479-
ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId,
480-
uint32_t WorkDim, bool UserDefinedLocalSize,
481-
ur_kernel_handle_t Kernel = nullptr)
482-
: CommandBuffer(CommandBuffer), CommandId(CommandId), WorkDim(WorkDim),
483-
UserDefinedLocalSize(UserDefinedLocalSize), Kernel(Kernel) {
479+
ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId)
480+
: CommandBuffer(CommandBuffer), CommandId(CommandId) {
484481
ur::level_zero::urCommandBufferRetainExp(CommandBuffer);
485-
if (Kernel)
486-
ur::level_zero::urKernelRetain(Kernel);
487482
}
488483

489484
ur_exp_command_buffer_command_handle_t_::
490485
~ur_exp_command_buffer_command_handle_t_() {
491486
ur::level_zero::urCommandBufferReleaseExp(CommandBuffer);
492-
if (Kernel)
493-
ur::level_zero::urKernelRelease(Kernel);
494487
}
495488

496489
void ur_exp_command_buffer_handle_t_::registerSyncPoint(
@@ -527,6 +520,31 @@ ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue(
527520
return UR_RESULT_SUCCESS;
528521
}
529522

523+
kernel_command_handle::kernel_command_handle(
524+
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
525+
uint64_t CommandId, uint32_t WorkDim, bool UserDefinedLocalSize,
526+
uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives)
527+
: ur_exp_command_buffer_command_handle_t_(CommandBuffer, CommandId),
528+
WorkDim(WorkDim), UserDefinedLocalSize(UserDefinedLocalSize),
529+
Kernel(Kernel) {
530+
// Add the default kernel to the list of valid kernels
531+
ur::level_zero::urKernelRetain(Kernel);
532+
ValidKernelHandles.insert(Kernel);
533+
// Add alternative kernels if provided
534+
if (KernelAlternatives) {
535+
for (size_t i = 0; i < NumKernelAlternatives; i++) {
536+
ur::level_zero::urKernelRetain(KernelAlternatives[i]);
537+
ValidKernelHandles.insert(KernelAlternatives[i]);
538+
}
539+
}
540+
}
541+
542+
kernel_command_handle::~kernel_command_handle() {
543+
for (const ur_kernel_handle_t &KernelHandle : ValidKernelHandles) {
544+
ur::level_zero::urKernelRelease(KernelHandle);
545+
}
546+
}
547+
530548
namespace ur::level_zero {
531549

532550
/**
@@ -906,7 +924,8 @@ setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer,
906924
ur_result_t
907925
createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
908926
ur_kernel_handle_t Kernel, uint32_t WorkDim,
909-
const size_t *LocalWorkSize,
927+
const size_t *LocalWorkSize, uint32_t NumKernelAlternatives,
928+
ur_kernel_handle_t *KernelAlternatives,
910929
ur_exp_command_buffer_command_handle_t &Command) {
911930

912931
assert(CommandBuffer->IsUpdatable);
@@ -923,14 +942,41 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
923942
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
924943

925944
auto Platform = CommandBuffer->Context->getPlatform();
926-
ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp,
927-
(CommandBuffer->ZeComputeCommandListTranslated,
928-
&ZeMutableCommandDesc, &CommandId));
945+
if (NumKernelAlternatives > 0) {
946+
ZeMutableCommandDesc.flags |=
947+
ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
948+
949+
std::vector<ze_kernel_handle_t> TranslatedKernelHandles(
950+
NumKernelAlternatives + 1, nullptr);
951+
952+
// Translate main kernel first
953+
ZE2UR_CALL(zelLoaderTranslateHandle,
954+
(ZEL_HANDLE_KERNEL, Kernel->ZeKernel,
955+
(void **)&TranslatedKernelHandles[0]));
956+
957+
for (size_t i = 0; i < NumKernelAlternatives; i++) {
958+
ZE2UR_CALL(zelLoaderTranslateHandle,
959+
(ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel,
960+
(void **)&TranslatedKernelHandles[i + 1]));
961+
}
962+
963+
ZE2UR_CALL(Platform->ZeMutableCmdListExt
964+
.zexCommandListGetNextCommandIdWithKernelsExp,
965+
(CommandBuffer->ZeComputeCommandListTranslated,
966+
&ZeMutableCommandDesc, NumKernelAlternatives + 1,
967+
TranslatedKernelHandles.data(), &CommandId));
968+
969+
} else {
970+
ZE2UR_CALL(Platform->ZeMutableCmdListExt.zexCommandListGetNextCommandIdExp,
971+
(CommandBuffer->ZeComputeCommandListTranslated,
972+
&ZeMutableCommandDesc, &CommandId));
973+
}
929974
DEBUG_LOG(CommandId);
930975

931976
try {
932-
Command = new ur_exp_command_buffer_command_handle_t_(
933-
CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr, Kernel);
977+
Command = new kernel_command_handle(
978+
CommandBuffer, Kernel, CommandId, WorkDim, LocalWorkSize != nullptr,
979+
NumKernelAlternatives, KernelAlternatives);
934980
} catch (const std::bad_alloc &) {
935981
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
936982
} catch (...) {
@@ -944,8 +990,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
944990
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
945991
uint32_t WorkDim, const size_t *GlobalWorkOffset,
946992
const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
947-
uint32_t /*numKernelAlternatives*/,
948-
ur_kernel_handle_t * /*phKernelAlternatives*/,
993+
uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives,
949994
uint32_t NumSyncPointsInWaitList,
950995
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
951996
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
@@ -960,6 +1005,10 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
9601005
UR_ASSERT(!(Command && !CommandBuffer->IsUpdatable),
9611006
UR_RESULT_ERROR_INVALID_OPERATION);
9621007

1008+
for (uint32_t i = 0; i < NumKernelAlternatives; ++i) {
1009+
UR_ASSERT(KernelAlternatives[i] != Kernel, UR_RESULT_ERROR_INVALID_VALUE);
1010+
}
1011+
9631012
// Lock automatically releases when this goes out of scope.
9641013
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
9651014
Kernel->Mutex, Kernel->Program->Mutex, CommandBuffer->Mutex);
@@ -983,18 +1032,21 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
9831032
ZE2UR_CALL(zeKernelSetGroupSize, (Kernel->ZeKernel, WG[0], WG[1], WG[2]));
9841033

9851034
CommandBuffer->KernelsList.push_back(Kernel);
1035+
for (size_t i = 0; i < NumKernelAlternatives; i++) {
1036+
CommandBuffer->KernelsList.push_back(KernelAlternatives[i]);
1037+
}
9861038

987-
// Increment the reference count of the Kernel and indicate that the Kernel
988-
// is in use. Once the event has been signaled, the code in
989-
// CleanupCompletedEvent(Event) will do a urKernelRelease to update the
990-
// reference count on the kernel, using the kernel saved in CommandData.
991-
UR_CALL(ur::level_zero::urKernelRetain(Kernel));
1039+
ur::level_zero::urKernelRetain(Kernel);
1040+
// Retain alternative kernels if provided
1041+
for (size_t i = 0; i < NumKernelAlternatives; i++) {
1042+
ur::level_zero::urKernelRetain(KernelAlternatives[i]);
1043+
}
9921044

9931045
if (Command) {
9941046
UR_CALL(createCommandHandle(CommandBuffer, Kernel, WorkDim, LocalWorkSize,
1047+
NumKernelAlternatives, KernelAlternatives,
9951048
*Command));
9961049
}
997-
9981050
std::vector<ze_event_handle_t> ZeEventList;
9991051
ze_event_handle_t ZeLaunchEvent = nullptr;
10001052
UR_CALL(createSyncPointAndGetZeEvents(
@@ -1690,7 +1742,7 @@ ur_result_t urCommandBufferReleaseCommandExp(
16901742
* @return UR_RESULT_SUCCESS or an error code on failure
16911743
*/
16921744
ur_result_t validateCommandDesc(
1693-
ur_exp_command_buffer_command_handle_t Command,
1745+
kernel_command_handle *Command,
16941746
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
16951747

16961748
auto CommandBuffer = Command->CommandBuffer;
@@ -1699,9 +1751,14 @@ ur_result_t validateCommandDesc(
16991751
->mutableCommandFlags;
17001752
logger::debug("Mutable features supported by device {}", SupportedFeatures);
17011753

1702-
// Kernel handle updates are not yet supported.
1703-
if (CommandDesc->hNewKernel && CommandDesc->hNewKernel != Command->Kernel) {
1704-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1754+
UR_ASSERT(
1755+
!CommandDesc->hNewKernel ||
1756+
(SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION),
1757+
UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1758+
// Check if the provided new kernel is in the list of valid alternatives.
1759+
if (CommandDesc->hNewKernel &&
1760+
!Command->ValidKernelHandles.count(CommandDesc->hNewKernel)) {
1761+
return UR_RESULT_ERROR_INVALID_VALUE;
17051762
}
17061763

17071764
if (CommandDesc->newWorkDim != Command->WorkDim &&
@@ -1754,7 +1811,7 @@ ur_result_t validateCommandDesc(
17541811
* @return UR_RESULT_SUCCESS or an error code on failure
17551812
*/
17561813
ur_result_t updateKernelCommand(
1757-
ur_exp_command_buffer_command_handle_t Command,
1814+
kernel_command_handle *Command,
17581815
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
17591816

17601817
// We need the created descriptors to live till the point when
@@ -1769,12 +1826,29 @@ ur_result_t updateKernelCommand(
17691826

17701827
const auto CommandBuffer = Command->CommandBuffer;
17711828
const void *NextDesc = nullptr;
1829+
auto Platform = CommandBuffer->Context->getPlatform();
17721830

17731831
uint32_t Dim = CommandDesc->newWorkDim;
17741832
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset;
17751833
size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize;
17761834
size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize;
17771835

1836+
// Kernel handle must be updated first for a given CommandId if required
1837+
ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel;
1838+
if (NewKernel && Command->Kernel != NewKernel) {
1839+
ze_kernel_handle_t ZeKernelTranslated = nullptr;
1840+
ZE2UR_CALL(
1841+
zelLoaderTranslateHandle,
1842+
(ZEL_HANDLE_KERNEL, NewKernel->ZeKernel, (void **)&ZeKernelTranslated));
1843+
1844+
ZE2UR_CALL(Platform->ZeMutableCmdListExt
1845+
.zexCommandListUpdateMutableCommandKernelsExp,
1846+
(CommandBuffer->ZeComputeCommandListTranslated, 1,
1847+
&Command->CommandId, &ZeKernelTranslated));
1848+
// Set current kernel to be the new kernel
1849+
Command->Kernel = NewKernel;
1850+
}
1851+
17781852
// Check if a new global offset is provided.
17791853
if (NewGlobalWorkOffset && Dim > 0) {
17801854
auto MutableGroupOffestDesc =
@@ -1973,7 +2047,6 @@ ur_result_t updateKernelCommand(
19732047
MutableCommandDesc.pNext = NextDesc;
19742048
MutableCommandDesc.flags = 0;
19752049

1976-
auto Platform = CommandBuffer->Context->getPlatform();
19772050
ZE2UR_CALL(
19782051
Platform->ZeMutableCmdListExt.zexCommandListUpdateMutableCommandsExp,
19792052
(CommandBuffer->ZeComputeCommandListTranslated, &MutableCommandDesc));
@@ -2009,18 +2082,22 @@ ur_result_t urCommandBufferUpdateKernelLaunchExp(
20092082
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
20102083
UR_ASSERT(Command->CommandBuffer->IsUpdatable,
20112084
UR_RESULT_ERROR_INVALID_OPERATION);
2012-
UR_ASSERT(Command->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
2085+
2086+
auto KernelCommandHandle = static_cast<kernel_command_handle *>(Command);
2087+
2088+
UR_ASSERT(KernelCommandHandle->Kernel, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
20132089

20142090
// Lock command, kernel and command buffer for update.
20152091
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard(
2016-
Command->Mutex, Command->CommandBuffer->Mutex, Command->Kernel->Mutex);
2092+
Command->Mutex, Command->CommandBuffer->Mutex,
2093+
KernelCommandHandle->Kernel->Mutex);
20172094

20182095
UR_ASSERT(Command->CommandBuffer->IsFinalized,
20192096
UR_RESULT_ERROR_INVALID_OPERATION);
20202097

2021-
UR_CALL(validateCommandDesc(Command, CommandDesc));
2098+
UR_CALL(validateCommandDesc(KernelCommandHandle, CommandDesc));
20222099
UR_CALL(waitForOngoingExecution(Command->CommandBuffer));
2023-
UR_CALL(updateKernelCommand(Command, CommandDesc));
2100+
UR_CALL(updateKernelCommand(KernelCommandHandle, CommandDesc));
20242101

20252102
ZE2UR_CALL(zeCommandListClose,
20262103
(Command->CommandBuffer->ZeComputeCommandList));

source/adapters/level_zero/command_buffer.hpp

+17-4
Original file line numberDiff line numberDiff line change
@@ -145,18 +145,31 @@ struct ur_exp_command_buffer_handle_t_ : public _ur_object {
145145

146146
struct ur_exp_command_buffer_command_handle_t_ : public _ur_object {
147147
ur_exp_command_buffer_command_handle_t_(ur_exp_command_buffer_handle_t,
148-
uint64_t, uint32_t, bool,
149-
ur_kernel_handle_t);
148+
uint64_t);
150149

151-
~ur_exp_command_buffer_command_handle_t_();
150+
virtual ~ur_exp_command_buffer_command_handle_t_();
152151

153152
// Command-buffer of this command.
154153
ur_exp_command_buffer_handle_t CommandBuffer;
155-
154+
// L0 command ID identifying this command
156155
uint64_t CommandId;
156+
};
157+
158+
struct kernel_command_handle : public ur_exp_command_buffer_command_handle_t_ {
159+
kernel_command_handle(ur_exp_command_buffer_handle_t CommandBuffer,
160+
ur_kernel_handle_t Kernel, uint64_t CommandId,
161+
uint32_t WorkDim, bool UserDefinedLocalSize,
162+
uint32_t NumKernelAlternatives,
163+
ur_kernel_handle_t *KernelAlternatives);
164+
165+
~kernel_command_handle();
166+
157167
// Work-dimension the command was originally created with.
158168
uint32_t WorkDim;
159169
// Set to true if the user set the local work size on command creation.
160170
bool UserDefinedLocalSize;
171+
// Currently active kernel handle
161172
ur_kernel_handle_t Kernel;
173+
// Storage for valid kernel alternatives for this command.
174+
std::unordered_set<ur_kernel_handle_t> ValidKernelHandles;
162175
};

source/adapters/level_zero/device.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,10 @@ ur_result_t urDeviceGetInfo(
10481048
UpdateCapabilities |=
10491049
UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_GLOBAL_WORK_OFFSET;
10501050
}
1051+
if (supportsFlags(ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION)) {
1052+
UpdateCapabilities |=
1053+
UR_DEVICE_COMMAND_BUFFER_UPDATE_CAPABILITY_FLAG_KERNEL_HANDLE;
1054+
}
10511055
return ReturnValue(UpdateCapabilities);
10521056
}
10531057
case UR_DEVICE_INFO_COMMAND_BUFFER_EVENT_SUPPORT_EXP:

source/adapters/level_zero/platform.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,22 @@ ur_result_t ur_platform_handle_t_::initialize() {
319319
ZeMutableCmdListExt.Supported |=
320320
ZeMutableCmdListExt.zexCommandListUpdateMutableCommandWaitEventsExp !=
321321
nullptr;
322+
ZeMutableCmdListExt.zexCommandListUpdateMutableCommandKernelsExp =
323+
(ze_pfnCommandListUpdateMutableCommandKernelsExp_t)
324+
ur_loader::LibLoader::getFunctionPtr(
325+
GlobalAdapter->processHandle,
326+
"zeCommandListUpdateMutableCommandKernelsExp");
327+
ZeMutableCmdListExt.Supported |=
328+
ZeMutableCmdListExt.zexCommandListUpdateMutableCommandKernelsExp !=
329+
nullptr;
330+
ZeMutableCmdListExt.zexCommandListGetNextCommandIdWithKernelsExp =
331+
(ze_pfnCommandListGetNextCommandIdWithKernelsExp_t)
332+
ur_loader::LibLoader::getFunctionPtr(
333+
GlobalAdapter->processHandle,
334+
"zeCommandListGetNextCommandIdWithKernelsExp");
335+
ZeMutableCmdListExt.Supported |=
336+
ZeMutableCmdListExt.zexCommandListGetNextCommandIdWithKernelsExp !=
337+
nullptr;
322338
} else {
323339
ZeMutableCmdListExt.Supported |=
324340
(ZE_CALL_NOCHECK(
@@ -353,6 +369,21 @@ ur_result_t ur_platform_handle_t_::initialize() {
353369
&ZeMutableCmdListExt
354370
.zexCommandListUpdateMutableCommandWaitEventsExp))) ==
355371
0);
372+
ZeMutableCmdListExt.Supported &=
373+
(ZE_CALL_NOCHECK(
374+
zeDriverGetExtensionFunctionAddress,
375+
(ZeDriver, "zeCommandListUpdateMutableCommandKernelsExp",
376+
reinterpret_cast<void **>(
377+
&ZeMutableCmdListExt
378+
.zexCommandListUpdateMutableCommandKernelsExp))) == 0);
379+
380+
ZeMutableCmdListExt.Supported &=
381+
(ZE_CALL_NOCHECK(
382+
zeDriverGetExtensionFunctionAddress,
383+
(ZeDriver, "zeCommandListGetNextCommandIdWithKernelsExp",
384+
reinterpret_cast<void **>(
385+
&ZeMutableCmdListExt
386+
.zexCommandListGetNextCommandIdWithKernelsExp))) == 0);
356387
}
357388
return UR_RESULT_SUCCESS;
358389
}

source/adapters/level_zero/platform.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,11 @@ struct ur_platform_handle_t_ : public _ur_platform {
107107
ze_result_t (*zexCommandListUpdateMutableCommandWaitEventsExp)(
108108
ze_command_list_handle_t, uint64_t, uint32_t,
109109
ze_event_handle_t *) = nullptr;
110+
ze_result_t (*zexCommandListUpdateMutableCommandKernelsExp)(
111+
ze_command_list_handle_t, uint32_t, uint64_t *,
112+
ze_kernel_handle_t *) = nullptr;
113+
ze_result_t (*zexCommandListGetNextCommandIdWithKernelsExp)(
114+
ze_command_list_handle_t, const ze_mutable_command_id_exp_desc_t *,
115+
uint32_t, ze_kernel_handle_t *, uint64_t *) = nullptr;
110116
} ZeMutableCmdListExt;
111117
};

0 commit comments

Comments
 (0)