Skip to content

Commit 099446e

Browse files
authored
amrex::Gpu::DeviceVector et al. (#545)
Follow-up to #540, renaming the `PODVector_{real,int,uint64}_default` to `DeviceVector_{real,int,uint64}` et al. Implement the aliases for `AMReX_GpuContainers.H` - [x] resolve inconsistency mentioned in AMReX-Codes/amrex#5123 --------- Signed-off-by: Axel Huebl <axel.huebl@plasma.ninja>
1 parent a1e7a2c commit 099446e

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

src/Base/PODVector.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,29 @@ void make_PODVector(py::module &m, std::string typestr)
168168
#endif
169169
make_PODVector<T, amrex::PolymorphicArenaAllocator<T>> (m, typestr, "polymorphic");
170170

171-
// Alias matching Gpu::DeviceVector<T> — resolves per platform:
172-
// CPU: PODVector_<type>_std, GPU: PODVector_<type>_arena
171+
// Implement AMReX_GpuContainers.H
172+
// Alias matching Gpu::DeviceVector<T> etc. — resolves per platform:
173+
// CPU: PODVector_<type>_std, GPU: PODVector_<type>_arena
174+
constexpr auto cstr = [](std::string const & a, std::string const & b) { return a + "_" + b; };
175+
#ifdef AMREX_USE_GPU
176+
m.attr(cstr("DeviceVector", typestr).c_str()) = m.attr(str_PODVector(typestr, "arena").c_str());
177+
m.attr(cstr("NonManagedDeviceVector", typestr).c_str()) = m.attr(str_PODVector(typestr, "device").c_str());
178+
m.attr(cstr("ManagedVector", typestr).c_str()) = m.attr(str_PODVector(typestr, "managed").c_str());
179+
m.attr(cstr("ManagedDeviceVector", typestr).c_str()) = m.attr(str_PODVector(typestr, "managed").c_str());
180+
m.attr(cstr("PinnedVector", typestr).c_str()) = m.attr(str_PODVector(typestr, "pinned").c_str());
181+
m.attr(cstr("AsyncVector", typestr).c_str()) = m.attr(str_PODVector(typestr, "async").c_str());
182+
m.attr(cstr("HostVector", typestr).c_str()) = m.attr(str_PODVector(typestr, "pinned").c_str());
183+
#else
184+
py::object const std_pod = m.attr(str_PODVector(typestr, "std").c_str());
185+
m.attr(cstr("DeviceVector", typestr).c_str()) = std_pod;
186+
m.attr(cstr("NonManagedDeviceVector", typestr).c_str()) = std_pod;
187+
m.attr(cstr("ManagedVector", typestr).c_str()) = std_pod;
188+
m.attr(cstr("ManagedDeviceVector", typestr).c_str()) = std_pod;
189+
m.attr(cstr("PinnedVector", typestr).c_str()) = std_pod;
190+
m.attr(cstr("AsyncVector", typestr).c_str()) = std_pod;
191+
m.attr(cstr("HostVector", typestr).c_str()) = std_pod;
192+
#endif
193+
173194
auto const default_name = str_PODVector(typestr, "default");
174195
m.attr(default_name.c_str()) =
175196
#ifdef AMREX_USE_GPU

0 commit comments

Comments
 (0)