|
3 | 3 | import os, sys
|
4 | 4 | import shutil
|
5 | 5 | from distutils import ccompiler
|
| 6 | +import six |
6 | 7 | from setuptools import setup
|
7 | 8 | from setuptools.extension import Extension
|
8 | 9 | from numpy import get_include
|
|
13 | 14 | include_dirs = [get_include(), os.path.join(sys.prefix, 'include')]
|
14 | 15 | library_dirs = [os.path.join(sys.prefix, 'lib')]
|
15 | 16 | for f in ('FFTW_ROOT', 'FFTW_DIR'):
|
16 |
| - if f in os.environ['PATH']: |
| 17 | + if f in os.environ: |
17 | 18 | library_dirs.append(os.path.join(os.environ[f], 'lib'))
|
18 | 19 | include_dirs.append(os.path.join(os.environ[f], 'include'))
|
19 | 20 |
|
| 21 | +prec_map = {'float': 'f', 'double': '', 'long double': 'l'} |
20 | 22 | compiler = ccompiler.new_compiler()
|
21 |
| -assert compiler.find_library_file(library_dirs, 'fftw3'), 'Cannot find FFTW library!' |
22 |
| -has_threads = compiler.find_library_file(library_dirs, 'fftw3_threads') |
23 | 23 |
|
24 |
| -prec_map = {'float': 'fftwf_', 'double': 'fftw_', 'long double': 'fftwl_'} |
25 |
| - |
26 |
| -libs = { |
27 |
| - 'float': ['fftw3f'], |
28 |
| - 'double': ['fftw3'], |
29 |
| - 'long double': ['fftw3l'] |
30 |
| - } |
31 |
| - |
32 |
| -has_prec = ['double'] |
33 |
| -for d in ('float', 'long double'): |
34 |
| - if compiler.find_library_file(library_dirs, libs[d][0]): |
35 |
| - has_prec.append(d) |
36 |
| - |
37 |
| -if has_threads: |
38 |
| - for d in has_prec: |
39 |
| - libs[d].append('_'.join((libs[d][0], 'threads'))) |
| 24 | +libs = {} |
| 25 | +for d in ('float', 'double', 'long double'): |
| 26 | + lib = 'fftw3'+prec_map[d] |
| 27 | + if compiler.find_library_file(library_dirs, lib): |
| 28 | + libs[d] = [lib] |
| 29 | + tlib = '_'.join((lib, 'threads')) |
| 30 | + if compiler.find_library_file(library_dirs, tlib): |
| 31 | + libs[d].append(tlib) |
40 | 32 | if sys.platform in ('unix', 'darwin'):
|
41 | 33 | libs[d].append('m')
|
42 | 34 |
|
43 | 35 | # Generate files with float and long double if needed
|
44 |
| -for d in has_prec[1:]: |
45 |
| - p = prec_map[d] |
| 36 | +for d in ('float', 'long double'): |
| 37 | + p = 'fftw'+prec_map[d]+'_' |
46 | 38 | for fl in ('fftw_planxfftn.h', 'fftw_planxfftn.c', 'fftw_xfftn.pyx', 'fftw_xfftn.pxd'):
|
47 | 39 | fp = fl.replace('fftw_', p)
|
48 | 40 | shutil.copy(os.path.join(fftwdir, fl), os.path.join(fftwdir, fp))
|
|
55 | 47 | include_dirs=[get_include(),
|
56 | 48 | os.path.join(sys.prefix, 'include')])]
|
57 | 49 |
|
58 |
| -for d in has_prec: |
59 |
| - p = prec_map[d] |
| 50 | +for d, v in six.iteritems(libs): |
| 51 | + p = 'fftw'+prec_map[d]+'_' |
60 | 52 | ext.append(Extension("mpi4py_fft.fftw.{}xfftn".format(p),
|
61 | 53 | sources=[os.path.join(fftwdir, "{}xfftn.pyx".format(p)),
|
62 | 54 | os.path.join(fftwdir, "{}planxfftn.c".format(p))],
|
63 | 55 | #define_macros=[('NPY_NO_DEPRECATED_API', 'NPY_1_7_API_VERSION')],
|
64 |
| - libraries=libs[d], |
| 56 | + libraries=v, |
65 | 57 | include_dirs=include_dirs,
|
66 | 58 | library_dirs=library_dirs))
|
67 | 59 |
|
|
92 | 84 | ],
|
93 | 85 | ext_modules=ext,
|
94 | 86 | install_requires=["mpi4py", "numpy", "six"],
|
95 |
| - setup_requires=["setuptools>=18.0", "cython>=0.25"] |
| 87 | + setup_requires=["setuptools>=18.0", "cython>=0.25", "six"] |
96 | 88 | )
|
0 commit comments