Skip to content

Commit 2994564

Browse files
committed
[UR] Refactor cuda adapter tests to work on multi-device runner.
1 parent e40aa08 commit 2994564

File tree

4 files changed

+12
-34
lines changed

4 files changed

+12
-34
lines changed

test/adapters/cuda/context_tests.cpp

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,37 +43,6 @@ TEST_P(cudaUrContextCreateTest, CreateWithChildThread) {
4343
callContextFromOtherThread.join();
4444
}
4545

46-
TEST_P(cudaUrContextCreateTest, ActiveContext) {
47-
uur::raii::Context context = nullptr;
48-
ASSERT_SUCCESS(urContextCreate(1, &device, nullptr, context.ptr()));
49-
ASSERT_NE(context, nullptr);
50-
51-
uur::raii::Queue queue = nullptr;
52-
ur_queue_properties_t queue_props{UR_STRUCTURE_TYPE_QUEUE_PROPERTIES, nullptr,
53-
0};
54-
ASSERT_SUCCESS(urQueueCreate(context, device, &queue_props, queue.ptr()));
55-
ASSERT_NE(queue, nullptr);
56-
57-
// check that the queue has the correct context
58-
ASSERT_EQ(context, queue->getContext());
59-
60-
// create a buffer
61-
uur::raii::Mem buffer = nullptr;
62-
ASSERT_SUCCESS(urMemBufferCreate(context, UR_MEM_FLAG_READ_WRITE, 1024,
63-
nullptr, buffer.ptr()));
64-
ASSERT_NE(buffer, nullptr);
65-
66-
// check that the context is now the active CUDA context
67-
CUcontext cudaCtx = nullptr;
68-
ASSERT_SUCCESS_CUDA(cuCtxGetCurrent(&cudaCtx));
69-
ASSERT_NE(cudaCtx, nullptr);
70-
71-
ur_native_handle_t native_context = 0;
72-
ASSERT_SUCCESS(urContextGetNativeHandle(context, &native_context));
73-
ASSERT_NE(reinterpret_cast<CUcontext>(native_context), nullptr);
74-
ASSERT_EQ(cudaCtx, reinterpret_cast<CUcontext>(native_context));
75-
}
76-
7746
TEST_P(cudaUrContextCreateTest, ContextLifetimeExisting) {
7847
// start by setting up a CUDA context on the thread
7948
CUcontext original;

test/adapters/cuda/memory_tests.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ TEST_P(cudaMemoryTest, urMemBufferNoActiveContext) {
1414
constexpr size_t memSize = 1024u;
1515

1616
CUcontext current = nullptr;
17-
do {
17+
ASSERT_SUCCESS_CUDA(cuCtxGetCurrent(&current));
18+
while (current != nullptr) {
1819
CUcontext oldContext = nullptr;
1920
ASSERT_SUCCESS_CUDA(cuCtxPopCurrent(&oldContext));
2021
ASSERT_SUCCESS_CUDA(cuCtxGetCurrent(&current));
21-
} while (current != nullptr);
22+
}
2223

2324
uur::raii::Mem mem;
2425
ASSERT_SUCCESS(urMemBufferCreate(context, UR_MEM_FLAG_READ_WRITE, memSize,

test/conformance/enqueue/urEnqueueKernelLaunch.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,9 @@ UUR_INSTANTIATE_PLATFORM_TEST_SUITE(urEnqueueKernelLaunchMultiDeviceTest);
566566
// (the context is only created for one device)
567567
TEST_P(urEnqueueKernelLaunchMultiDeviceTest, KernelLaunchReadDifferentQueues) {
568568
UUR_KNOWN_FAILURE_ON(uur::LevelZero{}, uur::LevelZeroV2{});
569-
569+
if (devices.size() > 1) {
570+
UUR_KNOWN_FAILURE_ON(uur::CUDA{});
571+
}
570572
uur::KernelLaunchHelper helper =
571573
uur::KernelLaunchHelper{platform, context, kernel, queues[0]};
572574

test/conformance/enqueue/urEnqueueKernelLaunchAndMemcpyInOrder.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ UUR_PLATFORM_TEST_SUITE_WITH_PARAM(
180180

181181
TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
182182
UUR_KNOWN_FAILURE_ON(uur::LevelZeroV2{});
183+
if (devices.size() > 1) {
184+
UUR_KNOWN_FAILURE_ON(uur::CUDA{});
185+
}
183186

184187
constexpr size_t global_offset = 0;
185188
constexpr size_t n_dimensions = 1;
@@ -358,6 +361,9 @@ UUR_PLATFORM_TEST_SUITE_WITH_PARAM(
358361
// Enqueue kernelLaunch concurrently from multiple threads
359362
// With !queuePerThread this becomes a test on a single device
360363
TEST_P(urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest, Success) {
364+
if (devices.size() > 1) {
365+
UUR_KNOWN_FAILURE_ON(uur::CUDA{});
366+
}
361367
auto useEvents = std::get<0>(getParam()).value;
362368
auto queuePerThread = std::get<1>(getParam()).value;
363369

0 commit comments

Comments
 (0)