16
16
ur_mem_handle_t_::ur_mem_handle_t_ (ur_context_handle_t hContext, size_t size)
17
17
: hContext(hContext), size(size) {}
18
18
19
- ur_host_mem_handle_t ::ur_host_mem_handle_t (ur_context_handle_t hContext,
20
- void *hostPtr, size_t size,
21
- host_ptr_action_t hostPtrAction)
19
+ ur_usm_handle_t_::ur_usm_handle_t_ (ur_context_handle_t hContext, size_t size,
20
+ const void *ptr)
21
+ : ur_mem_handle_t_(hContext, size), ptr(const_cast <void *>(ptr)) {}
22
+
23
+ ur_usm_handle_t_::~ur_usm_handle_t_ () {}
24
+
25
+ void *ur_usm_handle_t_::getDevicePtr (
26
+ ur_device_handle_t hDevice, access_mode_t access, size_t offset,
27
+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
28
+ std::ignore = hDevice;
29
+ std::ignore = access ;
30
+ std::ignore = offset;
31
+ std::ignore = size;
32
+ std::ignore = migrate;
33
+ return ptr;
34
+ }
35
+
36
+ void *ur_usm_handle_t_::mapHostPtr (
37
+ access_mode_t access, size_t offset, size_t size,
38
+ std::function<void (void *src, void *dst, size_t )>) {
39
+ std::ignore = access ;
40
+ std::ignore = offset;
41
+ std::ignore = size;
42
+ return ptr;
43
+ }
44
+
45
+ void ur_usm_handle_t_::unmapHostPtr (
46
+ void *pMappedPtr, std::function<void (void *src, void *dst, size_t )>) {
47
+ std::ignore = pMappedPtr;
48
+ /* nop */
49
+ }
50
+
51
+ ur_integrated_mem_handle_t ::ur_integrated_mem_handle_t (
52
+ ur_context_handle_t hContext, void *hostPtr, size_t size,
53
+ host_ptr_action_t hostPtrAction)
22
54
: ur_mem_handle_t_(hContext, size) {
23
55
bool hostPtrImported = false ;
24
56
if (hostPtrAction == host_ptr_action_t ::import) {
@@ -37,7 +69,7 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
37
69
}
38
70
}
39
71
40
- ur_host_mem_handle_t ::~ur_host_mem_handle_t () {
72
+ ur_integrated_mem_handle_t ::~ur_integrated_mem_handle_t () {
41
73
if (ptr) {
42
74
auto ret = hContext->getDefaultUSMPool ()->free (ptr);
43
75
if (ret != UR_RESULT_SUCCESS) {
@@ -46,21 +78,36 @@ ur_host_mem_handle_t::~ur_host_mem_handle_t() {
46
78
}
47
79
}
48
80
49
- void *ur_host_mem_handle_t ::getPtr(ur_device_handle_t hDevice) {
81
+ void *ur_integrated_mem_handle_t ::getDevicePtr(
82
+ ur_device_handle_t hDevice, access_mode_t access, size_t offset,
83
+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
50
84
std::ignore = hDevice;
85
+ std::ignore = access ;
86
+ std::ignore = offset;
87
+ std::ignore = size;
88
+ std::ignore = migrate;
51
89
return ptr;
52
90
}
53
91
54
- ur_result_t ur_device_mem_handle_t::migrateBufferTo (ur_device_handle_t hDevice,
55
- void *src, size_t size) {
56
- auto Id = hDevice->Id .value ();
92
+ void *ur_integrated_mem_handle_t ::mapHostPtr(
93
+ access_mode_t access, size_t offset, size_t size,
94
+ std::function<void (void *src, void *dst, size_t )> migrate) {
95
+ std::ignore = access ;
96
+ std::ignore = offset;
97
+ std::ignore = size;
98
+ std::ignore = migrate;
99
+ return ptr;
100
+ }
57
101
58
- if (!deviceAllocations[Id]) {
59
- UR_CALL (hContext-> getDefaultUSMPool ()-> allocate (hContext, hDevice, nullptr ,
60
- UR_USM_TYPE_DEVICE, size,
61
- &deviceAllocations[Id]));
62
- }
102
+ void ur_integrated_mem_handle_t::unmapHostPtr (
103
+ void *pMappedPtr, std::function< void ( void *src, void *dst, size_t )>) {
104
+ std::ignore = pMappedPtr;
105
+ /* nop */
106
+ }
63
107
108
+ static ur_result_t synchronousZeCopy (ur_context_handle_t hContext,
109
+ ur_device_handle_t hDevice, void *dst,
110
+ const void *src, size_t size) {
64
111
auto commandList = hContext->commandListCache .getImmediateCommandList (
65
112
hDevice->ZeDevice , true ,
66
113
hDevice
@@ -70,26 +117,42 @@ ur_result_t ur_device_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice,
70
117
std::nullopt);
71
118
72
119
ZE2UR_CALL (zeCommandListAppendMemoryCopy,
73
- (commandList.get (), deviceAllocations[Id], src, size, nullptr , 0 ,
74
- nullptr ));
120
+ (commandList.get (), dst, src, size, nullptr , 0 , nullptr ));
121
+
122
+ return UR_RESULT_SUCCESS;
123
+ }
124
+
125
+ ur_result_t
126
+ ur_discrete_mem_handle_t ::migrateBufferTo(ur_device_handle_t hDevice, void *src,
127
+ size_t size) {
128
+ auto Id = hDevice->Id .value ();
129
+
130
+ if (!deviceAllocations[Id]) {
131
+ UR_CALL (hContext->getDefaultUSMPool ()->allocate (hContext, hDevice, nullptr ,
132
+ UR_USM_TYPE_DEVICE, size,
133
+ &deviceAllocations[Id]));
134
+ }
135
+
136
+ UR_CALL (
137
+ synchronousZeCopy (hContext, hDevice, deviceAllocations[Id], src, size));
75
138
76
139
activeAllocationDevice = hDevice;
77
140
78
141
return UR_RESULT_SUCCESS;
79
142
}
80
143
81
- ur_device_mem_handle_t :: ur_device_mem_handle_t (ur_context_handle_t hContext,
82
- void *hostPtr, size_t size)
144
+ ur_discrete_mem_handle_t :: ur_discrete_mem_handle_t (ur_context_handle_t hContext,
145
+ void *hostPtr, size_t size)
83
146
: ur_mem_handle_t_(hContext, size),
84
147
deviceAllocations (hContext->getPlatform ()->getNumDevices()),
85
- activeAllocationDevice(nullptr ) {
148
+ activeAllocationDevice(nullptr ), hostAllocations() {
86
149
if (hostPtr) {
87
150
auto initialDevice = hContext->getDevices ()[0 ];
88
151
UR_CALL_THROWS (migrateBufferTo (initialDevice, hostPtr, size));
89
152
}
90
153
}
91
154
92
- ur_device_mem_handle_t ::~ur_device_mem_handle_t () {
155
+ ur_discrete_mem_handle_t ::~ur_discrete_mem_handle_t () {
93
156
for (auto &ptr : deviceAllocations) {
94
157
if (ptr) {
95
158
auto ret = hContext->getDefaultUSMPool ()->free (ptr);
@@ -100,8 +163,12 @@ ur_device_mem_handle_t::~ur_device_mem_handle_t() {
100
163
}
101
164
}
102
165
103
- void *ur_device_mem_handle_t ::getPtr(ur_device_handle_t hDevice) {
104
- std::lock_guard lock (this ->Mutex );
166
+ void *ur_discrete_mem_handle_t ::getDevicePtrUnlocked(
167
+ ur_device_handle_t hDevice, access_mode_t access, size_t offset,
168
+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
169
+ std::ignore = access ;
170
+ std::ignore = size;
171
+ std::ignore = migrate;
105
172
106
173
if (!activeAllocationDevice) {
107
174
UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
@@ -110,8 +177,10 @@ void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
110
177
activeAllocationDevice = hDevice;
111
178
}
112
179
180
+ char *ptr;
113
181
if (activeAllocationDevice == hDevice) {
114
- return deviceAllocations[hDevice->Id .value ()];
182
+ ptr = ur_cast<char *>(deviceAllocations[hDevice->Id .value ()]);
183
+ return ptr + offset;
115
184
}
116
185
117
186
auto &p2pDevices = hContext->getP2PDevices (hDevice);
@@ -124,7 +193,71 @@ void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
124
193
}
125
194
126
195
// TODO: see if it's better to migrate the memory to the specified device
127
- return deviceAllocations[activeAllocationDevice->Id .value ()];
196
+ return ur_cast<char *>(
197
+ deviceAllocations[activeAllocationDevice->Id .value ()]) +
198
+ offset;
199
+ }
200
+
201
+ void *ur_discrete_mem_handle_t ::getDevicePtr(
202
+ ur_device_handle_t hDevice, access_mode_t access, size_t offset,
203
+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
204
+ std::lock_guard lock (this ->Mutex );
205
+ return getDevicePtrUnlocked (hDevice, access , offset, size, migrate);
206
+ }
207
+
208
+ void *ur_discrete_mem_handle_t ::mapHostPtr(
209
+ access_mode_t access, size_t offset, size_t size,
210
+ std::function<void (void *src, void *dst, size_t )> migrate) {
211
+ std::lock_guard lock (this ->Mutex );
212
+
213
+ // TODO: use async alloc?
214
+
215
+ void *ptr;
216
+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
217
+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
218
+
219
+ hostAllocations.emplace_back (ptr, size, offset, access );
220
+
221
+ if (activeAllocationDevice && access != access_mode_t ::write_only) {
222
+ auto srcPtr =
223
+ ur_cast<char *>(deviceAllocations[activeAllocationDevice->Id .value ()]) +
224
+ offset;
225
+ migrate (srcPtr, hostAllocations.back ().ptr , size);
226
+ }
227
+
228
+ return hostAllocations.back ().ptr ;
229
+ }
230
+
231
+ void ur_discrete_mem_handle_t::unmapHostPtr (
232
+ void *pMappedPtr,
233
+ std::function<void (void *src, void *dst, size_t )> migrate) {
234
+ std::lock_guard lock (this ->Mutex );
235
+
236
+ for (auto &hostAllocation : hostAllocations) {
237
+ if (hostAllocation.ptr == pMappedPtr) {
238
+ void *devicePtr = nullptr ;
239
+ if (activeAllocationDevice) {
240
+ devicePtr = ur_cast<char *>(
241
+ deviceAllocations[activeAllocationDevice->Id .value ()]) +
242
+ hostAllocation.offset ;
243
+ } else if (hostAllocation.access != access_mode_t ::write_invalidate) {
244
+ devicePtr = ur_cast<char *>(getDevicePtrUnlocked (
245
+ hContext->getDevices ()[0 ], access_mode_t ::read_only,
246
+ hostAllocation.offset , hostAllocation.size , migrate));
247
+ }
248
+
249
+ if (devicePtr) {
250
+ migrate (hostAllocation.ptr , devicePtr, hostAllocation.size );
251
+ }
252
+
253
+ // TODO: use async free here?
254
+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->free (hostAllocation.ptr ));
255
+ return ;
256
+ }
257
+ }
258
+
259
+ // No mapping found
260
+ throw UR_RESULT_ERROR_INVALID_ARGUMENT;
128
261
}
129
262
130
263
namespace ur ::level_zero {
@@ -155,13 +288,14 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
155
288
if (useHostBuffer) {
156
289
// TODO: assert that if hostPtr is set, either UR_MEM_FLAG_USE_HOST_POINTER
157
290
// or UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER is set?
158
- auto hostPtrAction = flags & UR_MEM_FLAG_USE_HOST_POINTER
159
- ? ur_host_mem_handle_t ::host_ptr_action_t ::import
160
- : ur_host_mem_handle_t ::host_ptr_action_t ::copy;
291
+ auto hostPtrAction =
292
+ flags & UR_MEM_FLAG_USE_HOST_POINTER
293
+ ? ur_integrated_mem_handle_t ::host_ptr_action_t ::import
294
+ : ur_integrated_mem_handle_t ::host_ptr_action_t ::copy;
161
295
*phBuffer =
162
- new ur_host_mem_handle_t (hContext, hostPtr, size, hostPtrAction);
296
+ new ur_integrated_mem_handle_t (hContext, hostPtr, size, hostPtrAction);
163
297
} else {
164
- *phBuffer = new ur_device_mem_handle_t (hContext, hostPtr, size);
298
+ *phBuffer = new ur_discrete_mem_handle_t (hContext, hostPtr, size);
165
299
}
166
300
167
301
return UR_RESULT_SUCCESS;
0 commit comments