Skip to content

Commit 48cf2b4

Browse files
committed
Set the AF_PATH to one of known paths if none is found
This helps users avoiding setting the path
1 parent 81babf5 commit 48cf2b4

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

arrayfire/library.py

+27-19
Original file line numberDiff line numberDiff line change
@@ -322,18 +322,20 @@ def _setup():
322322
platform_name = platform.system()
323323

324324
try:
325-
AF_SEARCH_PATH = os.environ['AF_PATH']
325+
AF_PATH = os.environ['AF_PATH']
326326
except:
327-
AF_SEARCH_PATH = None
327+
AF_PATH = None
328328
pass
329329

330+
AF_SEARCH_PATH = AF_PATH
331+
330332
try:
331333
CUDA_PATH = os.environ['CUDA_PATH']
332334
except:
333335
CUDA_PATH= None
334336
pass
335337

336-
CUDA_EXISTS = False
338+
CUDA_FOUND = False
337339

338340
assert(len(platform_name) >= 3)
339341
if platform_name == 'Windows' or platform_name[:3] == 'CYG':
@@ -353,7 +355,7 @@ def _setup():
353355
AF_SEARCH_PATH="C:/Program Files/ArrayFire/v3/"
354356

355357
if CUDA_PATH is not None:
356-
CUDA_EXISTS = os.path.isdir(CUDA_PATH + '/bin') and os.path.isdir(CUDA_PATH + '/nvvm/bin/')
358+
CUDA_FOUND = os.path.isdir(CUDA_PATH + '/bin') and os.path.isdir(CUDA_PATH + '/nvvm/bin/')
357359

358360
elif platform_name == 'Darwin':
359361

@@ -367,7 +369,7 @@ def _setup():
367369
if CUDA_PATH is None:
368370
CUDA_PATH='/usr/local/cuda/'
369371

370-
CUDA_EXISTS = os.path.isdir(CUDA_PATH + '/lib') and os.path.isdir(CUDA_PATH + '/nvvm/lib')
372+
CUDA_FOUND = os.path.isdir(CUDA_PATH + '/lib') and os.path.isdir(CUDA_PATH + '/nvvm/lib')
371373

372374
elif platform_name == 'Linux':
373375
pre = 'lib'
@@ -379,20 +381,23 @@ def _setup():
379381
if CUDA_PATH is None:
380382
CUDA_PATH='/usr/local/cuda/'
381383

382-
if platform.architecture()[0][:2] == 64:
383-
CUDA_EXISTS = os.path.isdir(CUDA_PATH + '/lib64') and os.path.isdir(CUDA_PATH + '/nvvm/lib64')
384+
if platform.architecture()[0][:2] == '64':
385+
CUDA_FOUND = os.path.isdir(CUDA_PATH + '/lib64') and os.path.isdir(CUDA_PATH + '/nvvm/lib64')
384386
else:
385-
CUDA_EXISTS = os.path.isdir(CUDA_PATH + '/lib') and os.path.isdir(CUDA_PATH + '/nvvm/lib')
387+
CUDA_FOUND = os.path.isdir(CUDA_PATH + '/lib') and os.path.isdir(CUDA_PATH + '/nvvm/lib')
386388
else:
387389
raise OSError(platform_name + ' not supported')
388390

389-
return pre, post, AF_SEARCH_PATH, CUDA_EXISTS
391+
if AF_PATH is None:
392+
os.environ['AF_PATH'] = AF_SEARCH_PATH
393+
394+
return pre, post, AF_SEARCH_PATH, CUDA_FOUND
390395

391396
class _clibrary(object):
392397

393398
def __libname(self, name, head='af'):
394399
libname = self.__pre + head + name + self.__post
395-
libname_full = self.AF_SEARCH_PATH + '/lib/' + libname
400+
libname_full = self.AF_PATH + '/lib/' + libname
396401
return (libname, libname_full)
397402

398403
def set_unsafe(self, name):
@@ -405,25 +410,27 @@ def __init__(self):
405410

406411
more_info_str = "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information."
407412

408-
pre, post, AF_SEARCH_PATH, CUDA_EXISTS = _setup()
413+
pre, post, AF_PATH, CUDA_FOUND = _setup()
409414

410415
self.__pre = pre
411416
self.__post = post
412-
self.AF_SEARCH_PATH = AF_SEARCH_PATH
417+
self.AF_PATH = AF_PATH
418+
self.CUDA_FOUND = CUDA_FOUND
413419

414420
self.__name = None
415421

416-
self.__clibs = {'cuda' : None,
417-
'opencl' : None,
418-
'cpu' : None,
419-
'' : None}
422+
self.__clibs = {'cuda' : None,
423+
'opencl' : None,
424+
'cpu' : None,
425+
'unified' : None}
420426

421-
self.__backend_map = {0 : 'default',
427+
self.__backend_map = {0 : 'unified',
422428
1 : 'cpu' ,
423429
2 : 'cuda' ,
424430
4 : 'opencl' }
425431

426432
self.__backend_name_map = {'default' : 0,
433+
'unified' : 0,
427434
'cpu' : 1,
428435
'cuda' : 2,
429436
'opencl' : 4}
@@ -442,8 +449,9 @@ def __init__(self):
442449
for libname in libnames:
443450
try:
444451
ct.cdll.LoadLibrary(libname)
445-
self.__clibs[name] = ct.CDLL(libname)
446-
self.__name = name
452+
__name = 'unified' if name == '' else name
453+
self.__clibs[__name] = ct.CDLL(libname)
454+
self.__name = __name
447455
break;
448456
except:
449457
pass

0 commit comments

Comments
 (0)