Skip to content

Commit 942e682

Browse files
committed
Add try catch in command_buffer.cpp
Don't allow UR_CHECK_ERROR to be called outside of a try block.
1 parent f5ac85b commit 942e682

File tree

2 files changed

+199
-182
lines changed

2 files changed

+199
-182
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 127 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,15 @@ static ur_result_t enqueueCommandBufferFillHelper(
221221
ur_event_handle_t *RetEvent,
222222
ur_exp_command_buffer_command_handle_t *RetCommand) {
223223
std::vector<CUgraphNode> DepsList;
224-
UR_CHECK_ERROR(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
225-
SyncPointWaitList, DepsList));
224+
try {
225+
UR_CHECK_ERROR(getNodesFromSyncPoints(
226+
CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList, DepsList));
226227

227-
if (NumEventsInWaitList) {
228-
UR_CHECK_ERROR(CommandBuffer->addWaitNodes(DepsList, NumEventsInWaitList,
229-
EventWaitList));
230-
}
228+
if (NumEventsInWaitList) {
229+
UR_CHECK_ERROR(CommandBuffer->addWaitNodes(DepsList, NumEventsInWaitList,
230+
EventWaitList));
231+
}
231232

232-
try {
233233
// Graph node added to graph, if multiple nodes are created this will
234234
// be set to the leaf node
235235
CUgraphNode GraphNode;
@@ -566,15 +566,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
566566
ur_exp_command_buffer_command_handle_t *phCommand) {
567567
CUgraphNode GraphNode;
568568
std::vector<CUgraphNode> DepsList;
569-
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
570-
pSyncPointWaitList, DepsList));
569+
try {
570+
UR_CHECK_ERROR(getNodesFromSyncPoints(
571+
hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
571572

572-
if (numEventsInWaitList) {
573-
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
574-
phEventWaitList));
575-
}
573+
if (numEventsInWaitList) {
574+
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
575+
phEventWaitList));
576+
}
576577

577-
try {
578578
CUDA_MEMCPY3D NodeParams = {};
579579
setCopyParams(pSrc, CU_MEMORYTYPE_HOST, pDst, CU_MEMORYTYPE_HOST, size,
580580
NodeParams);
@@ -629,15 +629,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
629629
UR_ASSERT(size + srcOffset <= std::get<BufferMem>(hSrcMem->Mem).getSize(),
630630
UR_RESULT_ERROR_INVALID_SIZE);
631631

632-
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
633-
pSyncPointWaitList, DepsList));
632+
try {
633+
UR_CHECK_ERROR(getNodesFromSyncPoints(
634+
hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
634635

635-
if (numEventsInWaitList) {
636-
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
637-
phEventWaitList));
638-
}
636+
if (numEventsInWaitList) {
637+
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
638+
phEventWaitList));
639+
}
639640

640-
try {
641641
auto Src = std::get<BufferMem>(hSrcMem->Mem)
642642
.getPtrWithOffset(hCommandBuffer->Device, srcOffset);
643643
auto Dst = std::get<BufferMem>(hDstMem->Mem)
@@ -692,15 +692,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
692692
ur_exp_command_buffer_command_handle_t *phCommand) {
693693
CUgraphNode GraphNode;
694694
std::vector<CUgraphNode> DepsList;
695-
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
696-
pSyncPointWaitList, DepsList));
695+
try {
696+
UR_CHECK_ERROR(getNodesFromSyncPoints(
697+
hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
697698

698-
if (numEventsInWaitList) {
699-
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
700-
phEventWaitList));
701-
}
699+
if (numEventsInWaitList) {
700+
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
701+
phEventWaitList));
702+
}
702703

703-
try {
704704
auto SrcPtr =
705705
std::get<BufferMem>(hSrcMem->Mem).getPtr(hCommandBuffer->Device);
706706
auto DstPtr =
@@ -756,15 +756,15 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
756756
ur_exp_command_buffer_command_handle_t *phCommand) {
757757
CUgraphNode GraphNode;
758758
std::vector<CUgraphNode> DepsList;
759-
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
760-
pSyncPointWaitList, DepsList));
759+
try {
760+
UR_CHECK_ERROR(getNodesFromSyncPoints(
761+
hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
761762

762-
if (numEventsInWaitList) {
763-
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
764-
phEventWaitList));
765-
}
763+
if (numEventsInWaitList) {
764+
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
765+
phEventWaitList));
766+
}
766767

767-
try {
768768
auto Dst = std::get<BufferMem>(hBuffer->Mem)
769769
.getPtrWithOffset(hCommandBuffer->Device, offset);
770770

@@ -816,15 +816,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
816816
ur_exp_command_buffer_command_handle_t *phCommand) {
817817
CUgraphNode GraphNode;
818818
std::vector<CUgraphNode> DepsList;
819-
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
820-
pSyncPointWaitList, DepsList));
821-
822-
if (numEventsInWaitList) {
823-
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
824-
phEventWaitList));
825-
}
826819

827820
try {
821+
UR_CHECK_ERROR(getNodesFromSyncPoints(
822+
hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
823+
824+
if (numEventsInWaitList) {
825+
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
826+
phEventWaitList));
827+
}
828+
828829
auto Src = std::get<BufferMem>(hBuffer->Mem)
829830
.getPtrWithOffset(hCommandBuffer->Device, offset);
830831

@@ -879,15 +880,15 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
879880
ur_exp_command_buffer_command_handle_t *phCommand) {
880881
CUgraphNode GraphNode;
881882
std::vector<CUgraphNode> DepsList;
882-
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
883-
pSyncPointWaitList, DepsList));
883+
try {
884+
UR_CHECK_ERROR(getNodesFromSyncPoints(
885+
hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
884886

885-
if (numEventsInWaitList) {
886-
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
887-
phEventWaitList));
888-
}
887+
if (numEventsInWaitList) {
888+
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
889+
phEventWaitList));
890+
}
889891

890-
try {
891892
auto DstPtr =
892893
std::get<BufferMem>(hBuffer->Mem).getPtr(hCommandBuffer->Device);
893894
CUDA_MEMCPY3D NodeParams = {};
@@ -944,15 +945,15 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
944945
ur_exp_command_buffer_command_handle_t *phCommand) {
945946
CUgraphNode GraphNode;
946947
std::vector<CUgraphNode> DepsList;
947-
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
948-
pSyncPointWaitList, DepsList));
948+
try {
949+
UR_CHECK_ERROR(getNodesFromSyncPoints(
950+
hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
949951

950-
if (numEventsInWaitList) {
951-
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
952-
phEventWaitList));
953-
}
952+
if (numEventsInWaitList) {
953+
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
954+
phEventWaitList));
955+
}
954956

955-
try {
956957
auto SrcPtr =
957958
std::get<BufferMem>(hBuffer->Mem).getPtr(hCommandBuffer->Device);
958959
CUDA_MEMCPY3D NodeParams = {};
@@ -1009,15 +1010,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
10091010
CUgraphNode GraphNode;
10101011

10111012
std::vector<CUgraphNode> DepsList;
1012-
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
1013-
pSyncPointWaitList, DepsList));
1013+
try {
1014+
UR_CHECK_ERROR(getNodesFromSyncPoints(
1015+
hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
10141016

1015-
if (numEventsInWaitList) {
1016-
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
1017-
phEventWaitList));
1018-
}
1017+
if (numEventsInWaitList) {
1018+
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
1019+
phEventWaitList));
1020+
}
10191021

1020-
try {
10211022
// Add an empty node to preserve dependencies.
10221023
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
10231024
DepsList.data(), DepsList.size()));
@@ -1065,15 +1066,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
10651066
CUgraphNode GraphNode;
10661067

10671068
std::vector<CUgraphNode> DepsList;
1068-
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
1069-
pSyncPointWaitList, DepsList));
1069+
try {
1070+
UR_CHECK_ERROR(getNodesFromSyncPoints(
1071+
hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
10701072

1071-
if (numEventsInWaitList) {
1072-
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
1073-
phEventWaitList));
1074-
}
1073+
if (numEventsInWaitList) {
1074+
UR_CHECK_ERROR(hCommandBuffer->addWaitNodes(DepsList, numEventsInWaitList,
1075+
phEventWaitList));
1076+
}
10751077

1076-
try {
10771078
// Add an empty node to preserve dependencies.
10781079
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
10791080
DepsList.data(), DepsList.size()));
@@ -1361,49 +1362,55 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13611362
return UR_RESULT_ERROR_INVALID_VALUE;
13621363
}
13631364

1364-
auto KernelCommandHandle = static_cast<kernel_command_handle *>(hCommand);
1365+
try {
1366+
auto KernelCommandHandle = static_cast<kernel_command_handle *>(hCommand);
13651367

1366-
UR_CHECK_ERROR(validateCommandDesc(KernelCommandHandle, pUpdateKernelLaunch));
1367-
UR_CHECK_ERROR(
1368-
updateKernelArguments(CommandBuffer->Device, pUpdateKernelLaunch));
1369-
UR_CHECK_ERROR(updateCommand(KernelCommandHandle, pUpdateKernelLaunch));
1370-
1371-
// If no work-size is provided make sure we pass nullptr to setKernelParams so
1372-
// it can guess the local work size.
1373-
const bool ProvidedLocalSize = !KernelCommandHandle->isNullLocalSize();
1374-
size_t *LocalWorkSize =
1375-
ProvidedLocalSize ? KernelCommandHandle->LocalWorkSize : nullptr;
1376-
1377-
// Set the number of threads per block to the number of threads per warp
1378-
// by default unless user has provided a better number.
1379-
size_t ThreadsPerBlock[3] = {32u, 1u, 1u};
1380-
size_t BlocksPerGrid[3] = {1u, 1u, 1u};
1381-
CUfunction CuFunc = KernelCommandHandle->Kernel->get();
1382-
auto Result = setKernelParams(
1383-
CommandBuffer->Context, CommandBuffer->Device,
1384-
KernelCommandHandle->WorkDim, KernelCommandHandle->GlobalWorkOffset,
1385-
KernelCommandHandle->GlobalWorkSize, LocalWorkSize,
1386-
KernelCommandHandle->Kernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
1387-
if (Result != UR_RESULT_SUCCESS) {
1388-
return Result;
1389-
}
1368+
UR_CHECK_ERROR(
1369+
validateCommandDesc(KernelCommandHandle, pUpdateKernelLaunch));
1370+
UR_CHECK_ERROR(
1371+
updateKernelArguments(CommandBuffer->Device, pUpdateKernelLaunch));
1372+
UR_CHECK_ERROR(updateCommand(KernelCommandHandle, pUpdateKernelLaunch));
13901373

1391-
CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params;
1374+
// If no work-size is provided make sure we pass nullptr to setKernelParams
1375+
// so it can guess the local work size.
1376+
const bool ProvidedLocalSize = !KernelCommandHandle->isNullLocalSize();
1377+
size_t *LocalWorkSize =
1378+
ProvidedLocalSize ? KernelCommandHandle->LocalWorkSize : nullptr;
13921379

1393-
Params.func = CuFunc;
1394-
Params.gridDimX = BlocksPerGrid[0];
1395-
Params.gridDimY = BlocksPerGrid[1];
1396-
Params.gridDimZ = BlocksPerGrid[2];
1397-
Params.blockDimX = ThreadsPerBlock[0];
1398-
Params.blockDimY = ThreadsPerBlock[1];
1399-
Params.blockDimZ = ThreadsPerBlock[2];
1400-
Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize();
1401-
Params.kernelParams =
1402-
const_cast<void **>(KernelCommandHandle->Kernel->getArgIndices().data());
1380+
// Set the number of threads per block to the number of threads per warp
1381+
// by default unless user has provided a better number.
1382+
size_t ThreadsPerBlock[3] = {32u, 1u, 1u};
1383+
size_t BlocksPerGrid[3] = {1u, 1u, 1u};
1384+
CUfunction CuFunc = KernelCommandHandle->Kernel->get();
1385+
auto Result = setKernelParams(
1386+
CommandBuffer->Context, CommandBuffer->Device,
1387+
KernelCommandHandle->WorkDim, KernelCommandHandle->GlobalWorkOffset,
1388+
KernelCommandHandle->GlobalWorkSize, LocalWorkSize,
1389+
KernelCommandHandle->Kernel, CuFunc, ThreadsPerBlock, BlocksPerGrid);
1390+
if (Result != UR_RESULT_SUCCESS) {
1391+
return Result;
1392+
}
14031393

1404-
CUgraphNode Node = KernelCommandHandle->Node;
1405-
CUgraphExec CudaGraphExec = CommandBuffer->CudaGraphExec;
1406-
UR_CHECK_ERROR(cuGraphExecKernelNodeSetParams(CudaGraphExec, Node, &Params));
1394+
CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params;
1395+
1396+
Params.func = CuFunc;
1397+
Params.gridDimX = BlocksPerGrid[0];
1398+
Params.gridDimY = BlocksPerGrid[1];
1399+
Params.gridDimZ = BlocksPerGrid[2];
1400+
Params.blockDimX = ThreadsPerBlock[0];
1401+
Params.blockDimY = ThreadsPerBlock[1];
1402+
Params.blockDimZ = ThreadsPerBlock[2];
1403+
Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize();
1404+
Params.kernelParams = const_cast<void **>(
1405+
KernelCommandHandle->Kernel->getArgIndices().data());
1406+
1407+
CUgraphNode Node = KernelCommandHandle->Node;
1408+
CUgraphExec CudaGraphExec = CommandBuffer->CudaGraphExec;
1409+
UR_CHECK_ERROR(
1410+
cuGraphExecKernelNodeSetParams(CudaGraphExec, Node, &Params));
1411+
} catch (ur_result_t Err) {
1412+
return Err;
1413+
}
14071414
return UR_RESULT_SUCCESS;
14081415
}
14091416

@@ -1429,14 +1436,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateSignalEventExp(
14291436
return UR_RESULT_ERROR_INVALID_OPERATION;
14301437
}
14311438

1432-
CUevent SignalEvent;
1433-
UR_CHECK_ERROR(cuGraphEventRecordNodeGetEvent(SignalNode, &SignalEvent));
1439+
try {
1440+
CUevent SignalEvent;
1441+
UR_CHECK_ERROR(cuGraphEventRecordNodeGetEvent(SignalNode, &SignalEvent));
14341442

1435-
if (phEvent) {
1436-
*phEvent = std::unique_ptr<ur_event_handle_t_>(
1437-
ur_event_handle_t_::makeWithNative(CommandBuffer->Context,
1438-
SignalEvent))
1439-
.release();
1443+
if (phEvent) {
1444+
*phEvent = std::unique_ptr<ur_event_handle_t_>(
1445+
ur_event_handle_t_::makeWithNative(CommandBuffer->Context,
1446+
SignalEvent))
1447+
.release();
1448+
}
1449+
} catch (ur_result_t Err) {
1450+
return Err;
14401451
}
14411452

14421453
return UR_RESULT_SUCCESS;

0 commit comments

Comments
 (0)