@@ -787,6 +787,255 @@ ur_result_t urUSMReleaseExp(ur_context_handle_t Context, void *HostPtr) {
787
787
Context->getPlatform ()->ZeDriverHandleExpTranslated , HostPtr);
788
788
return UR_RESULT_SUCCESS;
789
789
}
790
+
791
+ enum class USMAllocType { Host = 0 , Device = 1 , Shared = 2 };
792
+
793
+ static ur_result_t USMAllocHelper (ur_context_handle_t Context,
794
+ ur_device_handle_t Device, size_t Size ,
795
+ void **RetMem, USMAllocType Type) {
796
+ auto &Platform = Device->Platform ;
797
+
798
+ // TODO: Should alignemnt be passed in 'ur_exp_async_usm_alloc_properties_t'?
799
+ uint32_t Alignment = 0 ;
800
+
801
+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex ,
802
+ std::defer_lock);
803
+ std::unique_lock<ur_shared_mutex> IndirectAccessTrackingLock (
804
+ Platform->ContextsMutex , std::defer_lock);
805
+ if (IndirectAccessTrackingEnabled) {
806
+ IndirectAccessTrackingLock.lock ();
807
+ UR_CALL (ur::level_zero::urContextRetain (Context));
808
+ } else {
809
+ ContextLock.lock ();
810
+ }
811
+
812
+ umf_memory_pool_handle_t hPoolInternal = nullptr ;
813
+ switch (Type) {
814
+ case USMAllocType::Host:
815
+ hPoolInternal = Context->HostMemPool .get ();
816
+ break ;
817
+ case USMAllocType::Device: {
818
+ auto It = Context->DeviceMemPools .find (Device->ZeDevice );
819
+ if (It == Context->DeviceMemPools .end ()) {
820
+ return UR_RESULT_ERROR_INVALID_VALUE;
821
+ }
822
+ hPoolInternal = It->second .get ();
823
+ } break ;
824
+ case USMAllocType::Shared: {
825
+ auto It = Context->SharedMemPools .find (Device->ZeDevice );
826
+ if (It == Context->SharedMemPools .end ()) {
827
+ return UR_RESULT_ERROR_INVALID_VALUE;
828
+ }
829
+ hPoolInternal = It->second .get ();
830
+ } break ;
831
+ };
832
+
833
+ *RetMem = umfPoolAlignedMalloc (hPoolInternal, Size , Alignment);
834
+ if (*RetMem == nullptr ) {
835
+ auto umfRet = umfPoolGetLastAllocationError (hPoolInternal);
836
+ return umf2urResult (umfRet);
837
+ }
838
+
839
+ if (IndirectAccessTrackingEnabled) {
840
+ // Keep track of all memory allocations in the context
841
+ Context->MemAllocs .emplace (std::piecewise_construct,
842
+ std::forward_as_tuple (*RetMem),
843
+ std::forward_as_tuple (Context));
844
+ }
845
+
846
+ return UR_RESULT_SUCCESS;
847
+ }
848
+
849
+ static ur_result_t enqueueUSMAllocHelper (
850
+ ur_queue_handle_t Queue, ur_usm_pool_handle_t Pool, const size_t Size ,
851
+ const ur_exp_async_usm_alloc_properties_t *Properties,
852
+ uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
853
+ void **RetMem, ur_event_handle_t *OutEvent, USMAllocType Type) {
854
+ std::ignore = Pool;
855
+ std::ignore = Properties;
856
+
857
+ std::scoped_lock<ur_shared_mutex> lock (Queue->Mutex );
858
+
859
+ bool UseCopyEngine = false ;
860
+ _ur_ze_event_list_t TmpWaitList;
861
+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
862
+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
863
+
864
+ // Get a new command list to be used on this call
865
+ ur_command_list_ptr_t CommandList{};
866
+ UR_CALL (Queue->Context ->getAvailableCommandList (
867
+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList));
868
+
869
+ ze_event_handle_t ZeEvent = nullptr ;
870
+ ur_event_handle_t InternalEvent{};
871
+ bool IsInternal = OutEvent == nullptr ;
872
+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
873
+
874
+ ur_command_t CommandType = UR_COMMAND_FORCE_UINT32;
875
+ switch (Type) {
876
+ case USMAllocType::Host:
877
+ CommandType = UR_COMMAND_ENQUEUE_USM_HOST_ALLOC_EXP;
878
+ break ;
879
+ case USMAllocType::Device:
880
+ CommandType = UR_COMMAND_ENQUEUE_USM_DEVICE_ALLOC_EXP;
881
+ break ;
882
+ case USMAllocType::Shared:
883
+ CommandType = UR_COMMAND_ENQUEUE_USM_SHARED_ALLOC_EXP;
884
+ break ;
885
+ }
886
+ UR_CALL (createEventAndAssociateQueue (Queue, Event, CommandType, CommandList,
887
+ IsInternal, false ));
888
+ ZeEvent = (*Event)->ZeEvent ;
889
+ (*Event)->WaitList = TmpWaitList;
890
+
891
+ // Allocate USM memory
892
+ auto Ret = USMAllocHelper (Queue->Context , Queue->Device , Size , RetMem, Type);
893
+ if (Ret) {
894
+ return Ret;
895
+ }
896
+
897
+ // Signal that USM allocation event was finished
898
+ ZE2UR_CALL (zeCommandListAppendSignalEvent, (CommandList->first , ZeEvent));
899
+
900
+ UR_CALL (Queue->executeCommandList (CommandList, false ));
901
+
902
+ return UR_RESULT_SUCCESS;
903
+ }
904
+
905
+ ur_result_t urEnqueueUSMDeviceAllocExp (
906
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
907
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
908
+ const size_t Size , // /< [in] minimum size in bytes of the USM memory object
909
+ // /< to be allocated
910
+ const ur_exp_async_usm_alloc_properties_t
911
+ *Properties, // /< [in][optional] pointer to the enqueue async alloc
912
+ // /< properties
913
+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
914
+ const ur_event_handle_t
915
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
916
+ // /< pointer to a list of events that must be complete
917
+ // /< before the kernel execution. If nullptr, the
918
+ // /< numEventsInWaitList must be 0, indicating no wait
919
+ // /< events.
920
+ void **Mem, // /< [out] pointer to USM memory object
921
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
922
+ // /< identifies the async alloc
923
+ ) {
924
+ return enqueueUSMAllocHelper (Queue, Pool, Size , Properties,
925
+ NumEventsInWaitList, EventWaitList, Mem,
926
+ OutEvent, USMAllocType::Device);
927
+ }
928
+
929
+ ur_result_t urEnqueueUSMSharedAllocExp (
930
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
931
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
932
+ const size_t Size , // /< [in] minimum size in bytes of the USM memory object
933
+ // /< to be allocated
934
+ const ur_exp_async_usm_alloc_properties_t
935
+ *Properties, // /< [in][optional] pointer to the enqueue async alloc
936
+ // /< properties
937
+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
938
+ const ur_event_handle_t
939
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
940
+ // /< pointer to a list of events that must be complete
941
+ // /< before the kernel execution. If nullptr, the
942
+ // /< numEventsInWaitList must be 0, indicating no wait
943
+ // /< events.
944
+ void **Mem, // /< [out] pointer to USM memory object
945
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
946
+ // /< identifies the async alloc
947
+ ) {
948
+ return enqueueUSMAllocHelper (Queue, Pool, Size , Properties,
949
+ NumEventsInWaitList, EventWaitList, Mem,
950
+ OutEvent, USMAllocType::Shared);
951
+ }
952
+
953
+ ur_result_t urEnqueueUSMHostAllocExp (
954
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
955
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
956
+ const size_t Size , // /< [in] minimum size in bytes of the USM memory object
957
+ // /< to be allocated
958
+ const ur_exp_async_usm_alloc_properties_t
959
+ *Properties, // /< [in][optional] pointer to the enqueue async alloc
960
+ // /< properties
961
+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
962
+ const ur_event_handle_t
963
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
964
+ // /< pointer to a list of events that must be complete
965
+ // /< before the kernel execution. If nullptr, the
966
+ // /< numEventsInWaitList must be 0, indicating no wait
967
+ // /< events.
968
+ void **Mem, // /< [out] pointer to USM memory object
969
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
970
+ // /< identifies the async alloc
971
+ ) {
972
+ return enqueueUSMAllocHelper (Queue, Pool, Size , Properties,
973
+ NumEventsInWaitList, EventWaitList, Mem,
974
+ OutEvent, USMAllocType::Host);
975
+ }
976
+
977
+ ur_result_t urEnqueueUSMFreeExp (
978
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
979
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
980
+ void *Mem, // /< [in] pointer to USM memory object
981
+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
982
+ const ur_event_handle_t
983
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
984
+ // /< pointer to a list of events that must be complete
985
+ // /< before the kernel execution. If nullptr, the
986
+ // /< numEventsInWaitList must be 0, indicating no wait
987
+ // /< events.
988
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
989
+ // /< identifies the async alloc
990
+ ) {
991
+ std::ignore = Pool;
992
+
993
+ std::scoped_lock<ur_shared_mutex> lock (Queue->Mutex );
994
+
995
+ bool UseCopyEngine = false ;
996
+ _ur_ze_event_list_t TmpWaitList;
997
+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
998
+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
999
+
1000
+ // Get a new command list to be used on this call
1001
+ ur_command_list_ptr_t CommandList{};
1002
+ UR_CALL (Queue->Context ->getAvailableCommandList (
1003
+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList));
1004
+
1005
+ ze_event_handle_t ZeEvent = nullptr ;
1006
+ ur_event_handle_t InternalEvent{};
1007
+ bool IsInternal = OutEvent == nullptr ;
1008
+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
1009
+
1010
+ UR_CALL (createEventAndAssociateQueue (Queue, Event,
1011
+ UR_COMMAND_ENQUEUE_USM_FREE_EXP,
1012
+ CommandList, IsInternal, false ));
1013
+ ZeEvent = (*Event)->ZeEvent ;
1014
+ (*Event)->WaitList = TmpWaitList;
1015
+
1016
+ const auto &ZeCommandList = CommandList->first ;
1017
+ const auto &WaitList = (*Event)->WaitList ;
1018
+ if (WaitList.Length ) {
1019
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
1020
+ (ZeCommandList, WaitList.Length , WaitList.ZeEventList ));
1021
+ }
1022
+
1023
+ // Wait for commands execution until USM can be freed
1024
+ UR_CALL (Queue->executeCommandList (CommandList, true )); // Blocking
1025
+
1026
+ // Free USM memory
1027
+ auto Ret = USMFreeHelper (Queue->Context , Mem);
1028
+ if (Ret) {
1029
+ return Ret;
1030
+ }
1031
+
1032
+ // Signal that USM free event was finished
1033
+ ZE2UR_CALL (zeCommandListAppendSignalEvent, (ZeCommandList, ZeEvent));
1034
+
1035
+ UR_CALL (Queue->executeCommandList (CommandList, false ));
1036
+
1037
+ return UR_RESULT_SUCCESS;
1038
+ }
790
1039
} // namespace ur::level_zero
791
1040
792
1041
static ur_result_t USMFreeImpl (ur_context_handle_t Context, void *Ptr ) {
0 commit comments