@@ -221,15 +221,15 @@ static ur_result_t enqueueCommandBufferFillHelper(
221
221
ur_event_handle_t *RetEvent,
222
222
ur_exp_command_buffer_command_handle_t *RetCommand) {
223
223
std::vector<CUgraphNode> DepsList;
224
- UR_CHECK_ERROR (getNodesFromSyncPoints (CommandBuffer, NumSyncPointsInWaitList,
225
- SyncPointWaitList, DepsList));
224
+ try {
225
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
226
+ CommandBuffer, NumSyncPointsInWaitList, SyncPointWaitList, DepsList));
226
227
227
- if (NumEventsInWaitList) {
228
- UR_CHECK_ERROR (CommandBuffer->addWaitNodes (DepsList, NumEventsInWaitList,
229
- EventWaitList));
230
- }
228
+ if (NumEventsInWaitList) {
229
+ UR_CHECK_ERROR (CommandBuffer->addWaitNodes (DepsList, NumEventsInWaitList,
230
+ EventWaitList));
231
+ }
231
232
232
- try {
233
233
// Graph node added to graph, if multiple nodes are created this will
234
234
// be set to the leaf node
235
235
CUgraphNode GraphNode;
@@ -566,15 +566,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
566
566
ur_exp_command_buffer_command_handle_t *phCommand) {
567
567
CUgraphNode GraphNode;
568
568
std::vector<CUgraphNode> DepsList;
569
- UR_CHECK_ERROR (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
570
- pSyncPointWaitList, DepsList));
569
+ try {
570
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
571
+ hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
571
572
572
- if (numEventsInWaitList) {
573
- UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
574
- phEventWaitList));
575
- }
573
+ if (numEventsInWaitList) {
574
+ UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
575
+ phEventWaitList));
576
+ }
576
577
577
- try {
578
578
CUDA_MEMCPY3D NodeParams = {};
579
579
setCopyParams (pSrc, CU_MEMORYTYPE_HOST, pDst, CU_MEMORYTYPE_HOST, size,
580
580
NodeParams);
@@ -629,15 +629,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
629
629
UR_ASSERT (size + srcOffset <= std::get<BufferMem>(hSrcMem->Mem ).getSize (),
630
630
UR_RESULT_ERROR_INVALID_SIZE);
631
631
632
- UR_CHECK_ERROR (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
633
- pSyncPointWaitList, DepsList));
632
+ try {
633
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
634
+ hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
634
635
635
- if (numEventsInWaitList) {
636
- UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
637
- phEventWaitList));
638
- }
636
+ if (numEventsInWaitList) {
637
+ UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
638
+ phEventWaitList));
639
+ }
639
640
640
- try {
641
641
auto Src = std::get<BufferMem>(hSrcMem->Mem )
642
642
.getPtrWithOffset (hCommandBuffer->Device , srcOffset);
643
643
auto Dst = std::get<BufferMem>(hDstMem->Mem )
@@ -692,15 +692,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
692
692
ur_exp_command_buffer_command_handle_t *phCommand) {
693
693
CUgraphNode GraphNode;
694
694
std::vector<CUgraphNode> DepsList;
695
- UR_CHECK_ERROR (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
696
- pSyncPointWaitList, DepsList));
695
+ try {
696
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
697
+ hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
697
698
698
- if (numEventsInWaitList) {
699
- UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
700
- phEventWaitList));
701
- }
699
+ if (numEventsInWaitList) {
700
+ UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
701
+ phEventWaitList));
702
+ }
702
703
703
- try {
704
704
auto SrcPtr =
705
705
std::get<BufferMem>(hSrcMem->Mem ).getPtr (hCommandBuffer->Device );
706
706
auto DstPtr =
@@ -756,15 +756,15 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
756
756
ur_exp_command_buffer_command_handle_t *phCommand) {
757
757
CUgraphNode GraphNode;
758
758
std::vector<CUgraphNode> DepsList;
759
- UR_CHECK_ERROR (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
760
- pSyncPointWaitList, DepsList));
759
+ try {
760
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
761
+ hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
761
762
762
- if (numEventsInWaitList) {
763
- UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
764
- phEventWaitList));
765
- }
763
+ if (numEventsInWaitList) {
764
+ UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
765
+ phEventWaitList));
766
+ }
766
767
767
- try {
768
768
auto Dst = std::get<BufferMem>(hBuffer->Mem )
769
769
.getPtrWithOffset (hCommandBuffer->Device , offset);
770
770
@@ -816,15 +816,16 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
816
816
ur_exp_command_buffer_command_handle_t *phCommand) {
817
817
CUgraphNode GraphNode;
818
818
std::vector<CUgraphNode> DepsList;
819
- UR_CHECK_ERROR (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
820
- pSyncPointWaitList, DepsList));
821
-
822
- if (numEventsInWaitList) {
823
- UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
824
- phEventWaitList));
825
- }
826
819
827
820
try {
821
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
822
+ hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
823
+
824
+ if (numEventsInWaitList) {
825
+ UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
826
+ phEventWaitList));
827
+ }
828
+
828
829
auto Src = std::get<BufferMem>(hBuffer->Mem )
829
830
.getPtrWithOffset (hCommandBuffer->Device , offset);
830
831
@@ -879,15 +880,15 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
879
880
ur_exp_command_buffer_command_handle_t *phCommand) {
880
881
CUgraphNode GraphNode;
881
882
std::vector<CUgraphNode> DepsList;
882
- UR_CHECK_ERROR (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
883
- pSyncPointWaitList, DepsList));
883
+ try {
884
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
885
+ hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
884
886
885
- if (numEventsInWaitList) {
886
- UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
887
- phEventWaitList));
888
- }
887
+ if (numEventsInWaitList) {
888
+ UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
889
+ phEventWaitList));
890
+ }
889
891
890
- try {
891
892
auto DstPtr =
892
893
std::get<BufferMem>(hBuffer->Mem ).getPtr (hCommandBuffer->Device );
893
894
CUDA_MEMCPY3D NodeParams = {};
@@ -944,15 +945,15 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
944
945
ur_exp_command_buffer_command_handle_t *phCommand) {
945
946
CUgraphNode GraphNode;
946
947
std::vector<CUgraphNode> DepsList;
947
- UR_CHECK_ERROR (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
948
- pSyncPointWaitList, DepsList));
948
+ try {
949
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
950
+ hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
949
951
950
- if (numEventsInWaitList) {
951
- UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
952
- phEventWaitList));
953
- }
952
+ if (numEventsInWaitList) {
953
+ UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
954
+ phEventWaitList));
955
+ }
954
956
955
- try {
956
957
auto SrcPtr =
957
958
std::get<BufferMem>(hBuffer->Mem ).getPtr (hCommandBuffer->Device );
958
959
CUDA_MEMCPY3D NodeParams = {};
@@ -1009,15 +1010,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
1009
1010
CUgraphNode GraphNode;
1010
1011
1011
1012
std::vector<CUgraphNode> DepsList;
1012
- UR_CHECK_ERROR (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
1013
- pSyncPointWaitList, DepsList));
1013
+ try {
1014
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
1015
+ hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
1014
1016
1015
- if (numEventsInWaitList) {
1016
- UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
1017
- phEventWaitList));
1018
- }
1017
+ if (numEventsInWaitList) {
1018
+ UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
1019
+ phEventWaitList));
1020
+ }
1019
1021
1020
- try {
1021
1022
// Add an empty node to preserve dependencies.
1022
1023
UR_CHECK_ERROR (cuGraphAddEmptyNode (&GraphNode, hCommandBuffer->CudaGraph ,
1023
1024
DepsList.data (), DepsList.size ()));
@@ -1065,15 +1066,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
1065
1066
CUgraphNode GraphNode;
1066
1067
1067
1068
std::vector<CUgraphNode> DepsList;
1068
- UR_CHECK_ERROR (getNodesFromSyncPoints (hCommandBuffer, numSyncPointsInWaitList,
1069
- pSyncPointWaitList, DepsList));
1069
+ try {
1070
+ UR_CHECK_ERROR (getNodesFromSyncPoints (
1071
+ hCommandBuffer, numSyncPointsInWaitList, pSyncPointWaitList, DepsList));
1070
1072
1071
- if (numEventsInWaitList) {
1072
- UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
1073
- phEventWaitList));
1074
- }
1073
+ if (numEventsInWaitList) {
1074
+ UR_CHECK_ERROR (hCommandBuffer->addWaitNodes (DepsList, numEventsInWaitList,
1075
+ phEventWaitList));
1076
+ }
1075
1077
1076
- try {
1077
1078
// Add an empty node to preserve dependencies.
1078
1079
UR_CHECK_ERROR (cuGraphAddEmptyNode (&GraphNode, hCommandBuffer->CudaGraph ,
1079
1080
DepsList.data (), DepsList.size ()));
@@ -1361,49 +1362,55 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
1361
1362
return UR_RESULT_ERROR_INVALID_VALUE;
1362
1363
}
1363
1364
1364
- auto KernelCommandHandle = static_cast <kernel_command_handle *>(hCommand);
1365
+ try {
1366
+ auto KernelCommandHandle = static_cast <kernel_command_handle *>(hCommand);
1365
1367
1366
- UR_CHECK_ERROR (validateCommandDesc (KernelCommandHandle, pUpdateKernelLaunch));
1367
- UR_CHECK_ERROR (
1368
- updateKernelArguments (CommandBuffer->Device , pUpdateKernelLaunch));
1369
- UR_CHECK_ERROR (updateCommand (KernelCommandHandle, pUpdateKernelLaunch));
1370
-
1371
- // If no work-size is provided make sure we pass nullptr to setKernelParams so
1372
- // it can guess the local work size.
1373
- const bool ProvidedLocalSize = !KernelCommandHandle->isNullLocalSize ();
1374
- size_t *LocalWorkSize =
1375
- ProvidedLocalSize ? KernelCommandHandle->LocalWorkSize : nullptr ;
1376
-
1377
- // Set the number of threads per block to the number of threads per warp
1378
- // by default unless user has provided a better number.
1379
- size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
1380
- size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
1381
- CUfunction CuFunc = KernelCommandHandle->Kernel ->get ();
1382
- auto Result = setKernelParams (
1383
- CommandBuffer->Context , CommandBuffer->Device ,
1384
- KernelCommandHandle->WorkDim , KernelCommandHandle->GlobalWorkOffset ,
1385
- KernelCommandHandle->GlobalWorkSize , LocalWorkSize,
1386
- KernelCommandHandle->Kernel , CuFunc, ThreadsPerBlock, BlocksPerGrid);
1387
- if (Result != UR_RESULT_SUCCESS) {
1388
- return Result;
1389
- }
1368
+ UR_CHECK_ERROR (
1369
+ validateCommandDesc (KernelCommandHandle, pUpdateKernelLaunch));
1370
+ UR_CHECK_ERROR (
1371
+ updateKernelArguments (CommandBuffer->Device , pUpdateKernelLaunch));
1372
+ UR_CHECK_ERROR (updateCommand (KernelCommandHandle, pUpdateKernelLaunch));
1390
1373
1391
- CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params ;
1374
+ // If no work-size is provided make sure we pass nullptr to setKernelParams
1375
+ // so it can guess the local work size.
1376
+ const bool ProvidedLocalSize = !KernelCommandHandle->isNullLocalSize ();
1377
+ size_t *LocalWorkSize =
1378
+ ProvidedLocalSize ? KernelCommandHandle->LocalWorkSize : nullptr ;
1392
1379
1393
- Params.func = CuFunc;
1394
- Params.gridDimX = BlocksPerGrid[0 ];
1395
- Params.gridDimY = BlocksPerGrid[1 ];
1396
- Params.gridDimZ = BlocksPerGrid[2 ];
1397
- Params.blockDimX = ThreadsPerBlock[0 ];
1398
- Params.blockDimY = ThreadsPerBlock[1 ];
1399
- Params.blockDimZ = ThreadsPerBlock[2 ];
1400
- Params.sharedMemBytes = KernelCommandHandle->Kernel ->getLocalSize ();
1401
- Params.kernelParams =
1402
- const_cast <void **>(KernelCommandHandle->Kernel ->getArgIndices ().data ());
1380
+ // Set the number of threads per block to the number of threads per warp
1381
+ // by default unless user has provided a better number.
1382
+ size_t ThreadsPerBlock[3 ] = {32u , 1u , 1u };
1383
+ size_t BlocksPerGrid[3 ] = {1u , 1u , 1u };
1384
+ CUfunction CuFunc = KernelCommandHandle->Kernel ->get ();
1385
+ auto Result = setKernelParams (
1386
+ CommandBuffer->Context , CommandBuffer->Device ,
1387
+ KernelCommandHandle->WorkDim , KernelCommandHandle->GlobalWorkOffset ,
1388
+ KernelCommandHandle->GlobalWorkSize , LocalWorkSize,
1389
+ KernelCommandHandle->Kernel , CuFunc, ThreadsPerBlock, BlocksPerGrid);
1390
+ if (Result != UR_RESULT_SUCCESS) {
1391
+ return Result;
1392
+ }
1403
1393
1404
- CUgraphNode Node = KernelCommandHandle->Node ;
1405
- CUgraphExec CudaGraphExec = CommandBuffer->CudaGraphExec ;
1406
- UR_CHECK_ERROR (cuGraphExecKernelNodeSetParams (CudaGraphExec, Node, &Params));
1394
+ CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params ;
1395
+
1396
+ Params.func = CuFunc;
1397
+ Params.gridDimX = BlocksPerGrid[0 ];
1398
+ Params.gridDimY = BlocksPerGrid[1 ];
1399
+ Params.gridDimZ = BlocksPerGrid[2 ];
1400
+ Params.blockDimX = ThreadsPerBlock[0 ];
1401
+ Params.blockDimY = ThreadsPerBlock[1 ];
1402
+ Params.blockDimZ = ThreadsPerBlock[2 ];
1403
+ Params.sharedMemBytes = KernelCommandHandle->Kernel ->getLocalSize ();
1404
+ Params.kernelParams = const_cast <void **>(
1405
+ KernelCommandHandle->Kernel ->getArgIndices ().data ());
1406
+
1407
+ CUgraphNode Node = KernelCommandHandle->Node ;
1408
+ CUgraphExec CudaGraphExec = CommandBuffer->CudaGraphExec ;
1409
+ UR_CHECK_ERROR (
1410
+ cuGraphExecKernelNodeSetParams (CudaGraphExec, Node, &Params));
1411
+ } catch (ur_result_t Err) {
1412
+ return Err;
1413
+ }
1407
1414
return UR_RESULT_SUCCESS;
1408
1415
}
1409
1416
@@ -1429,14 +1436,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateSignalEventExp(
1429
1436
return UR_RESULT_ERROR_INVALID_OPERATION;
1430
1437
}
1431
1438
1432
- CUevent SignalEvent;
1433
- UR_CHECK_ERROR (cuGraphEventRecordNodeGetEvent (SignalNode, &SignalEvent));
1439
+ try {
1440
+ CUevent SignalEvent;
1441
+ UR_CHECK_ERROR (cuGraphEventRecordNodeGetEvent (SignalNode, &SignalEvent));
1434
1442
1435
- if (phEvent) {
1436
- *phEvent = std::unique_ptr<ur_event_handle_t_>(
1437
- ur_event_handle_t_::makeWithNative (CommandBuffer->Context ,
1438
- SignalEvent))
1439
- .release ();
1443
+ if (phEvent) {
1444
+ *phEvent = std::unique_ptr<ur_event_handle_t_>(
1445
+ ur_event_handle_t_::makeWithNative (CommandBuffer->Context ,
1446
+ SignalEvent))
1447
+ .release ();
1448
+ }
1449
+ } catch (ur_result_t Err) {
1450
+ return Err;
1440
1451
}
1441
1452
1442
1453
return UR_RESULT_SUCCESS;
0 commit comments