@@ -31,6 +31,50 @@ ur_command_list_manager::~ur_command_list_manager() {
31
31
ur::level_zero::urDeviceRelease (device);
32
32
}
33
33
34
+ ur_result_t ur_command_list_manager::appendGenericFillUnlocked (
35
+ ur_mem_handle_t dst, size_t offset, size_t patternSize,
36
+ const void *pPattern, size_t size, uint32_t numEventsInWaitList,
37
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent,
38
+ ur_command_t commandType) {
39
+
40
+ auto zeSignalEvent = getSignalEvent (phEvent, commandType);
41
+
42
+ auto waitListView = getWaitListView (phEventWaitList, numEventsInWaitList);
43
+
44
+ auto pDst = ur_cast<char *>(dst->getDevicePtr (
45
+ device, ur_mem_handle_t_::device_access_mode_t ::read_only, offset, size,
46
+ [&](void *src, void *dst, size_t size) {
47
+ ZE2UR_CALL_THROWS (zeCommandListAppendMemoryCopy,
48
+ (zeCommandList.get (), dst, src, size, nullptr ,
49
+ waitListView.num , waitListView.handles ));
50
+ waitListView.clear ();
51
+ }));
52
+
53
+ // PatternSize must be a power of two for zeCommandListAppendMemoryFill.
54
+ // When it's not, the fill is emulated with zeCommandListAppendMemoryCopy.
55
+ if (isPowerOf2 (patternSize)) {
56
+ ZE2UR_CALL (zeCommandListAppendMemoryFill,
57
+ (zeCommandList.get (), pDst, pPattern, patternSize, size,
58
+ zeSignalEvent, waitListView.num , waitListView.handles ));
59
+ } else {
60
+ // Copy pattern into every entry in memory array pointed by Ptr.
61
+ uint32_t numOfCopySteps = size / patternSize;
62
+ const void *src = pPattern;
63
+
64
+ for (uint32_t step = 0 ; step < numOfCopySteps; ++step) {
65
+ void *dst = reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(pDst) +
66
+ step * patternSize);
67
+ ZE2UR_CALL (zeCommandListAppendMemoryCopy,
68
+ (zeCommandList.get (), dst, src, patternSize,
69
+ step == numOfCopySteps - 1 ? zeSignalEvent : nullptr ,
70
+ waitListView.num , waitListView.handles ));
71
+ waitListView.clear ();
72
+ }
73
+ }
74
+
75
+ return UR_RESULT_SUCCESS;
76
+ }
77
+
34
78
ur_result_t ur_command_list_manager::appendGenericCopyUnlocked (
35
79
ur_mem_buffer_t *src, ur_mem_buffer_t *dst, bool blocking, size_t srcOffset,
36
80
size_t dstOffset, size_t size, uint32_t numEventsInWaitList,
@@ -209,6 +253,95 @@ ur_result_t ur_command_list_manager::appendUSMMemcpy(
209
253
return UR_RESULT_SUCCESS;
210
254
}
211
255
256
+ ur_result_t ur_command_list_manager::appendMemBufferFill (
257
+ ur_mem_handle_t hBuffer, const void *pPattern, size_t patternSize,
258
+ size_t offset, size_t size, uint32_t numEventsInWaitList,
259
+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
260
+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendMemBufferFill" );
261
+
262
+ UR_ASSERT (offset + size <= hBuffer->getSize (), UR_RESULT_ERROR_INVALID_SIZE);
263
+
264
+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex> lock (this ->Mutex ,
265
+ hBuffer->getMutex ());
266
+
267
+ return appendGenericFillUnlocked (hBuffer, offset, patternSize, pPattern, size,
268
+ numEventsInWaitList, phEventWaitList,
269
+ phEvent, UR_COMMAND_MEM_BUFFER_FILL);
270
+ }
271
+
272
+ ur_result_t ur_command_list_manager::appendUSMFill (
273
+ void *pMem, size_t patternSize, const void *pPattern, size_t size,
274
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
275
+ ur_event_handle_t *phEvent) {
276
+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendUSMFill" );
277
+
278
+ std::scoped_lock<ur_shared_mutex> lock (this ->Mutex );
279
+
280
+ ur_usm_handle_t_ dstHandle (context, size, pMem);
281
+ return appendGenericFillUnlocked (&dstHandle, 0 , patternSize, pPattern, size,
282
+ numEventsInWaitList, phEventWaitList,
283
+ phEvent, UR_COMMAND_USM_FILL);
284
+ }
285
+
286
+ ur_result_t ur_command_list_manager::appendUSMPrefetch (
287
+ const void *pMem, size_t size, ur_usm_migration_flags_t flags,
288
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
289
+ ur_event_handle_t *phEvent) {
290
+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendUSMPrefetch" );
291
+
292
+ std::ignore = flags;
293
+
294
+ std::scoped_lock<ur_shared_mutex> lock (this ->Mutex );
295
+
296
+ auto zeSignalEvent = getSignalEvent (phEvent, UR_COMMAND_USM_PREFETCH);
297
+
298
+ auto [pWaitEvents, numWaitEvents] =
299
+ getWaitListView (phEventWaitList, numEventsInWaitList);
300
+
301
+ if (pWaitEvents) {
302
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
303
+ (zeCommandList.get (), numWaitEvents, pWaitEvents));
304
+ }
305
+ // TODO: figure out how to translate "flags"
306
+ ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
307
+ (zeCommandList.get (), pMem, size));
308
+ if (zeSignalEvent) {
309
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
310
+ (zeCommandList.get (), zeSignalEvent));
311
+ }
312
+
313
+ return UR_RESULT_SUCCESS;
314
+ }
315
+
316
+ ur_result_t
317
+ ur_command_list_manager::appendUSMAdvise (const void *pMem, size_t size,
318
+ ur_usm_advice_flags_t advice,
319
+ ur_event_handle_t *phEvent) {
320
+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendUSMAdvise" );
321
+
322
+ std::scoped_lock<ur_shared_mutex> lock (this ->Mutex );
323
+
324
+ auto zeAdvice = ur_cast<ze_memory_advice_t >(advice);
325
+
326
+ auto zeSignalEvent = getSignalEvent (phEvent, UR_COMMAND_USM_ADVISE);
327
+
328
+ auto [pWaitEvents, numWaitEvents] = getWaitListView (nullptr , 0 );
329
+
330
+ if (pWaitEvents) {
331
+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
332
+ (zeCommandList.get (), numWaitEvents, pWaitEvents));
333
+ }
334
+
335
+ ZE2UR_CALL (zeCommandListAppendMemAdvise,
336
+ (zeCommandList.get (), device->ZeDevice , pMem, size, zeAdvice));
337
+
338
+ if (zeSignalEvent) {
339
+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
340
+ (zeCommandList.get (), zeSignalEvent));
341
+ }
342
+ return UR_RESULT_SUCCESS;
343
+ }
344
+
212
345
ur_result_t ur_command_list_manager::appendMemBufferRead (
213
346
ur_mem_handle_t hMem, bool blockingRead, size_t offset, size_t size,
214
347
void *pDst, uint32_t numEventsInWaitList,
0 commit comments