Skip to content

Commit dbf31c3

Browse files
authored
[SYCL] Add multi-device and multi-platform support for SYCL_DEVICE_ALLOWLIST (#2483)
* Added support for the case where multiple devices or platforms are listed in SYCL_DEVICE_ALLOWLIST. * Fixed a memory issue, where memory had not been allocated properly. This caused a seg fault. * Created a test case which tests both legal and illegal uses. * Updated the documentation. Signed-off-by: Gail Lyons <[email protected]>
1 parent a9839b0 commit dbf31c3

File tree

3 files changed

+1004
-101
lines changed

3 files changed

+1004
-101
lines changed

sycl/doc/EnvironmentVariables.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ subject to change. Do not rely on these variables in production code.
2323
| SYCL_DISABLE_EXECUTION_GRAPH_CLEANUP | Any(\*) | Disable cleanup of finished command nodes at host-device synchronization points. |
2424
| SYCL_THROW_ON_BLOCK | Any(\*) | Throw an exception on attempt to wait for a blocked command. |
2525
| SYCL_DEVICELIB_INHIBIT_NATIVE | String of device library extensions (separated by a whitespace) | Do not rely on device native support for devicelib extensions listed in this option. |
26-
| SYCL_DEVICE_ALLOWLIST | A list of devices and their minimum driver version following the pattern: DeviceName:{{XXX}},DriverVersion:{{X.Y.Z.W}}. Also may contain PlatformName and PlatformVersion | Filter out devices that do not match the pattern specified. Regular expression can be passed and the DPC++ runtime will select only those devices which satisfy the regex. |
26+
| SYCL_DEVICE_ALLOWLIST | A list of devices and their driver version following the pattern: DeviceName:{{XXX}},DriverVersion:{{X.Y.Z.W}}. Also may contain PlatformName and PlatformVersion | Filter out devices that do not match the pattern specified. Regular expression can be passed and the DPC++ runtime will select only those devices which satisfy the regex. Special characters, such as parenthesis, must be escaped. More than one device can be specified using the piping symbol "\|".|
2727
| SYCL_QUEUE_THREAD_POOL_SIZE | Positive integer | Number of threads in thread pool of queue. |
2828
| SYCL_DEVICELIB_NO_FALLBACK | Any(\*) | Disable loading and linking of device library images |
2929
| SYCL_PI_LEVEL0_MAX_COMMAND_LIST_CACHE | Positive integer | Maximum number of oneAPI Level Zero Command lists that can be allocated with no reuse before throwing an "out of resources" error. Default is 20000, threshold may be increased based on resource availabilty and workload demand. |

sycl/source/detail/platform_impl.cpp

+125-100
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <algorithm>
1717
#include <cstring>
1818
#include <regex>
19+
#include <string>
1920

2021
__SYCL_INLINE_NAMESPACE(cl) {
2122
namespace sycl {
@@ -120,104 +121,106 @@ vector_class<platform> platform_impl::get_platforms() {
120121
return Platforms;
121122
}
122123

123-
struct DevDescT {
124-
const char *devName = nullptr;
125-
int devNameSize = 0;
126-
const char *devDriverVer = nullptr;
127-
int devDriverVerSize = 0;
124+
std::string getValue(const std::string &AllowList, size_t &Pos,
125+
unsigned long int Size) {
126+
size_t Prev = Pos;
127+
if ((Pos = AllowList.find("{{", Pos)) == std::string::npos) {
128+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
129+
PI_INVALID_VALUE);
130+
}
131+
if (Pos > Prev + Size) {
132+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
133+
PI_INVALID_VALUE);
134+
}
128135

129-
const char *platformName = nullptr;
130-
int platformNameSize = 0;
136+
Pos = Pos + 2;
137+
size_t Start = Pos;
138+
if ((Pos = AllowList.find("}}", Pos)) == std::string::npos) {
139+
throw sycl::runtime_error("Malformed syntax in SYCL_DEVICE_ALLOWLIST",
140+
PI_INVALID_VALUE);
141+
}
142+
std::string Value = AllowList.substr(Start, Pos - Start);
143+
Pos = Pos + 2;
144+
return Value;
145+
}
131146

132-
const char *platformVer = nullptr;
133-
int platformVerSize = 0;
147+
struct DevDescT {
148+
std::string DevName;
149+
std::string DevDriverVer;
150+
std::string PlatName;
151+
std::string PlatVer;
134152
};
135153

136154
static std::vector<DevDescT> getAllowListDesc() {
137-
const char *str = SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get();
138-
if (!str)
155+
std::string AllowList(SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get());
156+
if (AllowList.empty())
139157
return {};
140158

141-
std::vector<DevDescT> decDescs;
142-
const char devNameStr[] = "DeviceName";
143-
const char driverVerStr[] = "DriverVersion";
144-
const char platformNameStr[] = "PlatformName";
145-
const char platformVerStr[] = "PlatformVersion";
146-
decDescs.emplace_back();
147-
while ('\0' != *str) {
148-
const char **valuePtr = nullptr;
149-
int *size = nullptr;
150-
151-
// -1 to avoid comparing null terminator
152-
if (0 == strncmp(devNameStr, str, sizeof(devNameStr) - 1)) {
153-
valuePtr = &decDescs.back().devName;
154-
size = &decDescs.back().devNameSize;
155-
str += sizeof(devNameStr) - 1;
156-
} else if (0 ==
157-
strncmp(platformNameStr, str, sizeof(platformNameStr) - 1)) {
158-
valuePtr = &decDescs.back().platformName;
159-
size = &decDescs.back().platformNameSize;
160-
str += sizeof(platformNameStr) - 1;
161-
} else if (0 == strncmp(platformVerStr, str, sizeof(platformVerStr) - 1)) {
162-
valuePtr = &decDescs.back().platformVer;
163-
size = &decDescs.back().platformVerSize;
164-
str += sizeof(platformVerStr) - 1;
165-
} else if (0 == strncmp(driverVerStr, str, sizeof(driverVerStr) - 1)) {
166-
valuePtr = &decDescs.back().devDriverVer;
167-
size = &decDescs.back().devDriverVerSize;
168-
str += sizeof(driverVerStr) - 1;
169-
} else {
170-
throw sycl::runtime_error("Unrecognized key in device allowlist",
171-
PI_INVALID_VALUE);
159+
std::string DeviceName("DeviceName:");
160+
std::string DriverVersion("DriverVersion:");
161+
std::string PlatformName("PlatformName:");
162+
std::string PlatformVersion("PlatformVersion:");
163+
std::vector<DevDescT> DecDescs;
164+
DecDescs.emplace_back();
165+
166+
size_t Pos = 0;
167+
while (Pos < AllowList.size()) {
168+
if ((AllowList.compare(Pos, DeviceName.size(), DeviceName)) == 0) {
169+
DecDescs.back().DevName = getValue(AllowList, Pos, DeviceName.size());
170+
if (AllowList[Pos] == ',') {
171+
Pos++;
172+
}
172173
}
173174

174-
if (':' != *str)
175-
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
176-
177-
// Skip ':'
178-
str += 1;
179-
180-
if ('{' != *str || '{' != *(str + 1))
181-
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
182-
183-
// Skip opening sequence "{{"
184-
str += 2;
185-
186-
*valuePtr = str;
187-
188-
// Increment until closing sequence is encountered
189-
while (('\0' != *str) && ('}' != *str || '}' != *(str + 1)))
190-
++str;
191-
192-
if ('\0' == *str)
193-
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
194-
195-
*size = str - *valuePtr;
196-
197-
// Skip closing sequence "}}"
198-
str += 2;
199-
200-
if ('\0' == *str)
201-
break;
175+
else if ((AllowList.compare(Pos, DriverVersion.size(), DriverVersion)) ==
176+
0) {
177+
DecDescs.back().DevDriverVer =
178+
getValue(AllowList, Pos, DriverVersion.size());
179+
if (AllowList[Pos] == ',') {
180+
Pos++;
181+
}
182+
}
202183

203-
// '|' means that the is another filter
204-
if ('|' == *str)
205-
decDescs.emplace_back();
206-
else if (',' != *str)
207-
throw sycl::runtime_error("Malformed device allowlist", PI_INVALID_VALUE);
184+
else if ((AllowList.compare(Pos, PlatformName.size(), PlatformName)) == 0) {
185+
DecDescs.back().PlatName = getValue(AllowList, Pos, PlatformName.size());
186+
if (AllowList[Pos] == ',') {
187+
Pos++;
188+
}
189+
}
208190

209-
++str;
210-
}
191+
else if ((AllowList.compare(Pos, PlatformVersion.size(),
192+
PlatformVersion)) == 0) {
193+
DecDescs.back().PlatVer =
194+
getValue(AllowList, Pos, PlatformVersion.size());
195+
} else if (AllowList.find('|', Pos) != std::string::npos) {
196+
Pos = AllowList.find('|') + 1;
197+
while (AllowList[Pos] == ' ') {
198+
Pos++;
199+
}
200+
DecDescs.emplace_back();
201+
}
211202

212-
return decDescs;
203+
else {
204+
throw sycl::runtime_error("Unrecognized key in device allowlist",
205+
PI_INVALID_VALUE);
206+
}
207+
} // while (Pos <= AllowList.size())
208+
return DecDescs;
213209
}
214210

211+
enum MatchState { UNKNOWN, MATCH, NOMATCH };
212+
215213
static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
216214
RT::PiPlatform PiPlatform, const plugin &Plugin) {
217215
const std::vector<DevDescT> AllowList(getAllowListDesc());
218216
if (AllowList.empty())
219217
return;
220218

219+
MatchState DevNameState = UNKNOWN;
220+
MatchState DevVerState = UNKNOWN;
221+
MatchState PlatNameState = UNKNOWN;
222+
MatchState PlatVerState = UNKNOWN;
223+
221224
const string_class PlatformName =
222225
sycl::detail::get_platform_info<string_class, info::platform::name>::get(
223226
PiPlatform, Plugin);
@@ -237,33 +240,55 @@ static void filterAllowList(vector_class<RT::PiDevice> &PiDevices,
237240
string_class, info::device::driver_version>::get(Device, Plugin);
238241

239242
for (const DevDescT &Desc : AllowList) {
240-
if (nullptr != Desc.platformName &&
241-
!std::regex_match(PlatformName,
242-
std::regex(std::string(Desc.platformName,
243-
Desc.platformNameSize))))
244-
continue;
245-
246-
if (nullptr != Desc.platformVer &&
247-
!std::regex_match(
248-
PlatformVer,
249-
std::regex(std::string(Desc.platformVer, Desc.platformVerSize))))
250-
continue;
251-
252-
if (nullptr != Desc.devName &&
253-
!std::regex_match(DeviceName, std::regex(std::string(
254-
Desc.devName, Desc.devNameSize))))
255-
continue;
256-
257-
if (nullptr != Desc.devDriverVer &&
258-
!std::regex_match(DeviceDriverVer,
259-
std::regex(std::string(Desc.devDriverVer,
260-
Desc.devDriverVerSize))))
261-
continue;
243+
if (!Desc.PlatName.empty()) {
244+
if (!std::regex_match(PlatformName, std::regex(Desc.PlatName))) {
245+
PlatNameState = MatchState::NOMATCH;
246+
continue;
247+
} else {
248+
PlatNameState = MatchState::MATCH;
249+
}
250+
}
251+
252+
if (!Desc.PlatVer.empty()) {
253+
if (!std::regex_match(PlatformVer, std::regex(Desc.PlatVer))) {
254+
PlatVerState = MatchState::NOMATCH;
255+
continue;
256+
} else {
257+
PlatVerState = MatchState::MATCH;
258+
}
259+
}
260+
261+
if (!Desc.DevName.empty()) {
262+
if (!std::regex_match(DeviceName, std::regex(Desc.DevName))) {
263+
DevNameState = MatchState::NOMATCH;
264+
continue;
265+
} else {
266+
DevNameState = MatchState::MATCH;
267+
}
268+
}
269+
270+
if (!Desc.DevDriverVer.empty()) {
271+
if (!std::regex_match(DeviceDriverVer, std::regex(Desc.DevDriverVer))) {
272+
DevVerState = MatchState::NOMATCH;
273+
continue;
274+
} else {
275+
DevVerState = MatchState::MATCH;
276+
}
277+
}
262278

263279
PiDevices[InsertIDx++] = Device;
264280
break;
265281
}
266282
}
283+
if (DevNameState == MatchState::MATCH && DevVerState == MatchState::NOMATCH) {
284+
throw sycl::runtime_error("Requested SYCL device not found",
285+
PI_DEVICE_NOT_FOUND);
286+
}
287+
if (PlatNameState == MatchState::MATCH &&
288+
PlatVerState == MatchState::NOMATCH) {
289+
throw sycl::runtime_error("Requested SYCL platform not found",
290+
PI_DEVICE_NOT_FOUND);
291+
}
267292
PiDevices.resize(InsertIDx);
268293
}
269294

0 commit comments

Comments
 (0)