From 1bbc84c46f28734f4ab91bf1ef65cde087026bce Mon Sep 17 00:00:00 2001 From: Gilles Grospellier Date: Sun, 1 Dec 2024 09:50:16 +0100 Subject: [PATCH] =?UTF-8?q?[arcane,accelerator]=20Ajoute=20surcharge=20de?= =?UTF-8?q?=20'IMemoryAllocator::memoryResource()'=20pour=20les=20allocate?= =?UTF-8?q?urs=20des=20acc=C3=A9l=C3=A9rateurs.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../arcane/accelerator/cuda/CudaAccelerator.cc | 3 +++ .../src/arcane/accelerator/hip/HipAccelerator.cc | 3 +++ .../arcane/accelerator/sycl/SyclAccelerator.cc | 3 +++ arcane/src/arcane/accelerator/tests/TestInit.cc | 16 ++++++++++++++++ 4 files changed, 25 insertions(+) diff --git a/arcane/src/arcane/accelerator/cuda/CudaAccelerator.cc b/arcane/src/arcane/accelerator/cuda/CudaAccelerator.cc index bc4ba688b8..72e55cee3a 100644 --- a/arcane/src/arcane/accelerator/cuda/CudaAccelerator.cc +++ b/arcane/src/arcane/accelerator/cuda/CudaAccelerator.cc @@ -461,6 +461,7 @@ class UnifiedMemoryCudaMemoryAllocator if (p && s > 0) _applyHint(ptr.baseAddress(), ptr.size(), new_args); } + eMemoryResource memoryResource() const override { return eMemoryResource::UnifiedMemory; } protected: @@ -542,6 +543,7 @@ class HostPinnedCudaMemoryAllocator _setUseMemoryPool(use_memory_pool); m_block_wrapper.initialize(128, use_memory_pool); } + eMemoryResource memoryResource() const override { return eMemoryResource::HostPinned; } }; /*---------------------------------------------------------------------------*/ @@ -607,6 +609,7 @@ class DeviceCudaMemoryAllocator _setUseMemoryPool(use_memory_pool); m_block_wrapper.initialize(128, use_memory_pool); } + eMemoryResource memoryResource() const override { return eMemoryResource::Device; } }; /*---------------------------------------------------------------------------*/ diff --git a/arcane/src/arcane/accelerator/hip/HipAccelerator.cc b/arcane/src/arcane/accelerator/hip/HipAccelerator.cc index bc3a4dfcd9..6987d29886 100644 --- a/arcane/src/arcane/accelerator/hip/HipAccelerator.cc +++ b/arcane/src/arcane/accelerator/hip/HipAccelerator.cc @@ -106,6 +106,7 @@ class UnifiedMemoryHipMemoryAllocator { return ::hipFree(ptr); } + eMemoryResource memoryResource() const override { return eMemoryResource::UnifiedMemory; } }; /*---------------------------------------------------------------------------*/ @@ -124,6 +125,7 @@ class HostPinnedHipMemoryAllocator { return ::hipHostFree(ptr); } + eMemoryResource memoryResource() const override { return eMemoryResource::HostPinned; } }; /*---------------------------------------------------------------------------*/ @@ -142,6 +144,7 @@ class DeviceHipMemoryAllocator { return ::hipFree(ptr); } + eMemoryResource memoryResource() const override { return eMemoryResource::Device; } }; /*---------------------------------------------------------------------------*/ diff --git a/arcane/src/arcane/accelerator/sycl/SyclAccelerator.cc b/arcane/src/arcane/accelerator/sycl/SyclAccelerator.cc index f2132a0a5e..750c0186e2 100644 --- a/arcane/src/arcane/accelerator/sycl/SyclAccelerator.cc +++ b/arcane/src/arcane/accelerator/sycl/SyclAccelerator.cc @@ -107,6 +107,7 @@ class UnifiedMemorySyclMemoryAllocator { sycl::free(ptr, q); } + eMemoryResource memoryResource() const override { return eMemoryResource::UnifiedMemory; } }; /*---------------------------------------------------------------------------*/ @@ -126,6 +127,7 @@ class HostPinnedSyclMemoryAllocator { sycl::free(ptr, q); } + eMemoryResource memoryResource() const override { return eMemoryResource::HostPinned; } }; /*---------------------------------------------------------------------------*/ @@ -144,6 +146,7 @@ class DeviceSyclMemoryAllocator { sycl::free(ptr, q); } + eMemoryResource memoryResource() const override { return eMemoryResource::Device; } }; /*---------------------------------------------------------------------------*/ diff --git a/arcane/src/arcane/accelerator/tests/TestInit.cc b/arcane/src/arcane/accelerator/tests/TestInit.cc index 60889cd07b..262e2c1e62 100644 --- a/arcane/src/arcane/accelerator/tests/TestInit.cc +++ b/arcane/src/arcane/accelerator/tests/TestInit.cc @@ -7,6 +7,8 @@ #include +#include "arcane/utils/MemoryUtils.h" +#include "arcane/utils/MemoryAllocator.h" #include "arcane/accelerator/core/Runner.h" #include "arcane/accelerator/core/RunQueue.h" @@ -38,6 +40,20 @@ void _doTest1() Runner runner(exec_policy); RunQueue queue(makeQueue(runner)); ASSERT_TRUE(queue.executionPolicy() == exec_policy); + + eMemoryResource mr = eMemoryResource::Host; + ASSERT_EQ(MemoryUtils::getAllocator(mr)->memoryResource(),mr); + + if (queue.isAcceleratorPolicy()){ + mr = eMemoryResource::HostPinned; + ASSERT_EQ(MemoryUtils::getAllocator(mr)->memoryResource(),mr); + + mr = eMemoryResource::Device; + ASSERT_EQ(MemoryUtils::getAllocator(mr)->memoryResource(),mr); + + mr = eMemoryResource::UnifiedMemory; + ASSERT_EQ(MemoryUtils::getAllocator(mr)->memoryResource(),mr); + } } /*---------------------------------------------------------------------------*/