@@ -31,6 +31,50 @@ ur_command_list_manager::~ur_command_list_manager() {
3131 ur::level_zero::urDeviceRelease (device);
3232}
3333
34+ ur_result_t ur_command_list_manager::appendGenericFillUnlocked (
35+ ur_mem_handle_t dst, size_t offset, size_t patternSize,
36+ const void *pPattern, size_t size, uint32_t numEventsInWaitList,
37+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent,
38+ ur_command_t commandType) {
39+
40+ auto zeSignalEvent = getSignalEvent (phEvent, commandType);
41+
42+ auto waitListView = getWaitListView (phEventWaitList, numEventsInWaitList);
43+
44+ auto pDst = ur_cast<char *>(dst->getDevicePtr (
45+ device, ur_mem_handle_t_::device_access_mode_t ::read_only, offset, size,
46+ [&](void *src, void *dst, size_t size) {
47+ ZE2UR_CALL_THROWS (zeCommandListAppendMemoryCopy,
48+ (zeCommandList.get (), dst, src, size, nullptr ,
49+ waitListView.num , waitListView.handles ));
50+ waitListView.clear ();
51+ }));
52+
53+ // PatternSize must be a power of two for zeCommandListAppendMemoryFill.
54+ // When it's not, the fill is emulated with zeCommandListAppendMemoryCopy.
55+ if (isPowerOf2 (patternSize)) {
56+ ZE2UR_CALL (zeCommandListAppendMemoryFill,
57+ (zeCommandList.get (), pDst, pPattern, patternSize, size,
58+ zeSignalEvent, waitListView.num , waitListView.handles ));
59+ } else {
60+ // Copy pattern into every entry in memory array pointed by Ptr.
61+ uint32_t numOfCopySteps = size / patternSize;
62+ const void *src = pPattern;
63+
64+ for (uint32_t step = 0 ; step < numOfCopySteps; ++step) {
65+ void *dst = reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(pDst) +
66+ step * patternSize);
67+ ZE2UR_CALL (zeCommandListAppendMemoryCopy,
68+ (zeCommandList.get (), dst, src, patternSize,
69+ step == numOfCopySteps - 1 ? zeSignalEvent : nullptr ,
70+ waitListView.num , waitListView.handles ));
71+ waitListView.clear ();
72+ }
73+ }
74+
75+ return UR_RESULT_SUCCESS;
76+ }
77+
3478ur_result_t ur_command_list_manager::appendGenericCopyUnlocked (
3579 ur_mem_handle_t src, ur_mem_handle_t dst, bool blocking, size_t srcOffset,
3680 size_t dstOffset, size_t size, uint32_t numEventsInWaitList,
@@ -209,6 +253,95 @@ ur_result_t ur_command_list_manager::appendUSMMemcpy(
209253 return UR_RESULT_SUCCESS;
210254}
211255
256+ ur_result_t ur_command_list_manager::appendMemBufferFill (
257+ ur_mem_handle_t hBuffer, const void *pPattern, size_t patternSize,
258+ size_t offset, size_t size, uint32_t numEventsInWaitList,
259+ const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
260+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendMemBufferFill" );
261+
262+ UR_ASSERT (offset + size <= hBuffer->getSize (), UR_RESULT_ERROR_INVALID_SIZE);
263+
264+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex> lock (this ->Mutex ,
265+ hBuffer->getMutex ());
266+
267+ return appendGenericFillUnlocked (hBuffer, offset, patternSize, pPattern, size,
268+ numEventsInWaitList, phEventWaitList,
269+ phEvent, UR_COMMAND_MEM_BUFFER_FILL);
270+ }
271+
272+ ur_result_t ur_command_list_manager::appendUSMFill (
273+ void *pMem, size_t patternSize, const void *pPattern, size_t size,
274+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
275+ ur_event_handle_t *phEvent) {
276+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendUSMFill" );
277+
278+ std::scoped_lock<ur_shared_mutex> lock (this ->Mutex );
279+
280+ ur_usm_handle_t_ dstHandle (context, size, pMem);
281+ return appendGenericFillUnlocked (&dstHandle, 0 , patternSize, pPattern, size,
282+ numEventsInWaitList, phEventWaitList,
283+ phEvent, UR_COMMAND_USM_FILL);
284+ }
285+
286+ ur_result_t ur_command_list_manager::appendUSMPrefetch (
287+ const void *pMem, size_t size, ur_usm_migration_flags_t flags,
288+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
289+ ur_event_handle_t *phEvent) {
290+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendUSMPrefetch" );
291+
292+ std::ignore = flags;
293+
294+ std::scoped_lock<ur_shared_mutex> lock (this ->Mutex );
295+
296+ auto zeSignalEvent = getSignalEvent (phEvent, UR_COMMAND_USM_PREFETCH);
297+
298+ auto [pWaitEvents, numWaitEvents] =
299+ getWaitListView (phEventWaitList, numEventsInWaitList);
300+
301+ if (pWaitEvents) {
302+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
303+ (zeCommandList.get (), numWaitEvents, pWaitEvents));
304+ }
305+ // TODO: figure out how to translate "flags"
306+ ZE2UR_CALL (zeCommandListAppendMemoryPrefetch,
307+ (zeCommandList.get (), pMem, size));
308+ if (zeSignalEvent) {
309+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
310+ (zeCommandList.get (), zeSignalEvent));
311+ }
312+
313+ return UR_RESULT_SUCCESS;
314+ }
315+
316+ ur_result_t
317+ ur_command_list_manager::appendUSMAdvise (const void *pMem, size_t size,
318+ ur_usm_advice_flags_t advice,
319+ ur_event_handle_t *phEvent) {
320+ TRACK_SCOPE_LATENCY (" ur_command_list_manager::appendUSMAdvise" );
321+
322+ std::scoped_lock<ur_shared_mutex> lock (this ->Mutex );
323+
324+ auto zeAdvice = ur_cast<ze_memory_advice_t >(advice);
325+
326+ auto zeSignalEvent = getSignalEvent (phEvent, UR_COMMAND_USM_ADVISE);
327+
328+ auto [pWaitEvents, numWaitEvents] = getWaitListView (nullptr , 0 );
329+
330+ if (pWaitEvents) {
331+ ZE2UR_CALL (zeCommandListAppendWaitOnEvents,
332+ (zeCommandList.get (), numWaitEvents, pWaitEvents));
333+ }
334+
335+ ZE2UR_CALL (zeCommandListAppendMemAdvise,
336+ (zeCommandList.get (), device->ZeDevice , pMem, size, zeAdvice));
337+
338+ if (zeSignalEvent) {
339+ ZE2UR_CALL (zeCommandListAppendSignalEvent,
340+ (zeCommandList.get (), zeSignalEvent));
341+ }
342+ return UR_RESULT_SUCCESS;
343+ }
344+
212345ur_result_t ur_command_list_manager::appendMemBufferRead (
213346 ur_mem_handle_t hBuffer, bool blockingRead, size_t offset, size_t size,
214347 void *pDst, uint32_t numEventsInWaitList,
0 commit comments