@@ -476,21 +476,14 @@ void ur_exp_command_buffer_handle_t_::cleanupCommandBufferResources() {
476
476
477
477
ur_exp_command_buffer_command_handle_t_::
478
478
ur_exp_command_buffer_command_handle_t_ (
479
- ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId,
480
- uint32_t WorkDim, bool UserDefinedLocalSize,
481
- ur_kernel_handle_t Kernel = nullptr )
482
- : CommandBuffer(CommandBuffer), CommandId(CommandId), WorkDim(WorkDim),
483
- UserDefinedLocalSize(UserDefinedLocalSize), Kernel(Kernel) {
479
+ ur_exp_command_buffer_handle_t CommandBuffer, uint64_t CommandId)
480
+ : CommandBuffer(CommandBuffer), CommandId(CommandId) {
484
481
ur::level_zero::urCommandBufferRetainExp (CommandBuffer);
485
- if (Kernel)
486
- ur::level_zero::urKernelRetain (Kernel);
487
482
}
488
483
489
484
ur_exp_command_buffer_command_handle_t_::
490
485
~ur_exp_command_buffer_command_handle_t_ () {
491
486
ur::level_zero::urCommandBufferReleaseExp (CommandBuffer);
492
- if (Kernel)
493
- ur::level_zero::urKernelRelease (Kernel);
494
487
}
495
488
496
489
void ur_exp_command_buffer_handle_t_::registerSyncPoint (
@@ -527,6 +520,31 @@ ur_result_t ur_exp_command_buffer_handle_t_::getFenceForQueue(
527
520
return UR_RESULT_SUCCESS;
528
521
}
529
522
523
+ kernel_command_handle::kernel_command_handle (
524
+ ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
525
+ uint64_t CommandId, uint32_t WorkDim, bool UserDefinedLocalSize,
526
+ uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives)
527
+ : ur_exp_command_buffer_command_handle_t_(CommandBuffer, CommandId),
528
+ WorkDim(WorkDim), UserDefinedLocalSize(UserDefinedLocalSize),
529
+ Kernel(Kernel) {
530
+ // Add the default kernel to the list of valid kernels
531
+ ur::level_zero::urKernelRetain (Kernel);
532
+ ValidKernelHandles.insert (Kernel);
533
+ // Add alternative kernels if provided
534
+ if (KernelAlternatives) {
535
+ for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
536
+ ur::level_zero::urKernelRetain (KernelAlternatives[i]);
537
+ ValidKernelHandles.insert (KernelAlternatives[i]);
538
+ }
539
+ }
540
+ }
541
+
542
+ kernel_command_handle::~kernel_command_handle () {
543
+ for (const ur_kernel_handle_t &KernelHandle : ValidKernelHandles) {
544
+ ur::level_zero::urKernelRelease (KernelHandle);
545
+ }
546
+ }
547
+
530
548
namespace ur ::level_zero {
531
549
532
550
/* *
@@ -906,7 +924,8 @@ setKernelPendingArguments(ur_exp_command_buffer_handle_t CommandBuffer,
906
924
ur_result_t
907
925
createCommandHandle (ur_exp_command_buffer_handle_t CommandBuffer,
908
926
ur_kernel_handle_t Kernel, uint32_t WorkDim,
909
- const size_t *LocalWorkSize,
927
+ const size_t *LocalWorkSize, uint32_t NumKernelAlternatives,
928
+ ur_kernel_handle_t *KernelAlternatives,
910
929
ur_exp_command_buffer_command_handle_t &Command) {
911
930
912
931
assert (CommandBuffer->IsUpdatable );
@@ -923,14 +942,41 @@ createCommandHandle(ur_exp_command_buffer_handle_t CommandBuffer,
923
942
ZE_MUTABLE_COMMAND_EXP_FLAG_GLOBAL_OFFSET;
924
943
925
944
auto Platform = CommandBuffer->Context ->getPlatform ();
926
- ZE2UR_CALL (Platform->ZeMutableCmdListExt .zexCommandListGetNextCommandIdExp ,
927
- (CommandBuffer->ZeComputeCommandListTranslated ,
928
- &ZeMutableCommandDesc, &CommandId));
945
+ if (NumKernelAlternatives > 0 ) {
946
+ ZeMutableCommandDesc.flags |=
947
+ ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION;
948
+
949
+ std::vector<ze_kernel_handle_t > TranslatedKernelHandles (
950
+ NumKernelAlternatives + 1 , nullptr );
951
+
952
+ // Translate main kernel first
953
+ ZE2UR_CALL (zelLoaderTranslateHandle,
954
+ (ZEL_HANDLE_KERNEL, Kernel->ZeKernel ,
955
+ (void **)&TranslatedKernelHandles[0 ]));
956
+
957
+ for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
958
+ ZE2UR_CALL (zelLoaderTranslateHandle,
959
+ (ZEL_HANDLE_KERNEL, KernelAlternatives[i]->ZeKernel ,
960
+ (void **)&TranslatedKernelHandles[i + 1 ]));
961
+ }
962
+
963
+ ZE2UR_CALL (Platform->ZeMutableCmdListExt
964
+ .zexCommandListGetNextCommandIdWithKernelsExp ,
965
+ (CommandBuffer->ZeComputeCommandListTranslated ,
966
+ &ZeMutableCommandDesc, NumKernelAlternatives + 1 ,
967
+ TranslatedKernelHandles.data (), &CommandId));
968
+
969
+ } else {
970
+ ZE2UR_CALL (Platform->ZeMutableCmdListExt .zexCommandListGetNextCommandIdExp ,
971
+ (CommandBuffer->ZeComputeCommandListTranslated ,
972
+ &ZeMutableCommandDesc, &CommandId));
973
+ }
929
974
DEBUG_LOG (CommandId);
930
975
931
976
try {
932
- Command = new ur_exp_command_buffer_command_handle_t_ (
933
- CommandBuffer, CommandId, WorkDim, LocalWorkSize != nullptr , Kernel);
977
+ Command = new kernel_command_handle (
978
+ CommandBuffer, Kernel, CommandId, WorkDim, LocalWorkSize != nullptr ,
979
+ NumKernelAlternatives, KernelAlternatives);
934
980
} catch (const std::bad_alloc &) {
935
981
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
936
982
} catch (...) {
@@ -944,8 +990,7 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
944
990
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
945
991
uint32_t WorkDim, const size_t *GlobalWorkOffset,
946
992
const size_t *GlobalWorkSize, const size_t *LocalWorkSize,
947
- uint32_t /* numKernelAlternatives*/ ,
948
- ur_kernel_handle_t * /* phKernelAlternatives*/ ,
993
+ uint32_t NumKernelAlternatives, ur_kernel_handle_t *KernelAlternatives,
949
994
uint32_t NumSyncPointsInWaitList,
950
995
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
951
996
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
@@ -960,6 +1005,10 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
960
1005
UR_ASSERT (!(Command && !CommandBuffer->IsUpdatable ),
961
1006
UR_RESULT_ERROR_INVALID_OPERATION);
962
1007
1008
+ for (uint32_t i = 0 ; i < NumKernelAlternatives; ++i) {
1009
+ UR_ASSERT (KernelAlternatives[i] != Kernel, UR_RESULT_ERROR_INVALID_VALUE);
1010
+ }
1011
+
963
1012
// Lock automatically releases when this goes out of scope.
964
1013
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
965
1014
Kernel->Mutex , Kernel->Program ->Mutex , CommandBuffer->Mutex );
@@ -983,18 +1032,21 @@ ur_result_t urCommandBufferAppendKernelLaunchExp(
983
1032
ZE2UR_CALL (zeKernelSetGroupSize, (Kernel->ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
984
1033
985
1034
CommandBuffer->KernelsList .push_back (Kernel);
1035
+ for (size_t i = 0 ; i < NumKernelAlternatives; i++) {
1036
+ CommandBuffer->KernelsList .push_back (KernelAlternatives[i]);
1037
+ }
986
1038
987
- // Increment the reference count of the Kernel and indicate that the Kernel
988
- // is in use. Once the event has been signaled, the code in
989
- // CleanupCompletedEvent(Event) will do a urKernelRelease to update the
990
- // reference count on the kernel, using the kernel saved in CommandData.
991
- UR_CALL ( ur::level_zero::urKernelRetain (Kernel));
1039
+ ur::level_zero::urKernelRetain ( Kernel);
1040
+ // Retain alternative kernels if provided
1041
+ for ( size_t i = 0 ; i < NumKernelAlternatives; i++) {
1042
+ ur::level_zero::urKernelRetain (KernelAlternatives[i]);
1043
+ }
992
1044
993
1045
if (Command) {
994
1046
UR_CALL (createCommandHandle (CommandBuffer, Kernel, WorkDim, LocalWorkSize,
1047
+ NumKernelAlternatives, KernelAlternatives,
995
1048
*Command));
996
1049
}
997
-
998
1050
std::vector<ze_event_handle_t > ZeEventList;
999
1051
ze_event_handle_t ZeLaunchEvent = nullptr ;
1000
1052
UR_CALL (createSyncPointAndGetZeEvents (
@@ -1690,7 +1742,7 @@ ur_result_t urCommandBufferReleaseCommandExp(
1690
1742
* @return UR_RESULT_SUCCESS or an error code on failure
1691
1743
*/
1692
1744
ur_result_t validateCommandDesc (
1693
- ur_exp_command_buffer_command_handle_t Command,
1745
+ kernel_command_handle * Command,
1694
1746
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1695
1747
1696
1748
auto CommandBuffer = Command->CommandBuffer ;
@@ -1699,9 +1751,14 @@ ur_result_t validateCommandDesc(
1699
1751
->mutableCommandFlags ;
1700
1752
logger::debug (" Mutable features supported by device {}" , SupportedFeatures);
1701
1753
1702
- // Kernel handle updates are not yet supported.
1703
- if (CommandDesc->hNewKernel && CommandDesc->hNewKernel != Command->Kernel ) {
1704
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1754
+ UR_ASSERT (
1755
+ !CommandDesc->hNewKernel ||
1756
+ (SupportedFeatures & ZE_MUTABLE_COMMAND_EXP_FLAG_KERNEL_INSTRUCTION),
1757
+ UR_RESULT_ERROR_UNSUPPORTED_FEATURE);
1758
+ // Check if the provided new kernel is in the list of valid alternatives.
1759
+ if (CommandDesc->hNewKernel &&
1760
+ !Command->ValidKernelHandles .count (CommandDesc->hNewKernel )) {
1761
+ return UR_RESULT_ERROR_INVALID_VALUE;
1705
1762
}
1706
1763
1707
1764
if (CommandDesc->newWorkDim != Command->WorkDim &&
@@ -1754,7 +1811,7 @@ ur_result_t validateCommandDesc(
1754
1811
* @return UR_RESULT_SUCCESS or an error code on failure
1755
1812
*/
1756
1813
ur_result_t updateKernelCommand (
1757
- ur_exp_command_buffer_command_handle_t Command,
1814
+ kernel_command_handle * Command,
1758
1815
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
1759
1816
1760
1817
// We need the created descriptors to live till the point when
@@ -1769,12 +1826,29 @@ ur_result_t updateKernelCommand(
1769
1826
1770
1827
const auto CommandBuffer = Command->CommandBuffer ;
1771
1828
const void *NextDesc = nullptr ;
1829
+ auto Platform = CommandBuffer->Context ->getPlatform ();
1772
1830
1773
1831
uint32_t Dim = CommandDesc->newWorkDim ;
1774
1832
size_t *NewGlobalWorkOffset = CommandDesc->pNewGlobalWorkOffset ;
1775
1833
size_t *NewLocalWorkSize = CommandDesc->pNewLocalWorkSize ;
1776
1834
size_t *NewGlobalWorkSize = CommandDesc->pNewGlobalWorkSize ;
1777
1835
1836
+ // Kernel handle must be updated first for a given CommandId if required
1837
+ ur_kernel_handle_t NewKernel = CommandDesc->hNewKernel ;
1838
+ if (NewKernel && Command->Kernel != NewKernel) {
1839
+ ze_kernel_handle_t ZeKernelTranslated = nullptr ;
1840
+ ZE2UR_CALL (
1841
+ zelLoaderTranslateHandle,
1842
+ (ZEL_HANDLE_KERNEL, NewKernel->ZeKernel , (void **)&ZeKernelTranslated));
1843
+
1844
+ ZE2UR_CALL (Platform->ZeMutableCmdListExt
1845
+ .zexCommandListUpdateMutableCommandKernelsExp ,
1846
+ (CommandBuffer->ZeComputeCommandListTranslated , 1 ,
1847
+ &Command->CommandId , &ZeKernelTranslated));
1848
+ // Set current kernel to be the new kernel
1849
+ Command->Kernel = NewKernel;
1850
+ }
1851
+
1778
1852
// Check if a new global offset is provided.
1779
1853
if (NewGlobalWorkOffset && Dim > 0 ) {
1780
1854
auto MutableGroupOffestDesc =
@@ -1973,7 +2047,6 @@ ur_result_t updateKernelCommand(
1973
2047
MutableCommandDesc.pNext = NextDesc;
1974
2048
MutableCommandDesc.flags = 0 ;
1975
2049
1976
- auto Platform = CommandBuffer->Context ->getPlatform ();
1977
2050
ZE2UR_CALL (
1978
2051
Platform->ZeMutableCmdListExt .zexCommandListUpdateMutableCommandsExp ,
1979
2052
(CommandBuffer->ZeComputeCommandListTranslated , &MutableCommandDesc));
@@ -2009,18 +2082,22 @@ ur_result_t urCommandBufferUpdateKernelLaunchExp(
2009
2082
const ur_exp_command_buffer_update_kernel_launch_desc_t *CommandDesc) {
2010
2083
UR_ASSERT (Command->CommandBuffer ->IsUpdatable ,
2011
2084
UR_RESULT_ERROR_INVALID_OPERATION);
2012
- UR_ASSERT (Command->Kernel , UR_RESULT_ERROR_INVALID_NULL_HANDLE);
2085
+
2086
+ auto KernelCommandHandle = static_cast <kernel_command_handle *>(Command);
2087
+
2088
+ UR_ASSERT (KernelCommandHandle->Kernel , UR_RESULT_ERROR_INVALID_NULL_HANDLE);
2013
2089
2014
2090
// Lock command, kernel and command buffer for update.
2015
2091
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Guard (
2016
- Command->Mutex , Command->CommandBuffer ->Mutex , Command->Kernel ->Mutex );
2092
+ Command->Mutex , Command->CommandBuffer ->Mutex ,
2093
+ KernelCommandHandle->Kernel ->Mutex );
2017
2094
2018
2095
UR_ASSERT (Command->CommandBuffer ->IsFinalized ,
2019
2096
UR_RESULT_ERROR_INVALID_OPERATION);
2020
2097
2021
- UR_CALL (validateCommandDesc (Command , CommandDesc));
2098
+ UR_CALL (validateCommandDesc (KernelCommandHandle , CommandDesc));
2022
2099
UR_CALL (waitForOngoingExecution (Command->CommandBuffer ));
2023
- UR_CALL (updateKernelCommand (Command , CommandDesc));
2100
+ UR_CALL (updateKernelCommand (KernelCommandHandle , CommandDesc));
2024
2101
2025
2102
ZE2UR_CALL (zeCommandListClose,
2026
2103
(Command->CommandBuffer ->ZeComputeCommandList ));
0 commit comments