Skip to content

Commit 6e41df8

Browse files
committed
Add backend validation for peer access
only implemented for backends HIP, CUDA, Level Zero. Validation prevents crashes
1 parent 74dab57 commit 6e41df8

File tree

2 files changed

+187
-5
lines changed

2 files changed

+187
-5
lines changed

dpctl/_sycl_device.pyx

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,7 +1796,7 @@ cdef class SyclDevice(_SyclDevice):
17961796
raise ValueError("Internal error: NULL device vector encountered")
17971797
return _get_devices(cDVRef)
17981798

1799-
def can_access_peer(self, peer):
1799+
def can_access_peer_access_supported(self, peer):
18001800
""" Returns ``True`` if `self` can enable peer access
18011801
to `peer`, ``False`` otherwise.
18021802
@@ -1809,14 +1809,45 @@ cdef class SyclDevice(_SyclDevice):
18091809
bool:
18101810
``True`` if `self` can enable peer access
18111811
to `peer`, otherwise ``False``.
1812+
1813+
Raises:
1814+
TypeError:
1815+
If `peer` is not `dpctl.SyclDevice`.
1816+
ValueError:
1817+
If the backend associated with `self` or `peer` does not
1818+
support peer access.
18121819
"""
18131820
cdef SyclDevice p_dev
1821+
cdef _backend_type BTy1
1822+
cdef _backend_type BTy2
1823+
18141824
if not isinstance(peer, SyclDevice):
18151825
raise TypeError(
18161826
"second argument must be a `dpctl.SyclDevice`, got "
18171827
f"{type(peer)}"
18181828
)
18191829
p_dev = <SyclDevice>peer
1830+
BTy1 = DPCTLDevice_GetBackend(self._device_ref)
1831+
if (
1832+
BTy1 != _backend_type._CUDA and
1833+
BTy1 != _backend_type._HIP and
1834+
BTy1 != _backend_type._LEVEL_ZERO
1835+
):
1836+
raise ValueError(
1837+
"Peer access not supported for backend "
1838+
f"{_backend_type_to_filter_string_part(BTy1)}"
1839+
)
1840+
BTy2 = DPCTLDevice_GetBackend(p_dev.get_device_ref())
1841+
if (
1842+
BTy2 != _backend_type._CUDA and
1843+
BTy2 != _backend_type._HIP and
1844+
BTy2 != _backend_type._LEVEL_ZERO
1845+
):
1846+
raise ValueError(
1847+
"Peer access not supported for backend "
1848+
f"{_backend_type_to_filter_string_part(BTy2)}"
1849+
)
1850+
18201851
return DPCTLDevice_CanAccessPeer(
18211852
self._device_ref,
18221853
p_dev.get_device_ref(),
@@ -1837,14 +1868,45 @@ cdef class SyclDevice(_SyclDevice):
18371868
``True`` if `self` can enable peer access
18381869
to and can atomically modify memory on `peer`,
18391870
otherwise ``False``.
1871+
1872+
Raises:
1873+
TypeError:
1874+
If `peer` is not `dpctl.SyclDevice`.
1875+
ValueError:
1876+
If the backend associated with `self` or `peer` does not
1877+
support peer access.
18401878
"""
18411879
cdef SyclDevice p_dev
1880+
cdef _backend_type BTy1
1881+
cdef _backend_type BTy2
1882+
18421883
if not isinstance(peer, SyclDevice):
18431884
raise TypeError(
18441885
"second argument must be a `dpctl.SyclDevice`, got "
18451886
f"{type(peer)}"
18461887
)
18471888
p_dev = <SyclDevice>peer
1889+
BTy1 = DPCTLDevice_GetBackend(self._device_ref)
1890+
if (
1891+
BTy1 != _backend_type._CUDA and
1892+
BTy1 != _backend_type._HIP and
1893+
BTy1 != _backend_type._LEVEL_ZERO
1894+
):
1895+
raise ValueError(
1896+
"Peer access not supported for backend "
1897+
f"{_backend_type_to_filter_string_part(BTy1)}"
1898+
)
1899+
BTy2 = DPCTLDevice_GetBackend(p_dev.get_device_ref())
1900+
if (
1901+
BTy2 != _backend_type._CUDA and
1902+
BTy2 != _backend_type._HIP and
1903+
BTy2 != _backend_type._LEVEL_ZERO
1904+
):
1905+
raise ValueError(
1906+
"Peer access not supported for backend "
1907+
f"{_backend_type_to_filter_string_part(BTy2)}"
1908+
)
1909+
18481910
return DPCTLDevice_CanAccessPeer(
18491911
self._device_ref,
18501912
p_dev.get_device_ref(),
@@ -1861,17 +1923,45 @@ cdef class SyclDevice(_SyclDevice):
18611923
enable peer access to.
18621924
18631925
Raises:
1926+
TypeError:
1927+
If `peer` is not `dpctl.SyclDevice`.
18641928
ValueError:
1865-
If the ``DPCTLDevice_GetComponentDevices`` call returned
1866-
``NULL`` instead of a ``DPCTLDeviceVectorRef`` object.
1929+
If the backend associated with `self` or `peer` does not
1930+
support peer access.
18671931
"""
18681932
cdef SyclDevice p_dev
1933+
cdef _backend_type BTy1
1934+
cdef _backend_type BTy2
1935+
18691936
if not isinstance(peer, SyclDevice):
18701937
raise TypeError(
18711938
"second argument must be a `dpctl.SyclDevice`, got "
18721939
f"{type(peer)}"
18731940
)
18741941
p_dev = <SyclDevice>peer
1942+
BTy1 = (
1943+
DPCTLDevice_GetBackend(self._device_ref)
1944+
)
1945+
if (
1946+
BTy1 != _backend_type._CUDA and
1947+
BTy1 != _backend_type._HIP and
1948+
BTy1 != _backend_type._LEVEL_ZERO
1949+
):
1950+
raise ValueError(
1951+
"Peer access not supported for backend "
1952+
f"{_backend_type_to_filter_string_part(BTy1)}"
1953+
)
1954+
BTy2 = DPCTLDevice_GetBackend(p_dev.get_device_ref())
1955+
if (
1956+
BTy2 != _backend_type._CUDA and
1957+
BTy2 != _backend_type._HIP and
1958+
BTy2 != _backend_type._LEVEL_ZERO
1959+
):
1960+
raise ValueError(
1961+
"Peer access not supported for backend "
1962+
f"{_backend_type_to_filter_string_part(BTy2)}"
1963+
)
1964+
18751965
DPCTLDevice_EnablePeerAccess(self._device_ref, p_dev.get_device_ref())
18761966
return
18771967

@@ -1884,17 +1974,43 @@ cdef class SyclDevice(_SyclDevice):
18841974
disable peer access to.
18851975
18861976
Raises:
1977+
TypeError:
1978+
If `peer` is not `dpctl.SyclDevice`.
18871979
ValueError:
1888-
If the ``DPCTLDevice_GetComponentDevices`` call returned
1889-
``NULL`` instead of a ``DPCTLDeviceVectorRef`` object.
1980+
If the backend associated with `self` or `peer` does not
1981+
support peer access.
18901982
"""
18911983
cdef SyclDevice p_dev
1984+
cdef _backend_type BTy1
1985+
cdef _backend_type BTy2
1986+
18921987
if not isinstance(peer, SyclDevice):
18931988
raise TypeError(
18941989
"second argument must be a `dpctl.SyclDevice`, got "
18951990
f"{type(peer)}"
18961991
)
18971992
p_dev = <SyclDevice>peer
1993+
BTy1 = DPCTLDevice_GetBackend(self._device_ref)
1994+
if (
1995+
BTy1 != _backend_type._CUDA and
1996+
BTy1 != _backend_type._HIP and
1997+
BTy1 != _backend_type._LEVEL_ZERO
1998+
):
1999+
raise ValueError(
2000+
"Peer access not supported for backend "
2001+
f"{_backend_type_to_filter_string_part(BTy1)}"
2002+
)
2003+
BTy2 = DPCTLDevice_GetBackend(p_dev.get_device_ref())
2004+
if (
2005+
BTy2 != _backend_type._CUDA and
2006+
BTy2 != _backend_type._HIP and
2007+
BTy2 != _backend_type._LEVEL_ZERO
2008+
):
2009+
raise ValueError(
2010+
"Peer access not supported for backend "
2011+
f"{_backend_type_to_filter_string_part(BTy2)}"
2012+
)
2013+
18982014
DPCTLDevice_DisablePeerAccess(self._device_ref, p_dev.get_device_ref())
18992015
return
19002016

libsyclinterface/source/dpctl_sycl_device_interface.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -912,6 +912,28 @@ bool DPCTLDevice_CanAccessPeer(__dpctl_keep const DPCTLSyclDeviceRef DRef,
912912
auto D = unwrap<device>(DRef);
913913
auto PD = unwrap<device>(PDRef);
914914
if (D && PD) {
915+
auto BE1 = D->get_backend();
916+
auto BE2 = PD->get_backend();
917+
918+
if (BE1 != sycl::backend::ext_oneapi_level_zero &&
919+
BE1 != sycl::backend::ext_oneapi_cuda &&
920+
BE1 != sycl::backend::ext_oneapi_hip)
921+
{
922+
error_handler("Backend " + std::to_string(static_cast<int>(BE1)) +
923+
" does not support peer access",
924+
__FILE__, __func__, __LINE__);
925+
return false;
926+
}
927+
928+
if (BE2 != sycl::backend::ext_oneapi_level_zero &&
929+
BE2 != sycl::backend::ext_oneapi_cuda &&
930+
BE2 != sycl::backend::ext_oneapi_hip)
931+
{
932+
error_handler("Backend " + std::to_string(static_cast<int>(BE2)) +
933+
" does not support peer access",
934+
__FILE__, __func__, __LINE__);
935+
return false;
936+
}
915937
try {
916938
canAccess = D->ext_oneapi_can_access_peer(
917939
*PD, DPCTL_DPCTLPeerAccessTypeToSycl(PT));
@@ -928,6 +950,28 @@ void DPCTLDevice_EnablePeerAccess(__dpctl_keep const DPCTLSyclDeviceRef DRef,
928950
auto D = unwrap<device>(DRef);
929951
auto PD = unwrap<device>(PDRef);
930952
if (D && PD) {
953+
auto BE1 = D->get_backend();
954+
auto BE2 = PD->get_backend();
955+
956+
if (BE1 != sycl::backend::ext_oneapi_level_zero &&
957+
BE1 != sycl::backend::ext_oneapi_cuda &&
958+
BE1 != sycl::backend::ext_oneapi_hip)
959+
{
960+
error_handler("Backend " + std::to_string(static_cast<int>(BE1)) +
961+
" does not support peer access",
962+
__FILE__, __func__, __LINE__);
963+
return;
964+
}
965+
966+
if (BE2 != sycl::backend::ext_oneapi_level_zero &&
967+
BE2 != sycl::backend::ext_oneapi_cuda &&
968+
BE2 != sycl::backend::ext_oneapi_hip)
969+
{
970+
error_handler("Backend " + std::to_string(static_cast<int>(BE2)) +
971+
" does not support peer access",
972+
__FILE__, __func__, __LINE__);
973+
return;
974+
}
931975
try {
932976
D->ext_oneapi_enable_peer_access(*PD);
933977
} catch (std::exception const &e) {
@@ -943,6 +987,28 @@ void DPCTLDevice_DisablePeerAccess(__dpctl_keep const DPCTLSyclDeviceRef DRef,
943987
auto D = unwrap<device>(DRef);
944988
auto PD = unwrap<device>(PDRef);
945989
if (D && PD) {
990+
auto BE1 = D->get_backend();
991+
auto BE2 = PD->get_backend();
992+
993+
if (BE1 != sycl::backend::ext_oneapi_level_zero &&
994+
BE1 != sycl::backend::ext_oneapi_cuda &&
995+
BE1 != sycl::backend::ext_oneapi_hip)
996+
{
997+
error_handler("Backend " + std::to_string(static_cast<int>(BE1)) +
998+
" does not support peer access",
999+
__FILE__, __func__, __LINE__);
1000+
return;
1001+
}
1002+
1003+
if (BE2 != sycl::backend::ext_oneapi_level_zero &&
1004+
BE2 != sycl::backend::ext_oneapi_cuda &&
1005+
BE2 != sycl::backend::ext_oneapi_hip)
1006+
{
1007+
error_handler("Backend " + std::to_string(static_cast<int>(BE2)) +
1008+
" does not support peer access",
1009+
__FILE__, __func__, __LINE__);
1010+
return;
1011+
}
9461012
try {
9471013
D->ext_oneapi_disable_peer_access(*PD);
9481014
} catch (std::exception const &e) {

0 commit comments

Comments
 (0)