diff --git a/arcane/src/arcane/accelerator/cuda/CudaAccelerator.cc b/arcane/src/arcane/accelerator/cuda/CudaAccelerator.cc index bc4ba688b8..72e55cee3a 100644 --- a/arcane/src/arcane/accelerator/cuda/CudaAccelerator.cc +++ b/arcane/src/arcane/accelerator/cuda/CudaAccelerator.cc @@ -461,6 +461,7 @@ class UnifiedMemoryCudaMemoryAllocator if (p && s > 0) _applyHint(ptr.baseAddress(), ptr.size(), new_args); } + eMemoryResource memoryResource() const override { return eMemoryResource::UnifiedMemory; } protected: @@ -542,6 +543,7 @@ class HostPinnedCudaMemoryAllocator _setUseMemoryPool(use_memory_pool); m_block_wrapper.initialize(128, use_memory_pool); } + eMemoryResource memoryResource() const override { return eMemoryResource::HostPinned; } }; /*---------------------------------------------------------------------------*/ @@ -607,6 +609,7 @@ class DeviceCudaMemoryAllocator _setUseMemoryPool(use_memory_pool); m_block_wrapper.initialize(128, use_memory_pool); } + eMemoryResource memoryResource() const override { return eMemoryResource::Device; } }; /*---------------------------------------------------------------------------*/ diff --git a/arcane/src/arcane/accelerator/hip/HipAccelerator.cc b/arcane/src/arcane/accelerator/hip/HipAccelerator.cc index bc3a4dfcd9..6987d29886 100644 --- a/arcane/src/arcane/accelerator/hip/HipAccelerator.cc +++ b/arcane/src/arcane/accelerator/hip/HipAccelerator.cc @@ -106,6 +106,7 @@ class UnifiedMemoryHipMemoryAllocator { return ::hipFree(ptr); } + eMemoryResource memoryResource() const override { return eMemoryResource::UnifiedMemory; } }; /*---------------------------------------------------------------------------*/ @@ -124,6 +125,7 @@ class HostPinnedHipMemoryAllocator { return ::hipHostFree(ptr); } + eMemoryResource memoryResource() const override { return eMemoryResource::HostPinned; } }; /*---------------------------------------------------------------------------*/ @@ -142,6 +144,7 @@ class DeviceHipMemoryAllocator { return ::hipFree(ptr); } + eMemoryResource memoryResource() const override { return eMemoryResource::Device; } }; /*---------------------------------------------------------------------------*/ diff --git a/arcane/src/arcane/accelerator/sycl/SyclAccelerator.cc b/arcane/src/arcane/accelerator/sycl/SyclAccelerator.cc index f2132a0a5e..750c0186e2 100644 --- a/arcane/src/arcane/accelerator/sycl/SyclAccelerator.cc +++ b/arcane/src/arcane/accelerator/sycl/SyclAccelerator.cc @@ -107,6 +107,7 @@ class UnifiedMemorySyclMemoryAllocator { sycl::free(ptr, q); } + eMemoryResource memoryResource() const override { return eMemoryResource::UnifiedMemory; } }; /*---------------------------------------------------------------------------*/ @@ -126,6 +127,7 @@ class HostPinnedSyclMemoryAllocator { sycl::free(ptr, q); } + eMemoryResource memoryResource() const override { return eMemoryResource::HostPinned; } }; /*---------------------------------------------------------------------------*/ @@ -144,6 +146,7 @@ class DeviceSyclMemoryAllocator { sycl::free(ptr, q); } + eMemoryResource memoryResource() const override { return eMemoryResource::Device; } }; /*---------------------------------------------------------------------------*/ diff --git a/arcane/src/arcane/accelerator/tests/TestInit.cc b/arcane/src/arcane/accelerator/tests/TestInit.cc index 60889cd07b..262e2c1e62 100644 --- a/arcane/src/arcane/accelerator/tests/TestInit.cc +++ b/arcane/src/arcane/accelerator/tests/TestInit.cc @@ -7,6 +7,8 @@ #include +#include "arcane/utils/MemoryUtils.h" +#include "arcane/utils/MemoryAllocator.h" #include "arcane/accelerator/core/Runner.h" #include "arcane/accelerator/core/RunQueue.h" @@ -38,6 +40,20 @@ void _doTest1() Runner runner(exec_policy); RunQueue queue(makeQueue(runner)); ASSERT_TRUE(queue.executionPolicy() == exec_policy); + + eMemoryResource mr = eMemoryResource::Host; + ASSERT_EQ(MemoryUtils::getAllocator(mr)->memoryResource(),mr); + + if (queue.isAcceleratorPolicy()){ + mr = eMemoryResource::HostPinned; + ASSERT_EQ(MemoryUtils::getAllocator(mr)->memoryResource(),mr); + + mr = eMemoryResource::Device; + ASSERT_EQ(MemoryUtils::getAllocator(mr)->memoryResource(),mr); + + mr = eMemoryResource::UnifiedMemory; + ASSERT_EQ(MemoryUtils::getAllocator(mr)->memoryResource(),mr); + } } /*---------------------------------------------------------------------------*/