@@ -588,124 +588,257 @@ ur_result_t urUSMReleaseExp(ur_context_handle_t Context, void *HostPtr) {
588
588
return UR_RESULT_SUCCESS;
589
589
}
590
590
591
+ enum class USMAllocType { Host = 0 , Device = 1 , Shared = 2 };
592
+
593
+ static ur_result_t USMAllocHelper (ur_context_handle_t Context,
594
+ ur_device_handle_t Device, size_t Size ,
595
+ void **RetMem, USMAllocType Type) {
596
+ auto &Platform = Device->Platform ;
597
+
598
+ // TODO: Should alignemnt be passed in 'ur_exp_async_usm_alloc_properties_t'?
599
+ uint32_t Alignment = 0 ;
600
+
601
+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex ,
602
+ std::defer_lock);
603
+ std::unique_lock<ur_shared_mutex> IndirectAccessTrackingLock (
604
+ Platform->ContextsMutex , std::defer_lock);
605
+ if (IndirectAccessTrackingEnabled) {
606
+ IndirectAccessTrackingLock.lock ();
607
+ UR_CALL (ur::level_zero::urContextRetain (Context));
608
+ } else {
609
+ ContextLock.lock ();
610
+ }
611
+
612
+ umf_memory_pool_handle_t hPoolInternal = nullptr ;
613
+ switch (Type) {
614
+ case USMAllocType::Host:
615
+ hPoolInternal = Context->AsyncHostMemPool .get ();
616
+ break ;
617
+ case USMAllocType::Device: {
618
+ auto It = Context->AsyncDeviceMemPools .find (Device->ZeDevice );
619
+ if (It == Context->AsyncDeviceMemPools .end ()) {
620
+ return UR_RESULT_ERROR_INVALID_VALUE;
621
+ }
622
+ hPoolInternal = It->second .get ();
623
+ } break ;
624
+ case USMAllocType::Shared: {
625
+ auto It = Context->AsyncSharedMemPools .find (Device->ZeDevice );
626
+ if (It == Context->AsyncSharedMemPools .end ()) {
627
+ return UR_RESULT_ERROR_INVALID_VALUE;
628
+ }
629
+ hPoolInternal = It->second .get ();
630
+ } break ;
631
+ };
632
+
633
+ *RetMem = umfPoolAlignedMalloc (hPoolInternal, Size , Alignment);
634
+ if (*RetMem == nullptr ) {
635
+ auto umfRet = umfPoolGetLastAllocationError (hPoolInternal);
636
+ return umf2urResult (umfRet);
637
+ }
638
+
639
+ if (IndirectAccessTrackingEnabled) {
640
+ // Keep track of all memory allocations in the context
641
+ Context->MemAllocs .emplace (std::piecewise_construct,
642
+ std::forward_as_tuple (*RetMem),
643
+ std::forward_as_tuple (Context));
644
+ }
645
+
646
+ return UR_RESULT_SUCCESS;
647
+ }
648
+
649
+ static ur_result_t enqueueUSMAllocHelper (
650
+ ur_queue_handle_t Queue, ur_usm_pool_handle_t Pool, const size_t Size ,
651
+ const ur_exp_enqueue_usm_alloc_properties_t *Properties,
652
+ uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
653
+ void **RetMem, ur_event_handle_t *OutEvent, USMAllocType Type) {
654
+ std::ignore = Pool;
655
+ std::ignore = Properties;
656
+
657
+ std::scoped_lock<ur_shared_mutex> lock (Queue->Mutex );
658
+
659
+ bool UseCopyEngine = false ;
660
+ _ur_ze_event_list_t TmpWaitList;
661
+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
662
+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
663
+
664
+ bool OkToBatch = true ;
665
+ // Get a new command list to be used on this call
666
+ ur_command_list_ptr_t CommandList{};
667
+ UR_CALL (Queue->Context ->getAvailableCommandList (
668
+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
669
+ OkToBatch, nullptr /* ForcedCmdQueue*/ ));
670
+
671
+ ze_event_handle_t ZeEvent = nullptr ;
672
+ ur_event_handle_t InternalEvent{};
673
+ bool IsInternal = OutEvent == nullptr ;
674
+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
675
+
676
+ ur_command_t CommandType = UR_COMMAND_FORCE_UINT32;
677
+ switch (Type) {
678
+ case USMAllocType::Host:
679
+ CommandType = UR_COMMAND_ENQUEUE_USM_HOST_ALLOC_EXP;
680
+ break ;
681
+ case USMAllocType::Device:
682
+ CommandType = UR_COMMAND_ENQUEUE_USM_DEVICE_ALLOC_EXP;
683
+ break ;
684
+ case USMAllocType::Shared:
685
+ CommandType = UR_COMMAND_ENQUEUE_USM_SHARED_ALLOC_EXP;
686
+ break ;
687
+ }
688
+ UR_CALL (createEventAndAssociateQueue (Queue, Event, CommandType, CommandList,
689
+ IsInternal, false ));
690
+ ZeEvent = (*Event)->ZeEvent ;
691
+ (*Event)->WaitList = TmpWaitList;
692
+
693
+ // Allocate USM memory
694
+ auto Ret = USMAllocHelper (Queue->Context , Queue->Device , Size , RetMem, Type);
695
+ if (Ret) {
696
+ return Ret;
697
+ }
698
+
699
+ // Signal that USM allocation event was finished
700
+ ZE2UR_CALL (zeCommandListAppendSignalEvent, (CommandList->first , ZeEvent));
701
+
702
+ UR_CALL (Queue->executeCommandList (CommandList, false , OkToBatch));
703
+
704
+ return UR_RESULT_SUCCESS;
705
+ }
706
+
591
707
ur_result_t urEnqueueUSMDeviceAllocExp (
592
- ur_queue_handle_t hQueue, // /< [in] handle of the queue object
593
- ur_usm_pool_handle_t
594
- pPool, // /< [in][optional] handle of the USM memory pool
595
- const size_t size, // /< [in] minimum size in bytes of the USM memory object
708
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
709
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
710
+ const size_t Size , // /< [in] minimum size in bytes of the USM memory object
596
711
// /< to be allocated
597
712
const ur_exp_enqueue_usm_alloc_properties_t
598
- *pProperties , // /< [in][optional] pointer to the enqueue asynchronous
599
- // /< USM allocation properties
600
- uint32_t numEventsInWaitList , // /< [in] size of the event wait list
713
+ *Properties , // /< [in][optional] pointer to the enqueue async alloc
714
+ // /< properties
715
+ uint32_t NumEventsInWaitList , // /< [in] size of the event wait list
601
716
const ur_event_handle_t
602
- *phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
603
- // /< pointer to a list of events that must be complete
604
- // /< before the kernel execution. If nullptr, the
605
- // /< numEventsInWaitList must be 0, indicating no wait
606
- // /< events.
607
- void **ppMem, // /< [out] pointer to USM memory object
608
- ur_event_handle_t
609
- *phEvent // /< [out][optional] return an event object that identifies the
610
- // /< asynchronous USM device allocation
717
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
718
+ // /< pointer to a list of events that must be complete
719
+ // /< before the kernel execution. If nullptr, the
720
+ // /< numEventsInWaitList must be 0, indicating no wait
721
+ // /< events.
722
+ void **Mem, // /< [out] pointer to USM memory object
723
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
724
+ // /< identifies the async alloc
611
725
) {
612
- std::ignore = hQueue;
613
- std::ignore = pPool;
614
- std::ignore = size;
615
- std::ignore = pProperties;
616
- std::ignore = numEventsInWaitList;
617
- std::ignore = phEventWaitList;
618
- std::ignore = ppMem;
619
- std::ignore = phEvent;
620
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
726
+ return enqueueUSMAllocHelper (Queue, Pool, Size , Properties,
727
+ NumEventsInWaitList, EventWaitList, Mem,
728
+ OutEvent, USMAllocType::Device);
621
729
}
622
730
623
731
ur_result_t urEnqueueUSMSharedAllocExp (
624
- ur_queue_handle_t hQueue, // /< [in] handle of the queue object
625
- ur_usm_pool_handle_t
626
- pPool, // /< [in][optional] handle of the USM memory pool
627
- const size_t size, // /< [in] minimum size in bytes of the USM memory object
732
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
733
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
734
+ const size_t Size , // /< [in] minimum size in bytes of the USM memory object
628
735
// /< to be allocated
629
736
const ur_exp_enqueue_usm_alloc_properties_t
630
- *pProperties , // /< [in][optional] pointer to the enqueue asynchronous
631
- // /< USM allocation properties
632
- uint32_t numEventsInWaitList , // /< [in] size of the event wait list
737
+ *Properties , // /< [in][optional] pointer to the enqueue async alloc
738
+ // /< properties
739
+ uint32_t NumEventsInWaitList , // /< [in] size of the event wait list
633
740
const ur_event_handle_t
634
- *phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
635
- // /< pointer to a list of events that must be complete
636
- // /< before the kernel execution. If nullptr, the
637
- // /< numEventsInWaitList must be 0, indicating no wait
638
- // /< events.
639
- void **ppMem, // /< [out] pointer to USM memory object
640
- ur_event_handle_t
641
- *phEvent // /< [out][optional] return an event object that identifies the
642
- // /< asynchronous USM shared allocation
741
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
742
+ // /< pointer to a list of events that must be complete
743
+ // /< before the kernel execution. If nullptr, the
744
+ // /< numEventsInWaitList must be 0, indicating no wait
745
+ // /< events.
746
+ void **Mem, // /< [out] pointer to USM memory object
747
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
748
+ // /< identifies the async alloc
643
749
) {
644
- std::ignore = hQueue;
645
- std::ignore = pPool;
646
- std::ignore = size;
647
- std::ignore = pProperties;
648
- std::ignore = numEventsInWaitList;
649
- std::ignore = phEventWaitList;
650
- std::ignore = ppMem;
651
- std::ignore = phEvent;
652
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
750
+ return enqueueUSMAllocHelper (Queue, Pool, Size , Properties,
751
+ NumEventsInWaitList, EventWaitList, Mem,
752
+ OutEvent, USMAllocType::Shared);
653
753
}
654
754
655
755
ur_result_t urEnqueueUSMHostAllocExp (
656
- ur_queue_handle_t hQueue, // /< [in] handle of the queue object
657
- ur_usm_pool_handle_t
658
- pPool, // /< [in][optional] handle of the USM memory pool
659
- const size_t size, // /< [in] minimum size in bytes of the USM memory object
756
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
757
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
758
+ const size_t Size , // /< [in] minimum size in bytes of the USM memory object
660
759
// /< to be allocated
661
760
const ur_exp_enqueue_usm_alloc_properties_t
662
- *pProperties , // /< [in][optional] pointer to the enqueue asynchronous
663
- // /< USM allocation properties
664
- uint32_t numEventsInWaitList , // /< [in] size of the event wait list
761
+ *Properties , // /< [in][optional] pointer to the enqueue async alloc
762
+ // /< properties
763
+ uint32_t NumEventsInWaitList , // /< [in] size of the event wait list
665
764
const ur_event_handle_t
666
- *phEventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
667
- // /< pointer to a list of events that must be complete
668
- // /< before the kernel execution. If nullptr, the
669
- // /< numEventsInWaitList must be 0, indicating no wait
670
- // /< events.
671
- void **ppMem, // /< [out] pointer to USM memory object
672
- ur_event_handle_t
673
- *phEvent // /< [out][optional] return an event object that identifies the
674
- // /< asynchronous USM host allocation
765
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
766
+ // /< pointer to a list of events that must be complete
767
+ // /< before the kernel execution. If nullptr, the
768
+ // /< numEventsInWaitList must be 0, indicating no wait
769
+ // /< events.
770
+ void **Mem, // /< [out] pointer to USM memory object
771
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
772
+ // /< identifies the async alloc
675
773
) {
676
- std::ignore = hQueue;
677
- std::ignore = pPool;
678
- std::ignore = size;
679
- std::ignore = pProperties;
680
- std::ignore = numEventsInWaitList;
681
- std::ignore = phEventWaitList;
682
- std::ignore = ppMem;
683
- std::ignore = phEvent;
684
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
774
+ return enqueueUSMAllocHelper (Queue, Pool, Size , Properties,
775
+ NumEventsInWaitList, EventWaitList, Mem,
776
+ OutEvent, USMAllocType::Host);
685
777
}
686
778
687
779
ur_result_t urEnqueueUSMFreeExp (
688
- ur_queue_handle_t hQueue, // /< [in] handle of the queue object
689
- ur_usm_pool_handle_t
690
- pPool, // /< [in][optional] handle of the USM memory pooliptor
691
- void *pMem, // /< [in] pointer to USM memory object
692
- uint32_t numEventsInWaitList, // /< [in] size of the event wait list
780
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
781
+ ur_usm_pool_handle_t Pool, // /< [in][optional] USM pool descriptor
782
+ void *Mem, // /< [in] pointer to USM memory object
783
+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
693
784
const ur_event_handle_t
694
- *phEventWaitList , // /< [in][optional][range(0, numEventsInWaitList)]
695
- // /< pointer to a list of events that must be complete
696
- // /< before the kernel execution. If nullptr, the
697
- // /< numEventsInWaitList must be 0, indicating no wait
698
- // /< events.
699
- ur_event_handle_t *phEvent // /< [out][optional] return an event object that
700
- // /< identifies the asynchronous USM deallocation
785
+ *EventWaitList , // /< [in][optional][range(0, numEventsInWaitList)]
786
+ // /< pointer to a list of events that must be complete
787
+ // /< before the kernel execution. If nullptr, the
788
+ // /< numEventsInWaitList must be 0, indicating no wait
789
+ // /< events.
790
+ ur_event_handle_t *OutEvent // /< [out][optional] return an event object that
791
+ // /< identifies the async alloc
701
792
) {
702
- std::ignore = hQueue;
703
- std::ignore = pPool;
704
- std::ignore = pMem;
705
- std::ignore = numEventsInWaitList;
706
- std::ignore = phEventWaitList;
707
- std::ignore = phEvent;
708
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
793
+ std::ignore = Pool;
794
+
795
+ std::scoped_lock<ur_shared_mutex> lock (Queue->Mutex );
796
+
797
+ bool UseCopyEngine = false ;
798
+ _ur_ze_event_list_t TmpWaitList;
799
+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
800
+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
801
+
802
+ bool OkToBatch = false ;
803
+ // Get a new command list to be used on this call
804
+ ur_command_list_ptr_t CommandList{};
805
+ UR_CALL (Queue->Context ->getAvailableCommandList (
806
+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
807
+ OkToBatch, nullptr /* ForcedCmdQueue*/ ));
808
+
809
+ ze_event_handle_t ZeEvent = nullptr ;
810
+ ur_event_handle_t InternalEvent{};
811
+ bool IsInternal = OutEvent == nullptr ;
812
+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
813
+
814
+ UR_CALL (createEventAndAssociateQueue (Queue, Event,
815
+ UR_COMMAND_ENQUEUE_USM_FREE_EXP,
816
+ CommandList, IsInternal, false ));
817
+ ZeEvent = (*Event)->ZeEvent ;
818
+ (*Event)->WaitList = TmpWaitList;
819
+
820
+ const auto &ZeCommandList = CommandList->first ;
821
+ const auto &WaitList = (*Event)->WaitList ;
822
+ if (WaitList.Length ) {
823
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
824
+ (ZeCommandList, WaitList.Length , WaitList.ZeEventList ));
825
+ }
826
+
827
+ // Wait for commands execution until USM can be freed
828
+ UR_CALL (Queue->executeCommandList (CommandList, true , OkToBatch)); // Blocking
829
+
830
+ // Free USM memory
831
+ auto Ret = USMFreeHelper (Queue->Context , Mem);
832
+ if (Ret) {
833
+ return Ret;
834
+ }
835
+
836
+ // Signal that USM free event was finished
837
+ ZE2UR_CALL (zeCommandListAppendSignalEvent, (ZeCommandList, ZeEvent));
838
+
839
+ UR_CALL (Queue->executeCommandList (CommandList, false , OkToBatch));
840
+
841
+ return UR_RESULT_SUCCESS;
709
842
}
710
843
} // namespace ur::level_zero
711
844
0 commit comments