Skip to content

Commit 3db3a5e

Browse files
authored
Support UR program creation from multiple device binaries (#2147)
To support multi-device AOT scenario in SYCL we need an ability to create UR program from multiple device binaries. Changes in this PR: * Modify the core function `urProgramCreateWithBinary` to support program creation from multiple device binaries. * Add methods to ur_program to get/set per-device data like L0 module handle, build log etc. Otherwise any change in the structure of the class requires multiple changes in all UR functions which work with ur_program, in addition to this it allows to handle interop case (which implementation is an exception right now) in a single place. Also make State and some other info per-device because it is possible that UR program is associated with multiple devices and urProgramBuildExp is called multiple times for subsets, and we have to know the state per-device.
1 parent e3eeb4e commit 3db3a5e

34 files changed

+1129
-561
lines changed

.github/workflows/multi_device.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,4 @@ jobs:
6363

6464
- name: Test adapters
6565
working-directory: ${{github.workspace}}/build
66-
run: env UR_CTS_ADAPTER_PLATFORM="${{matrix.adapter.platform}}" ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" -E "enqueue|kernel|program|integration|exp_command_buffer|exp_enqueue_native|exp_launch_properties|exp_usm_p2p" --timeout 180
66+
run: env UR_CTS_ADAPTER_PLATFORM="${{matrix.adapter.platform}}" ctest -C ${{matrix.build_type}} --output-on-failure -L "conformance" -E "enqueue|kernel|integration|exp_command_buffer|exp_enqueue_native|exp_launch_properties|exp_usm_p2p" --timeout 180

include/ur_api.h

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4202,17 +4202,19 @@ urProgramCreateWithIL(
42024202
);
42034203

42044204
///////////////////////////////////////////////////////////////////////////////
4205-
/// @brief Create a program object from device native binary.
4205+
/// @brief Create a program object from native binaries for the specified
4206+
/// devices.
42064207
///
42074208
/// @details
42084209
/// - The application may call this function from simultaneous threads.
42094210
/// - Following a successful call to this entry point, `phProgram` will
4210-
/// contain a binary of type ::UR_PROGRAM_BINARY_TYPE_COMPILED_OBJECT or
4211-
/// ::UR_PROGRAM_BINARY_TYPE_LIBRARY for `hDevice`.
4212-
/// - The device specified by `hDevice` must be device associated with
4211+
/// contain binaries of type ::UR_PROGRAM_BINARY_TYPE_COMPILED_OBJECT or
4212+
/// ::UR_PROGRAM_BINARY_TYPE_LIBRARY for the specified devices in
4213+
/// `phDevices`.
4214+
/// - The devices specified by `phDevices` must be associated with the
42134215
/// context.
42144216
/// - The adapter may (but is not required to) perform validation of the
4215-
/// provided module during this call.
4217+
/// provided modules during this call.
42164218
///
42174219
/// @remarks
42184220
/// _Analogues_
@@ -4225,21 +4227,27 @@ urProgramCreateWithIL(
42254227
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
42264228
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
42274229
/// + `NULL == hContext`
4228-
/// + `NULL == hDevice`
42294230
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
4230-
/// + `NULL == pBinary`
4231+
/// + `NULL == phDevices`
4232+
/// + `NULL == pLengths`
4233+
/// + `NULL == ppBinaries`
42314234
/// + `NULL == phProgram`
42324235
/// + `NULL != pProperties && pProperties->count > 0 && NULL == pProperties->pMetadatas`
42334236
/// - ::UR_RESULT_ERROR_INVALID_SIZE
42344237
/// + `NULL != pProperties && NULL != pProperties->pMetadatas && pProperties->count == 0`
4238+
/// + `numDevices == 0`
42354239
/// - ::UR_RESULT_ERROR_INVALID_NATIVE_BINARY
4236-
/// + If `pBinary` isn't a valid binary for `hDevice.`
4240+
/// + If any binary in `ppBinaries` isn't a valid binary for the corresponding device in `phDevices.`
42374241
UR_APIEXPORT ur_result_t UR_APICALL
42384242
urProgramCreateWithBinary(
42394243
ur_context_handle_t hContext, ///< [in] handle of the context instance
4240-
ur_device_handle_t hDevice, ///< [in] handle to device associated with binary.
4241-
size_t size, ///< [in] size in bytes.
4242-
const uint8_t *pBinary, ///< [in] pointer to binary.
4244+
uint32_t numDevices, ///< [in] number of devices
4245+
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] a pointer to a list of device handles. The
4246+
///< binaries are loaded for devices specified in this list.
4247+
size_t *pLengths, ///< [in][range(0, numDevices)] array of sizes of program binaries
4248+
///< specified by `pBinaries` (in bytes).
4249+
const uint8_t **ppBinaries, ///< [in][range(0, numDevices)] pointer to program binaries to be loaded
4250+
///< for devices specified by `phDevices`.
42434251
const ur_program_properties_t *pProperties, ///< [in][optional] pointer to program creation properties.
42444252
ur_program_handle_t *phProgram ///< [out] pointer to handle of Program object created.
42454253
);
@@ -10325,9 +10333,10 @@ typedef struct ur_program_create_with_il_params_t {
1032510333
/// allowing the callback the ability to modify the parameter's value
1032610334
typedef struct ur_program_create_with_binary_params_t {
1032710335
ur_context_handle_t *phContext;
10328-
ur_device_handle_t *phDevice;
10329-
size_t *psize;
10330-
const uint8_t **ppBinary;
10336+
uint32_t *pnumDevices;
10337+
ur_device_handle_t **pphDevices;
10338+
size_t **ppLengths;
10339+
const uint8_t ***pppBinaries;
1033110340
const ur_program_properties_t **ppProperties;
1033210341
ur_program_handle_t **pphProgram;
1033310342
} ur_program_create_with_binary_params_t;

include/ur_ddi.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,10 @@ typedef ur_result_t(UR_APICALL *ur_pfnProgramCreateWithIL_t)(
284284
/// @brief Function-pointer for urProgramCreateWithBinary
285285
typedef ur_result_t(UR_APICALL *ur_pfnProgramCreateWithBinary_t)(
286286
ur_context_handle_t,
287-
ur_device_handle_t,
288-
size_t,
289-
const uint8_t *,
287+
uint32_t,
288+
ur_device_handle_t *,
289+
size_t *,
290+
const uint8_t **,
290291
const ur_program_properties_t *,
291292
ur_program_handle_t *);
292293

include/ur_print.hpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11179,21 +11179,44 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1117911179
*(params->phContext));
1118011180

1118111181
os << ", ";
11182-
os << ".hDevice = ";
11182+
os << ".numDevices = ";
1118311183

11184-
ur::details::printPtr(os,
11185-
*(params->phDevice));
11184+
os << *(params->pnumDevices);
1118611185

1118711186
os << ", ";
11188-
os << ".size = ";
11187+
os << ".phDevices = {";
11188+
for (size_t i = 0; *(params->pphDevices) != NULL && i < *params->pnumDevices; ++i) {
11189+
if (i != 0) {
11190+
os << ", ";
11191+
}
1118911192

11190-
os << *(params->psize);
11193+
ur::details::printPtr(os,
11194+
(*(params->pphDevices))[i]);
11195+
}
11196+
os << "}";
1119111197

1119211198
os << ", ";
11193-
os << ".pBinary = ";
11199+
os << ".pLengths = {";
11200+
for (size_t i = 0; *(params->ppLengths) != NULL && i < *params->pnumDevices; ++i) {
11201+
if (i != 0) {
11202+
os << ", ";
11203+
}
1119411204

11195-
ur::details::printPtr(os,
11196-
*(params->ppBinary));
11205+
os << (*(params->ppLengths))[i];
11206+
}
11207+
os << "}";
11208+
11209+
os << ", ";
11210+
os << ".ppBinaries = {";
11211+
for (size_t i = 0; *(params->pppBinaries) != NULL && i < *params->pnumDevices; ++i) {
11212+
if (i != 0) {
11213+
os << ", ";
11214+
}
11215+
11216+
ur::details::printPtr(os,
11217+
(*(params->pppBinaries))[i]);
11218+
}
11219+
os << "}";
1119711220

1119811221
os << ", ";
1119911222
os << ".pProperties = ";

scripts/core/program.yml

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ returns:
119119
- "`length == 0`"
120120
--- #--------------------------------------------------------------------------
121121
type: function
122-
desc: "Create a program object from device native binary."
122+
desc: "Create a program object from native binaries for the specified devices."
123123
class: $xProgram
124124
name: CreateWithBinary
125125
decl: static
@@ -128,22 +128,25 @@ analogue:
128128
- "**clCreateProgramWithBinary**"
129129
details:
130130
- "The application may call this function from simultaneous threads."
131-
- "Following a successful call to this entry point, `phProgram` will contain a binary of type $X_PROGRAM_BINARY_TYPE_COMPILED_OBJECT or $X_PROGRAM_BINARY_TYPE_LIBRARY for `hDevice`."
132-
- "The device specified by `hDevice` must be device associated with context."
133-
- "The adapter may (but is not required to) perform validation of the provided module during this call."
131+
- "Following a successful call to this entry point, `phProgram` will contain binaries of type $X_PROGRAM_BINARY_TYPE_COMPILED_OBJECT or $X_PROGRAM_BINARY_TYPE_LIBRARY for the specified devices in `phDevices`."
132+
- "The devices specified by `phDevices` must be associated with the context."
133+
- "The adapter may (but is not required to) perform validation of the provided modules during this call."
134134
params:
135135
- type: $x_context_handle_t
136136
name: hContext
137137
desc: "[in] handle of the context instance"
138-
- type: $x_device_handle_t
139-
name: hDevice
140-
desc: "[in] handle to device associated with binary."
141-
- type: size_t
142-
name: size
143-
desc: "[in] size in bytes."
144-
- type: const uint8_t*
145-
name: pBinary
146-
desc: "[in] pointer to binary."
138+
- type: uint32_t
139+
name: numDevices
140+
desc: "[in] number of devices"
141+
- type: $x_device_handle_t*
142+
name: phDevices
143+
desc: "[in][range(0, numDevices)] a pointer to a list of device handles. The binaries are loaded for devices specified in this list."
144+
- type: size_t*
145+
name: pLengths
146+
desc: "[in][range(0, numDevices)] array of sizes of program binaries specified by `pBinaries` (in bytes)."
147+
- type: const uint8_t**
148+
name: ppBinaries
149+
desc: "[in][range(0, numDevices)] pointer to program binaries to be loaded for devices specified by `phDevices`."
147150
- type: const $x_program_properties_t*
148151
name: pProperties
149152
desc: "[in][optional] pointer to program creation properties."
@@ -155,8 +158,9 @@ returns:
155158
- "`NULL != pProperties && pProperties->count > 0 && NULL == pProperties->pMetadatas`"
156159
- $X_RESULT_ERROR_INVALID_SIZE:
157160
- "`NULL != pProperties && NULL != pProperties->pMetadatas && pProperties->count == 0`"
161+
- "`numDevices == 0`"
158162
- $X_RESULT_ERROR_INVALID_NATIVE_BINARY:
159-
- "If `pBinary` isn't a valid binary for `hDevice.`"
163+
- "If any binary in `ppBinaries` isn't a valid binary for the corresponding device in `phDevices.`"
160164
--- #--------------------------------------------------------------------------
161165
type: function
162166
desc: "Produces an executable program from one program, negates need for the linking step."

source/adapters/cuda/program.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,12 +493,15 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetNativeHandle(
493493
}
494494

495495
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
496-
ur_context_handle_t hContext, ur_device_handle_t hDevice, size_t size,
497-
const uint8_t *pBinary, const ur_program_properties_t *pProperties,
496+
ur_context_handle_t hContext, uint32_t numDevices,
497+
ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries,
498+
const ur_program_properties_t *pProperties,
498499
ur_program_handle_t *phProgram) {
500+
if (numDevices > 1)
501+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
499502

500-
UR_CHECK_ERROR(
501-
createProgram(hContext, hDevice, size, pBinary, pProperties, phProgram));
503+
UR_CHECK_ERROR(createProgram(hContext, phDevices[0], pLengths[0],
504+
ppBinaries[0], pProperties, phProgram));
502505
(*phProgram)->BinaryType = UR_PROGRAM_BINARY_TYPE_COMPILED_OBJECT;
503506

504507
return UR_RESULT_SUCCESS;

source/adapters/hip/program.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,9 +480,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetNativeHandle(
480480
///
481481
/// Note: Only supports one device
482482
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
483-
ur_context_handle_t hContext, ur_device_handle_t hDevice, size_t size,
484-
const uint8_t *pBinary, const ur_program_properties_t *pProperties,
483+
ur_context_handle_t hContext, uint32_t numDevices,
484+
ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries,
485+
const ur_program_properties_t *pProperties,
485486
ur_program_handle_t *phProgram) {
487+
if (numDevices > 1)
488+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
489+
490+
auto hDevice = phDevices[0];
491+
auto pBinary = ppBinaries[0];
492+
auto size = pLengths[0];
486493
UR_ASSERT(std::find(hContext->getDevices().begin(),
487494
hContext->getDevices().end(),
488495
hDevice) != hContext->getDevices().end(),

source/adapters/level_zero/kernel.cpp

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -495,18 +495,11 @@ ur_result_t urEnqueueDeviceGlobalVariableWrite(
495495
///< this particular kernel execution instance.
496496
) {
497497
std::scoped_lock<ur_shared_mutex> lock(Queue->Mutex);
498-
499-
ze_module_handle_t ZeModule{};
500-
auto It = Program->ZeModuleMap.find(Queue->Device->ZeDevice);
501-
if (It != Program->ZeModuleMap.end()) {
502-
ZeModule = It->second;
503-
} else {
504-
ZeModule = Program->ZeModule;
505-
}
506-
507498
// Find global variable pointer
508499
size_t GlobalVarSize = 0;
509500
void *GlobalVarPtr = nullptr;
501+
ze_module_handle_t ZeModule =
502+
Program->getZeModuleHandle(Queue->Device->ZeDevice);
510503
ZE2UR_CALL(zeModuleGetGlobalPointer,
511504
(ZeModule, Name, &GlobalVarSize, &GlobalVarPtr));
512505
if (GlobalVarSize < Offset + Count) {
@@ -557,15 +550,8 @@ ur_result_t urEnqueueDeviceGlobalVariableRead(
557550
///< this particular kernel execution instance.
558551
) {
559552
std::scoped_lock<ur_shared_mutex> lock(Queue->Mutex);
560-
561-
ze_module_handle_t ZeModule{};
562-
auto It = Program->ZeModuleMap.find(Queue->Device->ZeDevice);
563-
if (It != Program->ZeModuleMap.end()) {
564-
ZeModule = It->second;
565-
} else {
566-
ZeModule = Program->ZeModule;
567-
}
568-
553+
ze_module_handle_t ZeModule =
554+
Program->getZeModuleHandle(Queue->Device->ZeDevice);
569555
// Find global variable pointer
570556
size_t GlobalVarSize = 0;
571557
void *GlobalVarPtr = nullptr;
@@ -603,10 +589,6 @@ ur_result_t urKernelCreate(
603589
*RetKernel ///< [out] pointer to handle of kernel object created.
604590
) {
605591
std::shared_lock<ur_shared_mutex> Guard(Program->Mutex);
606-
if (Program->State != ur_program_handle_t_::state::Exe) {
607-
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
608-
}
609-
610592
try {
611593
ur_kernel_handle_t_ *UrKernel = new ur_kernel_handle_t_(true, Program);
612594
*RetKernel = reinterpret_cast<ur_kernel_handle_t>(UrKernel);
@@ -616,8 +598,14 @@ ur_result_t urKernelCreate(
616598
return UR_RESULT_ERROR_UNKNOWN;
617599
}
618600

619-
for (auto It : Program->ZeModuleMap) {
620-
auto ZeModule = It.second;
601+
for (auto &Dev : Program->AssociatedDevices) {
602+
auto ZeDevice = Dev->ZeDevice;
603+
// Program may be associated with all devices from the context but built
604+
// only for subset of devices.
605+
if (Program->getState(ZeDevice) != ur_program_handle_t_::state::Exe)
606+
continue;
607+
608+
auto ZeModule = Program->getZeModuleHandle(ZeDevice);
621609
ZeStruct<ze_kernel_desc_t> ZeKernelDesc;
622610
ZeKernelDesc.flags = 0;
623611
ZeKernelDesc.pKernelName = KernelName;
@@ -632,8 +620,6 @@ ur_result_t urKernelCreate(
632620
return ze2urResult(ZeResult);
633621
}
634622

635-
auto ZeDevice = It.first;
636-
637623
// Store the kernel in the ZeKernelMap so the correct
638624
// kernel can be retrieved later for a specific device
639625
// where a queue is being submitted.
@@ -651,6 +637,9 @@ ur_result_t urKernelCreate(
651637
(*RetKernel)->ZeKernelMap[ZeSubDevice] = ZeKernel;
652638
}
653639
}
640+
// There is no any successfully built executable for program.
641+
if ((*RetKernel)->ZeKernelMap.empty())
642+
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
654643

655644
(*RetKernel)->ZeKernel = (*RetKernel)->ZeKernelMap.begin()->second;
656645

0 commit comments

Comments
 (0)