Skip to content

Commit 597ebe7

Browse files
committed
Attempt to detect FFTW libraries in setup
1 parent 8d91669 commit 597ebe7

File tree

2 files changed

+44
-21
lines changed

2 files changed

+44
-21
lines changed

README.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ and the C-library
6262

6363
Note that *mpi4py* requires a working MPI installation, with the compiler
6464
wrapper *mpicc* on your search path. All of the above dependencies are
65-
available and will be downlaoded through the conda-forge channel if
65+
available and will be downloaded through the conda-forge channel if
6666
conda is used for installation.
6767

6868
For IO you need to install either `h5py <https://www.h5py.org>`_ or
6969
`netCDF4 <http://unidata.github.io/netcdf4-python/>`_ with support for
7070
MPI. These libraries are, unfortunately, not compiled with MPI on
7171
conda-forge. The two libraries are available, though, for both OSX and
72-
linux (`h5py-parallel <https://anaconda.org/spectraldns/h5py-parallel>`_
73-
and `netcdf4-parallel <https://anaconda.org/spectraldns/netcdf4-parallel>`_)
74-
from the `spectralDNS <https://anaconda.org/spectralDNS>`_ channel
75-
on anaconda cloud.
72+
linux from the `spectralDNS <https://anaconda.org/spectralDNS>`_ channel
73+
on anaconda cloud::
74+
75+
conda install -c spectralDNS h5py-parallel netcdf4-parallel

setup.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,68 @@
22

33
import os, sys
44
import shutil
5+
from distutils import ccompiler
56
from setuptools import setup
67
from setuptools.extension import Extension
78
from numpy import get_include
89

910
cwd = os.path.abspath(os.path.dirname(__file__))
1011
fftwdir = os.path.join(cwd, 'mpi4py_fft', 'fftw')
1112

12-
# For now assuming that all precisions are available
13+
include_dirs = [get_include(), os.path.join(sys.prefix, 'include')]
14+
library_dirs = [os.path.join(sys.prefix, 'lib')]
15+
for f in ('FFTW_ROOT', 'FFTW_DIR'):
16+
if f in os.environ['PATH']:
17+
library_dirs.append(os.path.join(os.environ[f], 'lib'))
18+
include_dirs.append(os.path.join(os.environ[f], 'include'))
19+
20+
compiler = ccompiler.new_compiler()
21+
assert compiler.find_library_file(library_dirs, 'fftw3'), 'Cannot find FFTW library!'
22+
has_threads = compiler.find_library_file(library_dirs, 'fftw3_threads')
23+
24+
prec_map = {'float': 'fftwf_', 'double': 'fftw_', 'long double': 'fftwl_'}
1325

14-
prec = {'fftwf_': 'float', 'fftw': 'double', 'fftwl_': 'long double'}
1526
libs = {
16-
'fftwf_': ['m', 'fftw3f', 'fftw3f_threads'],
17-
'fftw_': ['m', 'fftw3', 'fftw3_threads'],
18-
'fftwl_': ['m', 'fftw3l', 'fftw3l_threads']}
27+
'float': ['fftw3f'],
28+
'double': ['fftw3'],
29+
'long double': ['fftw3l']
30+
}
31+
32+
has_prec = ['double']
33+
for d in ('float', 'long double'):
34+
if compiler.find_library_file(library_dirs, libs[d][0]):
35+
has_prec.append(d)
36+
37+
if has_threads:
38+
for d in has_prec:
39+
libs[d].append('_'.join((libs[d][0], 'threads')))
40+
if sys.platform in ('unix', 'darwin'):
41+
libs[d].append('m')
1942

20-
for fl in ('fftw_planxfftn.h', 'fftw_planxfftn.c', 'fftw_xfftn.pyx', 'fftw_xfftn.pxd'):
21-
for p in ('fftwf_', 'fftwl_'):
43+
# Generate files with float and long double if needed
44+
for d in has_prec[1:]:
45+
p = prec_map[d]
46+
for fl in ('fftw_planxfftn.h', 'fftw_planxfftn.c', 'fftw_xfftn.pyx', 'fftw_xfftn.pxd'):
2247
fp = fl.replace('fftw_', p)
2348
shutil.copy(os.path.join(fftwdir, fl), os.path.join(fftwdir, fp))
2449
sedcmd = "sed -i ''" if sys.platform == 'darwin' else "sed -i''"
2550
os.system(sedcmd + " 's/fftw_/{0}/g' {1}".format(p, os.path.join(fftwdir, fp)))
26-
os.system(sedcmd + " 's/double/{0}/g' {1}".format(prec[p], os.path.join(fftwdir, fp)))
51+
os.system(sedcmd + " 's/double/{0}/g' {1}".format(d, os.path.join(fftwdir, fp)))
2752

2853
ext = [Extension("mpi4py_fft.fftw.utilities",
2954
sources=[os.path.join(fftwdir, "utilities.pyx")],
30-
libraries=libs[p],
3155
include_dirs=[get_include(),
32-
os.path.join(sys.prefix, 'include')],
33-
library_dirs=[os.path.join(sys.prefix, 'lib')])]
56+
os.path.join(sys.prefix, 'include')])]
3457

35-
for p in ('fftw_', 'fftwf_', 'fftwl_'):
58+
for d in has_prec:
59+
p = prec_map[d]
3660
ext.append(Extension("mpi4py_fft.fftw.{}xfftn".format(p),
3761
sources=[os.path.join(fftwdir, "{}xfftn.pyx".format(p)),
3862
os.path.join(fftwdir, "{}planxfftn.c".format(p))],
3963
#define_macros=[('NPY_NO_DEPRECATED_API', 'NPY_1_7_API_VERSION')],
40-
libraries=libs[p],
41-
include_dirs=[get_include(),
42-
os.path.join(sys.prefix, 'include')],
43-
library_dirs=[os.path.join(sys.prefix, 'lib')]))
64+
libraries=libs[d],
65+
include_dirs=include_dirs,
66+
library_dirs=library_dirs))
4467

4568
with open("README.rst", "r") as fh:
4669
long_description = fh.read()

0 commit comments

Comments
 (0)