Skip to content

Commit c2b8f29

Browse files
committed
Adding functions specific to CUDA backend
1 parent 56897a0 commit c2b8f29

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

arrayfire/cuda.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#######################################################
2+
# Copyright (c) 2015, ArrayFire
3+
# All rights reserved.
4+
#
5+
# This file is distributed under 3-clause BSD license.
6+
# The complete license agreement can be obtained at:
7+
# http://arrayfire.com/licenses/BSD-3-Clause
8+
########################################################
9+
10+
"""
11+
Functions specific to CUDA backend.
12+
13+
This module provides interoperability with other CUDA libraries.
14+
"""
15+
16+
def get_stream(idx):
17+
"""
18+
Get the CUDA stream used for the device `idx` by ArrayFire.
19+
20+
Parameters
21+
----------
22+
23+
idx : int.
24+
Specifies the index of the device.
25+
26+
Returns
27+
-----------
28+
stream : integer denoting the stream id.
29+
"""
30+
31+
import ctypes as ct
32+
from .util import safe_call as safe_call
33+
from .library import backend as backend
34+
35+
if (backend.name() != "cuda"):
36+
raise RuntimeError("Invalid backend loaded")
37+
38+
stream = ct.c_void_p(0)
39+
safe_call(backend.get().afcu_get_stream(ct.pointer(stream), idx))
40+
return stream.value
41+
42+
def get_native_id(idx):
43+
"""
44+
Get native (unsorted) CUDA device ID
45+
46+
Parameters
47+
----------
48+
49+
idx : int.
50+
Specifies the (sorted) index of the device.
51+
52+
Returns
53+
-----------
54+
native_idx : integer denoting the native cuda id.
55+
"""
56+
57+
import ctypes as ct
58+
from .util import safe_call as safe_call
59+
from .library import backend as backend
60+
61+
if (backend.name() != "cuda"):
62+
raise RuntimeError("Invalid backend loaded")
63+
64+
native = ct.c_int(0)
65+
safe_call(backend.get().afcu_get_native_id(ct.pointer(native), idx))
66+
return native.value
67+
68+
def set_native_id(idx):
69+
"""
70+
Set native (unsorted) CUDA device ID
71+
72+
Parameters
73+
----------
74+
75+
idx : int.
76+
Specifies the (unsorted) native index of the device.
77+
"""
78+
79+
import ctypes as ct
80+
from .util import safe_call as safe_call
81+
from .library import backend as backend
82+
83+
if (backend.name() != "cuda"):
84+
raise RuntimeError("Invalid backend loaded")
85+
86+
safe_call(backend.get().afcu_set_native_id(idx))
87+
return

0 commit comments

Comments
 (0)