Skip to content

Commit 1924cfc

Browse files
committed
[L0 v2] implement urEnqueueMemBuffer[Map/Unmap]
and extend ur_mem_handle_t implementations to support async memory migration (right now, this is only used for keeping data in sync between device and host allocations). Also, implement generic memcpy/fill functions in queue which can be used by both USM and Buffer operations.
1 parent 70c4980 commit 1924cfc

File tree

6 files changed

+525
-338
lines changed

6 files changed

+525
-338
lines changed

source/adapters/level_zero/v2/kernel.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,9 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
313313

314314
auto kernelDevices = hKernel->getDevices();
315315
if (kernelDevices.size() == 1) {
316-
auto zePtr = hArgValue->getPtr(kernelDevices.front());
316+
auto zePtr = hArgValue->getDevicePtr(
317+
kernelDevices.front(), ur_mem_handle_t_::access_mode_t::read_write, 0,
318+
hArgValue->getSize(), nullptr);
317319
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
318320
} else {
319321
// TODO: if devices do not have p2p capabilities, we need to have allocation
@@ -324,7 +326,9 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
324326
// Get memory that is accessible by the first device.
325327
// If kernel is submitted to a different device the memory
326328
// will be accessed trough the link or migrated in enqueueKernelLaunch.
327-
auto zePtr = hArgValue->getPtr(kernelDevices.front());
329+
auto zePtr = hArgValue->getDevicePtr(
330+
kernelDevices.front(), ur_mem_handle_t_::access_mode_t::read_write, 0,
331+
hArgValue->getSize(), nullptr);
328332
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
329333
}
330334
}

source/adapters/level_zero/v2/memory.cpp

Lines changed: 162 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,41 @@
1616
ur_mem_handle_t_::ur_mem_handle_t_(ur_context_handle_t hContext, size_t size)
1717
: hContext(hContext), size(size) {}
1818

19-
ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
20-
void *hostPtr, size_t size,
21-
host_ptr_action_t hostPtrAction)
19+
ur_usm_handle_t_::ur_usm_handle_t_(ur_context_handle_t hContext, size_t size,
20+
const void *ptr)
21+
: ur_mem_handle_t_(hContext, size), ptr(const_cast<void *>(ptr)) {}
22+
23+
ur_usm_handle_t_::~ur_usm_handle_t_() {}
24+
25+
void *ur_usm_handle_t_::getDevicePtr(
26+
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
27+
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
28+
std::ignore = hDevice;
29+
std::ignore = access;
30+
std::ignore = offset;
31+
std::ignore = size;
32+
std::ignore = migrate;
33+
return ptr;
34+
}
35+
36+
void *ur_usm_handle_t_::mapHostPtr(
37+
access_mode_t access, size_t offset, size_t size,
38+
std::function<void(void *src, void *dst, size_t)>) {
39+
std::ignore = access;
40+
std::ignore = offset;
41+
std::ignore = size;
42+
return ptr;
43+
}
44+
45+
void ur_usm_handle_t_::unmapHostPtr(
46+
void *pMappedPtr, std::function<void(void *src, void *dst, size_t)>) {
47+
std::ignore = pMappedPtr;
48+
/* nop */
49+
}
50+
51+
ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
52+
ur_context_handle_t hContext, void *hostPtr, size_t size,
53+
host_ptr_action_t hostPtrAction)
2254
: ur_mem_handle_t_(hContext, size) {
2355
bool hostPtrImported = false;
2456
if (hostPtrAction == host_ptr_action_t::import) {
@@ -37,7 +69,7 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
3769
}
3870
}
3971

40-
ur_host_mem_handle_t::~ur_host_mem_handle_t() {
72+
ur_integrated_mem_handle_t::~ur_integrated_mem_handle_t() {
4173
if (ptr) {
4274
auto ret = hContext->getDefaultUSMPool()->free(ptr);
4375
if (ret != UR_RESULT_SUCCESS) {
@@ -46,21 +78,36 @@ ur_host_mem_handle_t::~ur_host_mem_handle_t() {
4678
}
4779
}
4880

49-
void *ur_host_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
81+
void *ur_integrated_mem_handle_t::getDevicePtr(
82+
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
83+
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
5084
std::ignore = hDevice;
85+
std::ignore = access;
86+
std::ignore = offset;
87+
std::ignore = size;
88+
std::ignore = migrate;
5189
return ptr;
5290
}
5391

54-
ur_result_t ur_device_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice,
55-
void *src, size_t size) {
56-
auto Id = hDevice->Id.value();
92+
void *ur_integrated_mem_handle_t::mapHostPtr(
93+
access_mode_t access, size_t offset, size_t size,
94+
std::function<void(void *src, void *dst, size_t)> migrate) {
95+
std::ignore = access;
96+
std::ignore = offset;
97+
std::ignore = size;
98+
std::ignore = migrate;
99+
return ptr;
100+
}
57101

58-
if (!deviceAllocations[Id]) {
59-
UR_CALL(hContext->getDefaultUSMPool()->allocate(hContext, hDevice, nullptr,
60-
UR_USM_TYPE_DEVICE, size,
61-
&deviceAllocations[Id]));
62-
}
102+
void ur_integrated_mem_handle_t::unmapHostPtr(
103+
void *pMappedPtr, std::function<void(void *src, void *dst, size_t)>) {
104+
std::ignore = pMappedPtr;
105+
/* nop */
106+
}
63107

108+
static ur_result_t synchronousZeCopy(ur_context_handle_t hContext,
109+
ur_device_handle_t hDevice, void *dst,
110+
const void *src, size_t size) {
64111
auto commandList = hContext->commandListCache.getImmediateCommandList(
65112
hDevice->ZeDevice, true,
66113
hDevice
@@ -70,26 +117,42 @@ ur_result_t ur_device_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice,
70117
std::nullopt);
71118

72119
ZE2UR_CALL(zeCommandListAppendMemoryCopy,
73-
(commandList.get(), deviceAllocations[Id], src, size, nullptr, 0,
74-
nullptr));
120+
(commandList.get(), dst, src, size, nullptr, 0, nullptr));
121+
122+
return UR_RESULT_SUCCESS;
123+
}
124+
125+
ur_result_t
126+
ur_discrete_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice, void *src,
127+
size_t size) {
128+
auto Id = hDevice->Id.value();
129+
130+
if (!deviceAllocations[Id]) {
131+
UR_CALL(hContext->getDefaultUSMPool()->allocate(hContext, hDevice, nullptr,
132+
UR_USM_TYPE_DEVICE, size,
133+
&deviceAllocations[Id]));
134+
}
135+
136+
UR_CALL(
137+
synchronousZeCopy(hContext, hDevice, deviceAllocations[Id], src, size));
75138

76139
activeAllocationDevice = hDevice;
77140

78141
return UR_RESULT_SUCCESS;
79142
}
80143

81-
ur_device_mem_handle_t::ur_device_mem_handle_t(ur_context_handle_t hContext,
82-
void *hostPtr, size_t size)
144+
ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(ur_context_handle_t hContext,
145+
void *hostPtr, size_t size)
83146
: ur_mem_handle_t_(hContext, size),
84147
deviceAllocations(hContext->getPlatform()->getNumDevices()),
85-
activeAllocationDevice(nullptr) {
148+
activeAllocationDevice(nullptr), hostAllocations() {
86149
if (hostPtr) {
87150
auto initialDevice = hContext->getDevices()[0];
88151
UR_CALL_THROWS(migrateBufferTo(initialDevice, hostPtr, size));
89152
}
90153
}
91154

92-
ur_device_mem_handle_t::~ur_device_mem_handle_t() {
155+
ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
93156
for (auto &ptr : deviceAllocations) {
94157
if (ptr) {
95158
auto ret = hContext->getDefaultUSMPool()->free(ptr);
@@ -100,8 +163,12 @@ ur_device_mem_handle_t::~ur_device_mem_handle_t() {
100163
}
101164
}
102165

103-
void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
104-
std::lock_guard lock(this->Mutex);
166+
void *ur_discrete_mem_handle_t::getDevicePtrUnlocked(
167+
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
168+
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
169+
std::ignore = access;
170+
std::ignore = size;
171+
std::ignore = migrate;
105172

106173
if (!activeAllocationDevice) {
107174
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
@@ -110,8 +177,10 @@ void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
110177
activeAllocationDevice = hDevice;
111178
}
112179

180+
char *ptr;
113181
if (activeAllocationDevice == hDevice) {
114-
return deviceAllocations[hDevice->Id.value()];
182+
ptr = ur_cast<char *>(deviceAllocations[hDevice->Id.value()]);
183+
return ptr + offset;
115184
}
116185

117186
auto &p2pDevices = hContext->getP2PDevices(hDevice);
@@ -124,7 +193,71 @@ void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
124193
}
125194

126195
// TODO: see if it's better to migrate the memory to the specified device
127-
return deviceAllocations[activeAllocationDevice->Id.value()];
196+
return ur_cast<char *>(
197+
deviceAllocations[activeAllocationDevice->Id.value()]) +
198+
offset;
199+
}
200+
201+
void *ur_discrete_mem_handle_t::getDevicePtr(
202+
ur_device_handle_t hDevice, access_mode_t access, size_t offset,
203+
size_t size, std::function<void(void *src, void *dst, size_t)> migrate) {
204+
std::lock_guard lock(this->Mutex);
205+
return getDevicePtrUnlocked(hDevice, access, offset, size, migrate);
206+
}
207+
208+
void *ur_discrete_mem_handle_t::mapHostPtr(
209+
access_mode_t access, size_t offset, size_t size,
210+
std::function<void(void *src, void *dst, size_t)> migrate) {
211+
std::lock_guard lock(this->Mutex);
212+
213+
// TODO: use async alloc?
214+
215+
void *ptr;
216+
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
217+
hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &ptr));
218+
219+
hostAllocations.emplace_back(ptr, size, offset, access);
220+
221+
if (activeAllocationDevice && access != access_mode_t::write_only) {
222+
auto srcPtr =
223+
ur_cast<char *>(deviceAllocations[activeAllocationDevice->Id.value()]) +
224+
offset;
225+
migrate(srcPtr, hostAllocations.back().ptr, size);
226+
}
227+
228+
return hostAllocations.back().ptr;
229+
}
230+
231+
void ur_discrete_mem_handle_t::unmapHostPtr(
232+
void *pMappedPtr,
233+
std::function<void(void *src, void *dst, size_t)> migrate) {
234+
std::lock_guard lock(this->Mutex);
235+
236+
for (auto &hostAllocation : hostAllocations) {
237+
if (hostAllocation.ptr == pMappedPtr) {
238+
void *devicePtr = nullptr;
239+
if (activeAllocationDevice) {
240+
devicePtr = ur_cast<char *>(
241+
deviceAllocations[activeAllocationDevice->Id.value()]) +
242+
hostAllocation.offset;
243+
} else if (hostAllocation.access != access_mode_t::write_invalidate) {
244+
devicePtr = ur_cast<char *>(getDevicePtrUnlocked(
245+
hContext->getDevices()[0], access_mode_t::read_only,
246+
hostAllocation.offset, hostAllocation.size, migrate));
247+
}
248+
249+
if (devicePtr) {
250+
migrate(hostAllocation.ptr, devicePtr, hostAllocation.size);
251+
}
252+
253+
// TODO: use async free here?
254+
UR_CALL_THROWS(hContext->getDefaultUSMPool()->free(hostAllocation.ptr));
255+
return;
256+
}
257+
}
258+
259+
// No mapping found
260+
throw UR_RESULT_ERROR_INVALID_ARGUMENT;
128261
}
129262

130263
namespace ur::level_zero {
@@ -155,13 +288,14 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
155288
if (useHostBuffer) {
156289
// TODO: assert that if hostPtr is set, either UR_MEM_FLAG_USE_HOST_POINTER
157290
// or UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER is set?
158-
auto hostPtrAction = flags & UR_MEM_FLAG_USE_HOST_POINTER
159-
? ur_host_mem_handle_t::host_ptr_action_t::import
160-
: ur_host_mem_handle_t::host_ptr_action_t::copy;
291+
auto hostPtrAction =
292+
flags & UR_MEM_FLAG_USE_HOST_POINTER
293+
? ur_integrated_mem_handle_t::host_ptr_action_t::import
294+
: ur_integrated_mem_handle_t::host_ptr_action_t::copy;
161295
*phBuffer =
162-
new ur_host_mem_handle_t(hContext, hostPtr, size, hostPtrAction);
296+
new ur_integrated_mem_handle_t(hContext, hostPtr, size, hostPtrAction);
163297
} else {
164-
*phBuffer = new ur_device_mem_handle_t(hContext, hostPtr, size);
298+
*phBuffer = new ur_discrete_mem_handle_t(hContext, hostPtr, size);
165299
}
166300

167301
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)