@@ -787,6 +787,255 @@ ur_result_t urUSMReleaseExp(ur_context_handle_t Context, void *HostPtr) {
787787 Context->getPlatform ()->ZeDriverHandleExpTranslated , HostPtr);
788788 return UR_RESULT_SUCCESS;
789789}
790+
791+ enum class USMAllocType { Host = 0 , Device = 1 , Shared = 2 };
792+
793+ static ur_result_t USMAllocHelper (ur_context_handle_t Context,
794+ ur_device_handle_t Device, size_t Size,
795+ void **RetMem, USMAllocType Type) {
796+ auto &Platform = Device->Platform ;
797+
798+ // TODO: Should alignemnt be passed in 'ur_exp_async_usm_alloc_properties_t'?
799+ uint32_t Alignment = 0 ;
800+
801+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex ,
802+ std::defer_lock);
803+ std::unique_lock<ur_shared_mutex> IndirectAccessTrackingLock (
804+ Platform->ContextsMutex , std::defer_lock);
805+ if (IndirectAccessTrackingEnabled) {
806+ IndirectAccessTrackingLock.lock ();
807+ UR_CALL (ur::level_zero::urContextRetain (Context));
808+ } else {
809+ ContextLock.lock ();
810+ }
811+
812+ umf_memory_pool_handle_t hPoolInternal = nullptr ;
813+ switch (Type) {
814+ case USMAllocType::Host:
815+ hPoolInternal = Context->AsyncHostMemPool .get ();
816+ break ;
817+ case USMAllocType::Device: {
818+ auto It = Context->AsyncDeviceMemPools .find (Device->ZeDevice );
819+ if (It == Context->AsyncDeviceMemPools .end ()) {
820+ return UR_RESULT_ERROR_INVALID_VALUE;
821+ }
822+ hPoolInternal = It->second .get ();
823+ } break ;
824+ case USMAllocType::Shared: {
825+ auto It = Context->AsyncSharedMemPools .find (Device->ZeDevice );
826+ if (It == Context->AsyncSharedMemPools .end ()) {
827+ return UR_RESULT_ERROR_INVALID_VALUE;
828+ }
829+ hPoolInternal = It->second .get ();
830+ } break ;
831+ };
832+
833+ *RetMem = umfPoolAlignedMalloc (hPoolInternal, Size, Alignment);
834+ if (*RetMem == nullptr ) {
835+ auto umfRet = umfPoolGetLastAllocationError (hPoolInternal);
836+ return umf2urResult (umfRet);
837+ }
838+
839+ if (IndirectAccessTrackingEnabled) {
840+ // Keep track of all memory allocations in the context
841+ Context->MemAllocs .emplace (std::piecewise_construct,
842+ std::forward_as_tuple (*RetMem),
843+ std::forward_as_tuple (Context));
844+ }
845+
846+ return UR_RESULT_SUCCESS;
847+ }
848+
849+ static ur_result_t enqueueUSMAllocHelper (
850+ ur_queue_handle_t Queue, ur_usm_pool_handle_t Pool, const size_t Size,
851+ const ur_exp_async_usm_alloc_properties_t *Properties,
852+ uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
853+ void **RetMem, ur_event_handle_t *OutEvent, USMAllocType Type) {
854+ std::ignore = Pool;
855+ std::ignore = Properties;
856+
857+ std::scoped_lock<ur_shared_mutex> lock (Queue->Mutex );
858+
859+ bool UseCopyEngine = false ;
860+ _ur_ze_event_list_t TmpWaitList;
861+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
862+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
863+
864+ // Get a new command list to be used on this call
865+ ur_command_list_ptr_t CommandList{};
866+ UR_CALL (Queue->Context ->getAvailableCommandList (
867+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList));
868+
869+ ze_event_handle_t ZeEvent = nullptr ;
870+ ur_event_handle_t InternalEvent{};
871+ bool IsInternal = OutEvent == nullptr ;
872+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
873+
874+ ur_command_t CommandType = UR_COMMAND_FORCE_UINT32;
875+ switch (Type) {
876+ case USMAllocType::Host:
877+ CommandType = UR_COMMAND_ENQUEUE_USM_HOST_ALLOC_EXP;
878+ break ;
879+ case USMAllocType::Device:
880+ CommandType = UR_COMMAND_ENQUEUE_USM_DEVICE_ALLOC_EXP;
881+ break ;
882+ case USMAllocType::Shared:
883+ CommandType = UR_COMMAND_ENQUEUE_USM_SHARED_ALLOC_EXP;
884+ break ;
885+ }
886+ UR_CALL (createEventAndAssociateQueue (Queue, Event, CommandType, CommandList,
887+ IsInternal, false ));
888+ ZeEvent = (*Event)->ZeEvent ;
889+ (*Event)->WaitList = TmpWaitList;
890+
891+ // Allocate USM memory
892+ auto Ret = USMAllocHelper (Queue->Context , Queue->Device , Size, RetMem, Type);
893+ if (Ret) {
894+ return Ret;
895+ }
896+
897+ // Signal that USM allocation event was finished
898+ ZE2UR_CALL (zeCommandListAppendSignalEvent, (CommandList->first , ZeEvent));
899+
900+ UR_CALL (Queue->executeCommandList (CommandList, false ));
901+
902+ return UR_RESULT_SUCCESS;
903+ }
904+
905+ ur_result_t urEnqueueUSMDeviceAllocExp (
906+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
907+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
908+ const size_t Size, // /< [in] minimum size in bytes of the USM memory object
909+ // /< to be allocated
910+ const ur_exp_async_usm_alloc_properties_t
911+ *Properties, // /< [in][optional] pointer to the enqueue async alloc
912+ // /< properties
913+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
914+ const ur_event_handle_t
915+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
916+ // /< pointer to a list of events that must be complete
917+ // /< before the kernel execution. If nullptr, the
918+ // /< numEventsInWaitList must be 0, indicating no wait
919+ // /< events.
920+ void **Mem, // /< [out] pointer to USM memory object
921+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
922+ // /< identifies the async alloc
923+ ) {
924+ return enqueueUSMAllocHelper (Queue, Pool, Size, Properties,
925+ NumEventsInWaitList, EventWaitList, Mem,
926+ OutEvent, USMAllocType::Device);
927+ }
928+
929+ ur_result_t urEnqueueUSMSharedAllocExp (
930+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
931+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
932+ const size_t Size, // /< [in] minimum size in bytes of the USM memory object
933+ // /< to be allocated
934+ const ur_exp_async_usm_alloc_properties_t
935+ *Properties, // /< [in][optional] pointer to the enqueue async alloc
936+ // /< properties
937+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
938+ const ur_event_handle_t
939+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
940+ // /< pointer to a list of events that must be complete
941+ // /< before the kernel execution. If nullptr, the
942+ // /< numEventsInWaitList must be 0, indicating no wait
943+ // /< events.
944+ void **Mem, // /< [out] pointer to USM memory object
945+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
946+ // /< identifies the async alloc
947+ ) {
948+ return enqueueUSMAllocHelper (Queue, Pool, Size, Properties,
949+ NumEventsInWaitList, EventWaitList, Mem,
950+ OutEvent, USMAllocType::Shared);
951+ }
952+
953+ ur_result_t urEnqueueUSMHostAllocExp (
954+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
955+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
956+ const size_t Size, // /< [in] minimum size in bytes of the USM memory object
957+ // /< to be allocated
958+ const ur_exp_async_usm_alloc_properties_t
959+ *Properties, // /< [in][optional] pointer to the enqueue async alloc
960+ // /< properties
961+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
962+ const ur_event_handle_t
963+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
964+ // /< pointer to a list of events that must be complete
965+ // /< before the kernel execution. If nullptr, the
966+ // /< numEventsInWaitList must be 0, indicating no wait
967+ // /< events.
968+ void **Mem, // /< [out] pointer to USM memory object
969+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
970+ // /< identifies the async alloc
971+ ) {
972+ return enqueueUSMAllocHelper (Queue, Pool, Size, Properties,
973+ NumEventsInWaitList, EventWaitList, Mem,
974+ OutEvent, USMAllocType::Host);
975+ }
976+
977+ ur_result_t urEnqueueUSMFreeExp (
978+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
979+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
980+ void *Mem, // /< [in] pointer to USM memory object
981+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
982+ const ur_event_handle_t
983+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
984+ // /< pointer to a list of events that must be complete
985+ // /< before the kernel execution. If nullptr, the
986+ // /< numEventsInWaitList must be 0, indicating no wait
987+ // /< events.
988+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
989+ // /< identifies the async alloc
990+ ) {
991+ std::ignore = Pool;
992+
993+ std::scoped_lock<ur_shared_mutex> lock (Queue->Mutex );
994+
995+ bool UseCopyEngine = false ;
996+ _ur_ze_event_list_t TmpWaitList;
997+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
998+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
999+
1000+ // Get a new command list to be used on this call
1001+ ur_command_list_ptr_t CommandList{};
1002+ UR_CALL (Queue->Context ->getAvailableCommandList (
1003+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList));
1004+
1005+ ze_event_handle_t ZeEvent = nullptr ;
1006+ ur_event_handle_t InternalEvent{};
1007+ bool IsInternal = OutEvent == nullptr ;
1008+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
1009+
1010+ UR_CALL (createEventAndAssociateQueue (Queue, Event,
1011+ UR_COMMAND_ENQUEUE_USM_FREE_EXP,
1012+ CommandList, IsInternal, false ));
1013+ ZeEvent = (*Event)->ZeEvent ;
1014+ (*Event)->WaitList = TmpWaitList;
1015+
1016+ const auto &ZeCommandList = CommandList->first ;
1017+ const auto &WaitList = (*Event)->WaitList ;
1018+ if (WaitList.Length ) {
1019+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1020+ (ZeCommandList, WaitList.Length , WaitList.ZeEventList ));
1021+ }
1022+
1023+ // Wait for commands execution until USM can be freed
1024+ UR_CALL (Queue->executeCommandList (CommandList, true )); // Blocking
1025+
1026+ // Free USM memory
1027+ auto Ret = USMFreeHelper (Queue->Context , Mem);
1028+ if (Ret) {
1029+ return Ret;
1030+ }
1031+
1032+ // Signal that USM free event was finished
1033+ ZE2UR_CALL (zeCommandListAppendSignalEvent, (ZeCommandList, ZeEvent));
1034+
1035+ UR_CALL (Queue->executeCommandList (CommandList, false ));
1036+
1037+ return UR_RESULT_SUCCESS;
1038+ }
7901039} // namespace ur::level_zero
7911040
7921041static ur_result_t USMFreeImpl (ur_context_handle_t Context, void *Ptr) {
0 commit comments