@@ -187,8 +187,15 @@ static ur_result_t USMDeviceAllocImpl(void **ResultPtr,
187
187
ZeDesc.pNext = &RelaxedDesc;
188
188
}
189
189
190
- ZE2UR_CALL (zeMemAllocDevice, (Context->ZeContext , &ZeDesc, Size , Alignment,
191
- Device->ZeDevice , ResultPtr));
190
+ ze_result_t ZeResult = ZE_CALL_NOCHECK (
191
+ zeMemAllocDevice, (Context->ZeContext , &ZeDesc, Size , Alignment,
192
+ Device->ZeDevice , ResultPtr));
193
+ if (ZeResult != ZE_RESULT_SUCCESS) {
194
+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
195
+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
196
+ }
197
+ return ze2urResult (ZeResult);
198
+ }
192
199
193
200
UR_ASSERT (Alignment == 0 ||
194
201
reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -226,8 +233,15 @@ static ur_result_t USMSharedAllocImpl(void **ResultPtr,
226
233
ZeDevDesc.pNext = &RelaxedDesc;
227
234
}
228
235
229
- ZE2UR_CALL (zeMemAllocShared, (Context->ZeContext , &ZeDevDesc, &ZeHostDesc,
230
- Size , Alignment, Device->ZeDevice , ResultPtr));
236
+ ze_result_t ZeResult = ZE_CALL_NOCHECK (
237
+ zeMemAllocShared, (Context->ZeContext , &ZeDevDesc, &ZeHostDesc, Size ,
238
+ Alignment, Device->ZeDevice , ResultPtr));
239
+ if (ZeResult != ZE_RESULT_SUCCESS) {
240
+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
241
+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
242
+ }
243
+ return ze2urResult (ZeResult);
244
+ }
231
245
232
246
UR_ASSERT (Alignment == 0 ||
233
247
reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -254,8 +268,15 @@ static ur_result_t USMHostAllocImpl(void **ResultPtr,
254
268
// TODO: translate PI properties to Level Zero flags
255
269
ZeStruct<ze_host_mem_alloc_desc_t > ZeHostDesc;
256
270
ZeHostDesc.flags = 0 ;
257
- ZE2UR_CALL (zeMemAllocHost,
258
- (Context->ZeContext , &ZeHostDesc, Size , Alignment, ResultPtr));
271
+ ze_result_t ZeResult =
272
+ ZE_CALL_NOCHECK (zeMemAllocHost, (Context->ZeContext , &ZeHostDesc, Size ,
273
+ Alignment, ResultPtr));
274
+ if (ZeResult != ZE_RESULT_SUCCESS) {
275
+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
276
+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
277
+ }
278
+ return ze2urResult (ZeResult);
279
+ }
259
280
260
281
UR_ASSERT (Alignment == 0 ||
261
282
reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -599,6 +620,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo(
599
620
ZE2UR_CALL (zeMemGetAddressRange, (Context->ZeContext , Ptr , nullptr , &Size ));
600
621
return ReturnValue (Size );
601
622
}
623
+ case UR_USM_ALLOC_INFO_POOL: {
624
+ auto UMFPool = umfPoolByPtr (Ptr );
625
+ if (!UMFPool) {
626
+ return UR_RESULT_ERROR_INVALID_VALUE;
627
+ }
628
+
629
+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
630
+
631
+ auto SearchMatchingPool =
632
+ [](std::unordered_map<ur_device_handle_t , umf::pool_unique_handle_t >
633
+ &PoolMap,
634
+ umf_memory_pool_handle_t UMFPool) {
635
+ for (auto &PoolPair : PoolMap) {
636
+ if (PoolPair.second .get () == UMFPool) {
637
+ return true ;
638
+ }
639
+ }
640
+ return false ;
641
+ };
642
+
643
+ for (auto &Pool : Context->UsmPoolHandles ) {
644
+ if (SearchMatchingPool (Pool->DeviceMemPools , UMFPool)) {
645
+ return ReturnValue (Pool);
646
+ }
647
+ if (SearchMatchingPool (Pool->SharedMemPools , UMFPool)) {
648
+ return ReturnValue (Pool);
649
+ }
650
+ if (Pool->HostMemPool .get () == UMFPool) {
651
+ return ReturnValue (Pool);
652
+ }
653
+ }
654
+
655
+ return UR_RESULT_ERROR_INVALID_VALUE;
656
+ }
602
657
default :
603
658
urPrint (" urUSMGetMemAllocInfo: unsupported ParamName\n " );
604
659
return UR_RESULT_ERROR_INVALID_VALUE;
@@ -748,6 +803,7 @@ ur_result_t L0HostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
748
803
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t Context,
749
804
ur_usm_pool_desc_t *PoolDesc) {
750
805
806
+ this ->Context = Context;
751
807
zeroInit = static_cast <uint32_t >(PoolDesc->flags &
752
808
UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK);
753
809
@@ -831,6 +887,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
831
887
try {
832
888
*Pool = reinterpret_cast <ur_usm_pool_handle_t >(
833
889
new ur_usm_pool_handle_t_ (Context, PoolDesc));
890
+
891
+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
892
+ Context->UsmPoolHandles .insert (Context->UsmPoolHandles .cend (), *Pool);
893
+
834
894
} catch (const UsmAllocationException &Ex) {
835
895
return Ex.getError ();
836
896
}
@@ -848,6 +908,8 @@ ur_result_t
848
908
urUSMPoolRelease (ur_usm_pool_handle_t Pool // /< [in] pointer to USM memory pool
849
909
) {
850
910
if (Pool->RefCount .decrementAndTest ()) {
911
+ std::shared_lock<ur_shared_mutex> ContextLock (Pool->Context ->Mutex );
912
+ Pool->Context ->UsmPoolHandles .remove (Pool);
851
913
delete Pool;
852
914
}
853
915
return UR_RESULT_SUCCESS;
@@ -861,13 +923,19 @@ ur_result_t urUSMPoolGetInfo(
861
923
// /< property
862
924
size_t *PropSizeRet // /< [out] size in bytes returned in pool property value
863
925
) {
864
- std::ignore = Pool;
865
- std::ignore = PropName;
866
- std::ignore = PropSize;
867
- std::ignore = PropValue;
868
- std::ignore = PropSizeRet;
869
- urPrint (" [UR][L0] %s function not implemented!\n " , __FUNCTION__);
870
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
926
+ UrReturnHelper ReturnValue (PropSize, PropValue, PropSizeRet);
927
+
928
+ switch (PropName) {
929
+ case UR_USM_POOL_INFO_REFERENCE_COUNT: {
930
+ return ReturnValue (Pool->RefCount .load ());
931
+ }
932
+ case UR_USM_POOL_INFO_CONTEXT: {
933
+ return ReturnValue (Pool->Context );
934
+ }
935
+ default : {
936
+ return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
937
+ }
938
+ }
871
939
}
872
940
873
941
// If indirect access tracking is not enabled then this functions just performs
0 commit comments