@@ -322,18 +322,20 @@ def _setup():
322
322
platform_name = platform .system ()
323
323
324
324
try :
325
- AF_SEARCH_PATH = os .environ ['AF_PATH' ]
325
+ AF_PATH = os .environ ['AF_PATH' ]
326
326
except :
327
- AF_SEARCH_PATH = None
327
+ AF_PATH = None
328
328
pass
329
329
330
+ AF_SEARCH_PATH = AF_PATH
331
+
330
332
try :
331
333
CUDA_PATH = os .environ ['CUDA_PATH' ]
332
334
except :
333
335
CUDA_PATH = None
334
336
pass
335
337
336
- CUDA_EXISTS = False
338
+ CUDA_FOUND = False
337
339
338
340
assert (len (platform_name ) >= 3 )
339
341
if platform_name == 'Windows' or platform_name [:3 ] == 'CYG' :
@@ -353,7 +355,7 @@ def _setup():
353
355
AF_SEARCH_PATH = "C:/Program Files/ArrayFire/v3/"
354
356
355
357
if CUDA_PATH is not None :
356
- CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/bin' ) and os .path .isdir (CUDA_PATH + '/nvvm/bin/' )
358
+ CUDA_FOUND = os .path .isdir (CUDA_PATH + '/bin' ) and os .path .isdir (CUDA_PATH + '/nvvm/bin/' )
357
359
358
360
elif platform_name == 'Darwin' :
359
361
@@ -367,7 +369,7 @@ def _setup():
367
369
if CUDA_PATH is None :
368
370
CUDA_PATH = '/usr/local/cuda/'
369
371
370
- CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/lib' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib' )
372
+ CUDA_FOUND = os .path .isdir (CUDA_PATH + '/lib' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib' )
371
373
372
374
elif platform_name == 'Linux' :
373
375
pre = 'lib'
@@ -379,20 +381,23 @@ def _setup():
379
381
if CUDA_PATH is None :
380
382
CUDA_PATH = '/usr/local/cuda/'
381
383
382
- if platform .architecture ()[0 ][:2 ] == 64 :
383
- CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/lib64' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib64' )
384
+ if platform .architecture ()[0 ][:2 ] == '64' :
385
+ CUDA_FOUND = os .path .isdir (CUDA_PATH + '/lib64' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib64' )
384
386
else :
385
- CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/lib' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib' )
387
+ CUDA_FOUND = os .path .isdir (CUDA_PATH + '/lib' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib' )
386
388
else :
387
389
raise OSError (platform_name + ' not supported' )
388
390
389
- return pre , post , AF_SEARCH_PATH , CUDA_EXISTS
391
+ if AF_PATH is None :
392
+ os .environ ['AF_PATH' ] = AF_SEARCH_PATH
393
+
394
+ return pre , post , AF_SEARCH_PATH , CUDA_FOUND
390
395
391
396
class _clibrary (object ):
392
397
393
398
def __libname (self , name , head = 'af' ):
394
399
libname = self .__pre + head + name + self .__post
395
- libname_full = self .AF_SEARCH_PATH + '/lib/' + libname
400
+ libname_full = self .AF_PATH + '/lib/' + libname
396
401
return (libname , libname_full )
397
402
398
403
def set_unsafe (self , name ):
@@ -405,25 +410,27 @@ def __init__(self):
405
410
406
411
more_info_str = "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information."
407
412
408
- pre , post , AF_SEARCH_PATH , CUDA_EXISTS = _setup ()
413
+ pre , post , AF_PATH , CUDA_FOUND = _setup ()
409
414
410
415
self .__pre = pre
411
416
self .__post = post
412
- self .AF_SEARCH_PATH = AF_SEARCH_PATH
417
+ self .AF_PATH = AF_PATH
418
+ self .CUDA_FOUND = CUDA_FOUND
413
419
414
420
self .__name = None
415
421
416
- self .__clibs = {'cuda' : None ,
417
- 'opencl' : None ,
418
- 'cpu' : None ,
419
- '' : None }
422
+ self .__clibs = {'cuda' : None ,
423
+ 'opencl' : None ,
424
+ 'cpu' : None ,
425
+ 'unified' : None }
420
426
421
- self .__backend_map = {0 : 'default ' ,
427
+ self .__backend_map = {0 : 'unified ' ,
422
428
1 : 'cpu' ,
423
429
2 : 'cuda' ,
424
430
4 : 'opencl' }
425
431
426
432
self .__backend_name_map = {'default' : 0 ,
433
+ 'unified' : 0 ,
427
434
'cpu' : 1 ,
428
435
'cuda' : 2 ,
429
436
'opencl' : 4 }
@@ -442,8 +449,9 @@ def __init__(self):
442
449
for libname in libnames :
443
450
try :
444
451
ct .cdll .LoadLibrary (libname )
445
- self .__clibs [name ] = ct .CDLL (libname )
446
- self .__name = name
452
+ __name = 'unified' if name == '' else name
453
+ self .__clibs [__name ] = ct .CDLL (libname )
454
+ self .__name = __name
447
455
break ;
448
456
except :
449
457
pass
0 commit comments