@@ -42,8 +42,8 @@ CUstream ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) {
42
42
// The second check is done after mutex is locked so other threads can not
43
43
// change NumComputeStreams after that
44
44
if (NumComputeStreams < ComputeStreams.size ()) {
45
- UR_CHECK_ERROR (
46
- cuStreamCreate ( &ComputeStreams[NumComputeStreams++], Flags));
45
+ UR_CHECK_ERROR (cuStreamCreateWithPriority (
46
+ &ComputeStreams[NumComputeStreams++], Flags, Priority ));
47
47
}
48
48
}
49
49
Token = ComputeStreamIndex++;
@@ -103,8 +103,8 @@ CUstream ur_queue_handle_t_::getNextTransferStream() {
103
103
// The second check is done after mutex is locked so other threads can not
104
104
// change NumTransferStreams after that
105
105
if (NumTransferStreams < TransferStreams.size ()) {
106
- UR_CHECK_ERROR (
107
- cuStreamCreate ( &TransferStreams[NumTransferStreams++], Flags));
106
+ UR_CHECK_ERROR (cuStreamCreateWithPriority (
107
+ &TransferStreams[NumTransferStreams++], Flags, Priority ));
108
108
}
109
109
}
110
110
uint32_t StreamI = TransferStreamIndex++ % TransferStreams.size ();
@@ -130,6 +130,8 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
130
130
131
131
unsigned int Flags = CU_STREAM_NON_BLOCKING;
132
132
ur_queue_flags_t URFlags = 0 ;
133
+ // '0' is the default priority, per CUDA Toolkit 12.2 and earlier
134
+ int Priority = 0 ;
133
135
bool IsOutOfOrder = false ;
134
136
if (pProps && pProps->stype == UR_STRUCTURE_TYPE_QUEUE_PROPERTIES) {
135
137
URFlags = pProps->flags ;
@@ -142,6 +144,13 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
142
144
if (URFlags & UR_QUEUE_FLAG_OUT_OF_ORDER_EXEC_MODE_ENABLE) {
143
145
IsOutOfOrder = true ;
144
146
}
147
+ if (URFlags & UR_QUEUE_FLAG_PRIORITY_HIGH) {
148
+ ScopedContext Active (hContext);
149
+ UR_CHECK_ERROR (cuCtxGetStreamPriorityRange (nullptr , &Priority));
150
+ } else if (URFlags & UR_QUEUE_FLAG_PRIORITY_LOW) {
151
+ ScopedContext Active (hContext);
152
+ UR_CHECK_ERROR (cuCtxGetStreamPriorityRange (&Priority, nullptr ));
153
+ }
145
154
}
146
155
147
156
std::vector<CUstream> ComputeCuStreams (
@@ -151,7 +160,7 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
151
160
152
161
Queue = std::unique_ptr<ur_queue_handle_t_>(new ur_queue_handle_t_{
153
162
std::move (ComputeCuStreams), std::move (TransferCuStreams), hContext,
154
- hDevice, Flags, URFlags});
163
+ hDevice, Flags, URFlags, Priority });
155
164
156
165
*phQueue = Queue.release ();
157
166
0 commit comments