Skip to content

Commit 6a0eb7e

Browse files
al42andfabiomestre
authored andcommitted
[SYCL][CUDA] Implement ext_oneapi_queue_priority (#11296)
The E2E test assumes that the device supports priorities (sm_35 or newer).
1 parent 187633a commit 6a0eb7e

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

source/adapters/cuda/queue.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ CUstream ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) {
4242
// The second check is done after mutex is locked so other threads can not
4343
// change NumComputeStreams after that
4444
if (NumComputeStreams < ComputeStreams.size()) {
45-
UR_CHECK_ERROR(
46-
cuStreamCreate(&ComputeStreams[NumComputeStreams++], Flags));
45+
UR_CHECK_ERROR(cuStreamCreateWithPriority(
46+
&ComputeStreams[NumComputeStreams++], Flags, Priority));
4747
}
4848
}
4949
Token = ComputeStreamIndex++;
@@ -103,8 +103,8 @@ CUstream ur_queue_handle_t_::getNextTransferStream() {
103103
// The second check is done after mutex is locked so other threads can not
104104
// change NumTransferStreams after that
105105
if (NumTransferStreams < TransferStreams.size()) {
106-
UR_CHECK_ERROR(
107-
cuStreamCreate(&TransferStreams[NumTransferStreams++], Flags));
106+
UR_CHECK_ERROR(cuStreamCreateWithPriority(
107+
&TransferStreams[NumTransferStreams++], Flags, Priority));
108108
}
109109
}
110110
uint32_t StreamI = TransferStreamIndex++ % TransferStreams.size();
@@ -130,6 +130,8 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
130130

131131
unsigned int Flags = CU_STREAM_NON_BLOCKING;
132132
ur_queue_flags_t URFlags = 0;
133+
// '0' is the default priority, per CUDA Toolkit 12.2 and earlier
134+
int Priority = 0;
133135
bool IsOutOfOrder = false;
134136
if (pProps && pProps->stype == UR_STRUCTURE_TYPE_QUEUE_PROPERTIES) {
135137
URFlags = pProps->flags;
@@ -142,6 +144,13 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
142144
if (URFlags & UR_QUEUE_FLAG_OUT_OF_ORDER_EXEC_MODE_ENABLE) {
143145
IsOutOfOrder = true;
144146
}
147+
if (URFlags & UR_QUEUE_FLAG_PRIORITY_HIGH) {
148+
ScopedContext Active(hContext);
149+
UR_CHECK_ERROR(cuCtxGetStreamPriorityRange(nullptr, &Priority));
150+
} else if (URFlags & UR_QUEUE_FLAG_PRIORITY_LOW) {
151+
ScopedContext Active(hContext);
152+
UR_CHECK_ERROR(cuCtxGetStreamPriorityRange(&Priority, nullptr));
153+
}
145154
}
146155

147156
std::vector<CUstream> ComputeCuStreams(
@@ -151,7 +160,7 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
151160

152161
Queue = std::unique_ptr<ur_queue_handle_t_>(new ur_queue_handle_t_{
153162
std::move(ComputeCuStreams), std::move(TransferCuStreams), hContext,
154-
hDevice, Flags, URFlags});
163+
hDevice, Flags, URFlags, Priority});
155164

156165
*phQueue = Queue.release();
157166

source/adapters/cuda/queue.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ struct ur_queue_handle_t_ {
4949
unsigned int LastSyncTransferStreams;
5050
unsigned int Flags;
5151
ur_queue_flags_t URFlags;
52+
int Priority;
5253
// When ComputeStreamSyncMutex and ComputeStreamMutex both need to be
5354
// locked at the same time, ComputeStreamSyncMutex should be locked first
5455
// to avoid deadlocks
@@ -61,7 +62,7 @@ struct ur_queue_handle_t_ {
6162
ur_queue_handle_t_(std::vector<CUstream> &&ComputeStreams,
6263
std::vector<CUstream> &&TransferStreams,
6364
ur_context_handle_t_ *Context, ur_device_handle_t_ *Device,
64-
unsigned int Flags, ur_queue_flags_t URFlags,
65+
unsigned int Flags, ur_queue_flags_t URFlags, int Priority,
6566
bool BackendOwns = true)
6667
: ComputeStreams{std::move(ComputeStreams)}, TransferStreams{std::move(
6768
TransferStreams)},
@@ -71,7 +72,7 @@ struct ur_queue_handle_t_ {
7172
Device{Device}, RefCount{1}, EventCount{0}, ComputeStreamIndex{0},
7273
TransferStreamIndex{0}, NumComputeStreams{0}, NumTransferStreams{0},
7374
LastSyncComputeStreams{0}, LastSyncTransferStreams{0}, Flags(Flags),
74-
URFlags(URFlags), HasOwnership{BackendOwns} {
75+
URFlags(URFlags), Priority(Priority), HasOwnership{BackendOwns} {
7576
urContextRetain(Context);
7677
urDeviceRetain(Device);
7778
}

0 commit comments

Comments
 (0)