Skip to content

Commit b25bb64

Browse files
authored
Merge pull request #1067 from alexbatashev/fix_native_handles
[UR][Loader] Fix handling of native handles
2 parents 0e281bc + 8b1bfc9 commit b25bb64

File tree

2 files changed

+16
-122
lines changed

2 files changed

+16
-122
lines changed

Diff for: scripts/templates/ldrddi.cpp.mako

+7-2
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,17 @@ namespace ur_loader
130130
%else:
131131
<%param_replacements={}%>
132132
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
133-
%if 0 == i:
133+
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
134134
// extract platform's function pointer table
135135
auto dditable = reinterpret_cast<${item['obj']}*>( ${item['pointer']}${item['name']} )->dditable;
136136
auto ${th.make_pfn_name(n, tags, obj)} = dditable->${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
137137
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
138138
return ${X}_RESULT_ERROR_UNINITIALIZED;
139139
140+
<%break%>
140141
%endif
142+
%endfor
143+
%for i, item in enumerate(th.get_loader_prologue(n, tags, obj, meta)):
141144
%if 'range' in item:
142145
<%
143146
add_local = True
@@ -146,13 +149,15 @@ namespace ur_loader
146149
for( size_t i = ${item['range'][0]}; i < ${item['range'][1]}; ++i )
147150
${item['name']}Local[ i ] = reinterpret_cast<${item['obj']}*>( ${item['name']}[ i ] )->handle;
148151
%else:
152+
%if not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
149153
// convert loader handle to platform handle
150154
%if item['optional']:
151155
${item['name']} = ( ${item['name']} ) ? reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle : nullptr;
152156
%else:
153157
${item['name']} = reinterpret_cast<${item['obj']}*>( ${item['name']} )->handle;
154158
%endif
155159
%endif
160+
%endif
156161
157162
%endfor
158163
// forward to device-platform
@@ -173,7 +178,7 @@ namespace ur_loader
173178
%if item['release']:
174179
// release loader handle
175180
${item['factory']}.release( ${item['name']} );
176-
%else:
181+
%elif not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
177182
try
178183
{
179184
%if 'range' in item:

Diff for: source/loader/ur_ldrddi.cpp

+9-120
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,6 @@ __urdlllocal ur_result_t UR_APICALL urPlatformGetNativeHandle(
352352
return result;
353353
}
354354

355-
try {
356-
// convert platform handle to loader handle
357-
*phNativePlatform = reinterpret_cast<ur_native_handle_t>(
358-
ur_native_factory.getInstance(*phNativePlatform, dditable));
359-
} catch (std::bad_alloc &) {
360-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
361-
}
362-
363355
return result;
364356
}
365357

@@ -673,14 +665,6 @@ __urdlllocal ur_result_t UR_APICALL urDeviceGetNativeHandle(
673665
return result;
674666
}
675667

676-
try {
677-
// convert platform handle to loader handle
678-
*phNativeDevice = reinterpret_cast<ur_native_handle_t>(
679-
ur_native_factory.getInstance(*phNativeDevice, dditable));
680-
} catch (std::bad_alloc &) {
681-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
682-
}
683-
684668
return result;
685669
}
686670

@@ -699,17 +683,13 @@ __urdlllocal ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
699683

700684
// extract platform's function pointer table
701685
auto dditable =
702-
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->dditable;
686+
reinterpret_cast<ur_platform_object_t *>(hPlatform)->dditable;
703687
auto pfnCreateWithNativeHandle =
704688
dditable->ur.Device.pfnCreateWithNativeHandle;
705689
if (nullptr == pfnCreateWithNativeHandle) {
706690
return UR_RESULT_ERROR_UNINITIALIZED;
707691
}
708692

709-
// convert loader handle to platform handle
710-
hNativeDevice =
711-
reinterpret_cast<ur_native_object_t *>(hNativeDevice)->handle;
712-
713693
// convert loader handle to platform handle
714694
hPlatform = reinterpret_cast<ur_platform_object_t *>(hPlatform)->handle;
715695

@@ -916,14 +896,6 @@ __urdlllocal ur_result_t UR_APICALL urContextGetNativeHandle(
916896
return result;
917897
}
918898

919-
try {
920-
// convert platform handle to loader handle
921-
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
922-
ur_native_factory.getInstance(*phNativeContext, dditable));
923-
} catch (std::bad_alloc &) {
924-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
925-
}
926-
927899
return result;
928900
}
929901

@@ -944,17 +916,13 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
944916

945917
// extract platform's function pointer table
946918
auto dditable =
947-
reinterpret_cast<ur_native_object_t *>(hNativeContext)->dditable;
919+
reinterpret_cast<ur_device_object_t *>(*phDevices)->dditable;
948920
auto pfnCreateWithNativeHandle =
949921
dditable->ur.Context.pfnCreateWithNativeHandle;
950922
if (nullptr == pfnCreateWithNativeHandle) {
951923
return UR_RESULT_ERROR_UNINITIALIZED;
952924
}
953925

954-
// convert loader handle to platform handle
955-
hNativeContext =
956-
reinterpret_cast<ur_native_object_t *>(hNativeContext)->handle;
957-
958926
// convert loader handles to platform handles
959927
auto phDevicesLocal = std::vector<ur_device_handle_t>(numDevices);
960928
for (size_t i = 0; i < numDevices; ++i) {
@@ -1207,14 +1175,6 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
12071175
return result;
12081176
}
12091177

1210-
try {
1211-
// convert platform handle to loader handle
1212-
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
1213-
ur_native_factory.getInstance(*phNativeMem, dditable));
1214-
} catch (std::bad_alloc &) {
1215-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
1216-
}
1217-
12181178
return result;
12191179
}
12201180

@@ -1232,17 +1192,13 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(
12321192
ur_result_t result = UR_RESULT_SUCCESS;
12331193

12341194
// extract platform's function pointer table
1235-
auto dditable =
1236-
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
1195+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
12371196
auto pfnBufferCreateWithNativeHandle =
12381197
dditable->ur.Mem.pfnBufferCreateWithNativeHandle;
12391198
if (nullptr == pfnBufferCreateWithNativeHandle) {
12401199
return UR_RESULT_ERROR_UNINITIALIZED;
12411200
}
12421201

1243-
// convert loader handle to platform handle
1244-
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;
1245-
12461202
// convert loader handle to platform handle
12471203
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
12481204

@@ -1282,17 +1238,13 @@ __urdlllocal ur_result_t UR_APICALL urMemImageCreateWithNativeHandle(
12821238
ur_result_t result = UR_RESULT_SUCCESS;
12831239

12841240
// extract platform's function pointer table
1285-
auto dditable =
1286-
reinterpret_cast<ur_native_object_t *>(hNativeMem)->dditable;
1241+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
12871242
auto pfnImageCreateWithNativeHandle =
12881243
dditable->ur.Mem.pfnImageCreateWithNativeHandle;
12891244
if (nullptr == pfnImageCreateWithNativeHandle) {
12901245
return UR_RESULT_ERROR_UNINITIALIZED;
12911246
}
12921247

1293-
// convert loader handle to platform handle
1294-
hNativeMem = reinterpret_cast<ur_native_object_t *>(hNativeMem)->handle;
1295-
12961248
// convert loader handle to platform handle
12971249
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
12981250

@@ -1528,14 +1480,6 @@ __urdlllocal ur_result_t UR_APICALL urSamplerGetNativeHandle(
15281480
return result;
15291481
}
15301482

1531-
try {
1532-
// convert platform handle to loader handle
1533-
*phNativeSampler = reinterpret_cast<ur_native_handle_t>(
1534-
ur_native_factory.getInstance(*phNativeSampler, dditable));
1535-
} catch (std::bad_alloc &) {
1536-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
1537-
}
1538-
15391483
return result;
15401484
}
15411485

@@ -1553,18 +1497,13 @@ __urdlllocal ur_result_t UR_APICALL urSamplerCreateWithNativeHandle(
15531497
ur_result_t result = UR_RESULT_SUCCESS;
15541498

15551499
// extract platform's function pointer table
1556-
auto dditable =
1557-
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->dditable;
1500+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
15581501
auto pfnCreateWithNativeHandle =
15591502
dditable->ur.Sampler.pfnCreateWithNativeHandle;
15601503
if (nullptr == pfnCreateWithNativeHandle) {
15611504
return UR_RESULT_ERROR_UNINITIALIZED;
15621505
}
15631506

1564-
// convert loader handle to platform handle
1565-
hNativeSampler =
1566-
reinterpret_cast<ur_native_object_t *>(hNativeSampler)->handle;
1567-
15681507
// convert loader handle to platform handle
15691508
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
15701509

@@ -2604,14 +2543,6 @@ __urdlllocal ur_result_t UR_APICALL urProgramGetNativeHandle(
26042543
return result;
26052544
}
26062545

2607-
try {
2608-
// convert platform handle to loader handle
2609-
*phNativeProgram = reinterpret_cast<ur_native_handle_t>(
2610-
ur_native_factory.getInstance(*phNativeProgram, dditable));
2611-
} catch (std::bad_alloc &) {
2612-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
2613-
}
2614-
26152546
return result;
26162547
}
26172548

@@ -2629,18 +2560,13 @@ __urdlllocal ur_result_t UR_APICALL urProgramCreateWithNativeHandle(
26292560
ur_result_t result = UR_RESULT_SUCCESS;
26302561

26312562
// extract platform's function pointer table
2632-
auto dditable =
2633-
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->dditable;
2563+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
26342564
auto pfnCreateWithNativeHandle =
26352565
dditable->ur.Program.pfnCreateWithNativeHandle;
26362566
if (nullptr == pfnCreateWithNativeHandle) {
26372567
return UR_RESULT_ERROR_UNINITIALIZED;
26382568
}
26392569

2640-
// convert loader handle to platform handle
2641-
hNativeProgram =
2642-
reinterpret_cast<ur_native_object_t *>(hNativeProgram)->handle;
2643-
26442570
// convert loader handle to platform handle
26452571
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
26462572

@@ -3088,14 +3014,6 @@ __urdlllocal ur_result_t UR_APICALL urKernelGetNativeHandle(
30883014
return result;
30893015
}
30903016

3091-
try {
3092-
// convert platform handle to loader handle
3093-
*phNativeKernel = reinterpret_cast<ur_native_handle_t>(
3094-
ur_native_factory.getInstance(*phNativeKernel, dditable));
3095-
} catch (std::bad_alloc &) {
3096-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
3097-
}
3098-
30993017
return result;
31003018
}
31013019

@@ -3115,18 +3033,13 @@ __urdlllocal ur_result_t UR_APICALL urKernelCreateWithNativeHandle(
31153033
ur_result_t result = UR_RESULT_SUCCESS;
31163034

31173035
// extract platform's function pointer table
3118-
auto dditable =
3119-
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->dditable;
3036+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
31203037
auto pfnCreateWithNativeHandle =
31213038
dditable->ur.Kernel.pfnCreateWithNativeHandle;
31223039
if (nullptr == pfnCreateWithNativeHandle) {
31233040
return UR_RESULT_ERROR_UNINITIALIZED;
31243041
}
31253042

3126-
// convert loader handle to platform handle
3127-
hNativeKernel =
3128-
reinterpret_cast<ur_native_object_t *>(hNativeKernel)->handle;
3129-
31303043
// convert loader handle to platform handle
31313044
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
31323045

@@ -3300,14 +3213,6 @@ __urdlllocal ur_result_t UR_APICALL urQueueGetNativeHandle(
33003213
return result;
33013214
}
33023215

3303-
try {
3304-
// convert platform handle to loader handle
3305-
*phNativeQueue = reinterpret_cast<ur_native_handle_t>(
3306-
ur_native_factory.getInstance(*phNativeQueue, dditable));
3307-
} catch (std::bad_alloc &) {
3308-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
3309-
}
3310-
33113216
return result;
33123217
}
33133218

@@ -3326,17 +3231,13 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
33263231
ur_result_t result = UR_RESULT_SUCCESS;
33273232

33283233
// extract platform's function pointer table
3329-
auto dditable =
3330-
reinterpret_cast<ur_native_object_t *>(hNativeQueue)->dditable;
3234+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
33313235
auto pfnCreateWithNativeHandle =
33323236
dditable->ur.Queue.pfnCreateWithNativeHandle;
33333237
if (nullptr == pfnCreateWithNativeHandle) {
33343238
return UR_RESULT_ERROR_UNINITIALIZED;
33353239
}
33363240

3337-
// convert loader handle to platform handle
3338-
hNativeQueue = reinterpret_cast<ur_native_object_t *>(hNativeQueue)->handle;
3339-
33403241
// convert loader handle to platform handle
33413242
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
33423243

@@ -3573,14 +3474,6 @@ __urdlllocal ur_result_t UR_APICALL urEventGetNativeHandle(
35733474
return result;
35743475
}
35753476

3576-
try {
3577-
// convert platform handle to loader handle
3578-
*phNativeEvent = reinterpret_cast<ur_native_handle_t>(
3579-
ur_native_factory.getInstance(*phNativeEvent, dditable));
3580-
} catch (std::bad_alloc &) {
3581-
result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
3582-
}
3583-
35843477
return result;
35853478
}
35863479

@@ -3598,17 +3491,13 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle(
35983491
ur_result_t result = UR_RESULT_SUCCESS;
35993492

36003493
// extract platform's function pointer table
3601-
auto dditable =
3602-
reinterpret_cast<ur_native_object_t *>(hNativeEvent)->dditable;
3494+
auto dditable = reinterpret_cast<ur_context_object_t *>(hContext)->dditable;
36033495
auto pfnCreateWithNativeHandle =
36043496
dditable->ur.Event.pfnCreateWithNativeHandle;
36053497
if (nullptr == pfnCreateWithNativeHandle) {
36063498
return UR_RESULT_ERROR_UNINITIALIZED;
36073499
}
36083500

3609-
// convert loader handle to platform handle
3610-
hNativeEvent = reinterpret_cast<ur_native_object_t *>(hNativeEvent)->handle;
3611-
36123501
// convert loader handle to platform handle
36133502
hContext = reinterpret_cast<ur_context_object_t *>(hContext)->handle;
36143503

0 commit comments

Comments
 (0)