Skip to content

Commit ced0322

Browse files
committed
FEAT: Adding functions to interact with other opencl libraries
1 parent 196f110 commit ced0322

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

arrayfire/opencl.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,30 @@
1313
This module provides interoperability with other OpenCL libraries.
1414
"""
1515

16+
from .util import *
17+
from .library import (_Enum, _Enum_Type)
18+
19+
class DEVICE_TYPE(_Enum):
20+
"""
21+
ArrayFire wrapper for CL_DEVICE_TYPE
22+
"""
23+
CPU = _Enum_Type(1<<1)
24+
GPU = _Enum_Type(1<<2)
25+
ACC = _Enum_Type(1<<3)
26+
UNKNOWN = _Enum_Type(-1)
27+
28+
class PLATFORM(_Enum):
29+
"""
30+
ArrayFire enum for common platforms
31+
"""
32+
AMD = _Enum_Type(0)
33+
APPLE = _Enum_Type(1)
34+
INTEL = _Enum_Type(2)
35+
NVIDIA = _Enum_Type(3)
36+
BEIGNET = _Enum_Type(4)
37+
POCL = _Enum_Type(5)
38+
UNKNOWN = _Enum_Type(-1)
39+
1640
def get_context(retain=False):
1741
"""
1842
Get the current OpenCL context being used by ArrayFire.
@@ -107,3 +131,87 @@ def set_device_id(idx):
107131

108132
safe_call(backend.get().afcl_set_device_id(idx))
109133
return
134+
135+
def add_device_context(dev, ctx, que):
136+
"""
137+
Add a new device to arrayfire opencl device manager
138+
139+
Parameters
140+
----------
141+
142+
dev : cl_device_id
143+
144+
ctx : cl_context
145+
146+
que : cl_command_queue
147+
148+
"""
149+
if (backend.name() != "opencl"):
150+
raise RuntimeError("Invalid backend loaded")
151+
152+
safe_call(backend.get().afcl_add_device_context(dev, ctx, que))
153+
154+
def set_device_context(dev, ctx):
155+
"""
156+
Set a device as current active device
157+
158+
Parameters
159+
----------
160+
161+
dev : cl_device_id
162+
163+
ctx : cl_context
164+
165+
"""
166+
if (backend.name() != "opencl"):
167+
raise RuntimeError("Invalid backend loaded")
168+
169+
safe_call(backend.get().afcl_set_device_context(dev, ctx))
170+
171+
def delete_device_context(dev, ctx):
172+
"""
173+
Delete a device
174+
175+
Parameters
176+
----------
177+
178+
dev : cl_device_id
179+
180+
ctx : cl_context
181+
182+
"""
183+
if (backend.name() != "opencl"):
184+
raise RuntimeError("Invalid backend loaded")
185+
186+
safe_call(backend.get().afcl_delete_device_context(dev, ctx))
187+
188+
189+
_to_device_type = {DEVICE_TYPE.CPU.value : DEVICE_TYPE.CPU,
190+
DEVICE_TYPE.GPU.value : DEVICE_TYPE.GPU,
191+
DEVICE_TYPE.ACC.value : DEVICE_TYPE.ACC,
192+
DEVICE_TYPE.UNKNOWN.value : DEVICE_TYPE.UNKNOWN}
193+
194+
_to_platform = {PLATFORM.AMD.value : PLATFORM.AMD,
195+
PLATFORM.APPLE.value : PLATFORM.APPLE,
196+
PLATFORM.INTEL.value : PLATFORM.INTEL,
197+
PLATFORM.NVIDIA.value : PLATFORM.NVIDIA,
198+
PLATFORM.BEIGNET.value : PLATFORM.BEIGNET,
199+
PLATFORM.POCL.value : PLATFORM.POCL,
200+
PLATFORM.UNKNOWN.value : PLATFORM.UNKNOWN}
201+
202+
203+
def get_device_type():
204+
"""
205+
Get opencl device type
206+
"""
207+
res = ct.c_int(DEVICE_TYPE.UNKNOWN.value)
208+
safe_call(backend.get().afcl_get_device_type(ct.pointer(res)))
209+
return _to_device_type[res.value]
210+
211+
def get_platform():
212+
"""
213+
Get opencl platform
214+
"""
215+
res = ct.c_int(PLATFORM.UNKNOWN.value)
216+
safe_call(backend.get().afcl_get_platform(ct.pointer(res)))
217+
return _to_platform[res.value]

0 commit comments

Comments
 (0)