@@ -315,31 +315,85 @@ class BACKEND(_Enum):
315
315
CUDA = _Enum_Type (2 )
316
316
OPENCL = _Enum_Type (4 )
317
317
318
- class _clibrary (object ):
318
+ def _setup ():
319
+ import platform
320
+ import os
321
+
322
+ platform_name = platform .system ()
323
+
324
+ try :
325
+ AF_SEARCH_PATH = os .environ ['AF_PATH' ]
326
+ except :
327
+ AF_SEARCH_PATH = None
328
+ pass
329
+
330
+ try :
331
+ CUDA_PATH = os .environ ['CUDA_PATH' ]
332
+ except :
333
+ CUDA_PATH = None
334
+ pass
335
+
336
+ CUDA_EXISTS = False
337
+
338
+ assert (len (platform_name ) >= 3 )
339
+ if platform_name == 'Windows' or platform_name [:3 ] == 'CYG' :
340
+
341
+ ## Windows specific setup
342
+ pre = ''
343
+ post = '.dll'
344
+ if platform_name == "Windows" :
345
+ '''
346
+ Supressing crashes caused by missing dlls
347
+ http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
348
+ https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
349
+ '''
350
+ ct .windll .kernel32 .SetErrorMode (0x0001 | 0x0002 )
351
+
352
+ if AF_SEARCH_PATH is None :
353
+ AF_SEARCH_PATH = "C:/Program Files/ArrayFire/v3/"
354
+
355
+ if CUDA_PATH is not None :
356
+ CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/bin' ) and os .path .isdir (CUDA_PATH + '/nvvm/bin/' )
357
+
358
+ elif platform_name == 'Darwin' :
359
+
360
+ ## OSX specific setup
361
+ pre = 'lib'
362
+ post = '.dylib'
363
+
364
+ if AF_SEARCH_PATH is None :
365
+ AF_SEARCH_PATH = '/usr/local/'
366
+
367
+ if CUDA_PATH is None :
368
+ CUDA_PATH = '/usr/local/cuda/'
369
+
370
+ CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/lib' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib' )
319
371
320
- def __libname (self , name ):
321
- platform_name = platform .system ()
322
- assert (len (platform_name ) >= 3 )
323
-
324
- libname = 'libaf' + name
325
- if platform_name == 'Linux' :
326
- libname += '.so'
327
- elif platform_name == 'Darwin' :
328
- libname += '.dylib'
329
- elif platform_name == "Windows" or platform_name [:3 ] == "CYG" :
330
- libname += '.dll'
331
- libname = libname [3 :] # remove 'lib'
332
- if platform_name == "Windows" :
333
- '''
334
- Supressing crashes caused by missing dlls
335
- http://stackoverflow.com/questions/8347266/missing-dll-print-message-instead-of-launching-a-popup
336
- https://msdn.microsoft.com/en-us/library/windows/desktop/ms680621.aspx
337
- '''
338
- ct .windll .kernel32 .SetErrorMode (0x0001 | 0x0002 );
372
+ elif platform_name == 'Linux' :
373
+ pre = 'lib'
374
+ post = '.so'
375
+
376
+ if AF_SEARCH_PATH is None :
377
+ AF_SEARCH_PATH = '/opt/arrayfire-3/'
378
+
379
+ if CUDA_PATH is None :
380
+ CUDA_PATH = '/usr/local/cuda/'
381
+
382
+ if platform .architecture ()[0 ][:2 ] == 64 :
383
+ CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/lib64' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib64' )
339
384
else :
340
- raise OSError (platform_name + ' not supported' )
385
+ CUDA_EXISTS = os .path .isdir (CUDA_PATH + '/lib' ) and os .path .isdir (CUDA_PATH + '/nvvm/lib' )
386
+ else :
387
+ raise OSError (platform_name + ' not supported' )
341
388
342
- return libname
389
+ return pre , post , AF_SEARCH_PATH , CUDA_EXISTS
390
+
391
+ class _clibrary (object ):
392
+
393
+ def __libname (self , name , head = 'af' ):
394
+ libname = self .__pre + head + name + self .__post
395
+ libname_full = self .AF_SEARCH_PATH + '/lib/' + libname
396
+ return (libname , libname_full )
343
397
344
398
def set_unsafe (self , name ):
345
399
lib = self .__clibs [name ]
@@ -348,6 +402,15 @@ def set_unsafe(self, name):
348
402
self .__name = name
349
403
350
404
def __init__ (self ):
405
+
406
+ more_info_str = "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information."
407
+
408
+ pre , post , AF_SEARCH_PATH , CUDA_EXISTS = _setup ()
409
+
410
+ self .__pre = pre
411
+ self .__post = post
412
+ self .AF_SEARCH_PATH = AF_SEARCH_PATH
413
+
351
414
self .__name = None
352
415
353
416
self .__clibs = {'cuda' : None ,
@@ -365,18 +428,29 @@ def __init__(self):
365
428
'cuda' : 2 ,
366
429
'opencl' : 4 }
367
430
368
- # Iterate in reverse order of preference
369
- for name in ('cpu' , 'opencl' , 'cuda' , '' ):
431
+ # Try to pre-load forge library if it exists
432
+ libnames = self .__libname ('forge' , '' )
433
+ for libname in libnames :
370
434
try :
371
- libname = self .__libname (name )
372
435
ct .cdll .LoadLibrary (libname )
373
- self .__clibs [name ] = ct .CDLL (libname )
374
- self .__name = name
375
436
except :
376
437
pass
377
438
439
+ # Iterate in reverse order of preference
440
+ for name in ('cpu' , 'opencl' , 'cuda' , '' ):
441
+ libnames = self .__libname (name )
442
+ for libname in libnames :
443
+ try :
444
+ ct .cdll .LoadLibrary (libname )
445
+ self .__clibs [name ] = ct .CDLL (libname )
446
+ self .__name = name
447
+ break ;
448
+ except :
449
+ pass
450
+
378
451
if (self .__name is None ):
379
- raise RuntimeError ("Could not load any ArrayFire libraries" )
452
+ raise RuntimeError ("Could not load any ArrayFire libraries.\n " +
453
+ more_info_str )
380
454
381
455
def get_id (self , name ):
382
456
return self .__backend_name_map [name ]
0 commit comments