Skip to content

Commit 4d96578

Browse files
committed
More work on setup
1 parent 504abaf commit 4d96578

File tree

3 files changed

+31
-26
lines changed

3 files changed

+31
-26
lines changed

mpi4py_fft/fftw/factory.py

+14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,20 @@
44
from .utilities import FFTW_FORWARD, FFTW_MEASURE
55

66
def get_fftw_lib(dtype):
7+
"""Return compiled fftw module interfacing the FFTW library
8+
9+
Parameters
10+
----------
11+
dtype : dtype
12+
Data precision
13+
14+
Returns
15+
-------
16+
Module or ``None``
17+
Module can be either :mod:`.fftwf_xfftn`, :mod:`.fftw_xfftn` or
18+
:mod:`.fftwl_xfftn`, depending on precision.
19+
"""
20+
721
dtype = np.dtype(dtype).char.upper()
822
if dtype == 'G':
923
try:

mpi4py_fft/fftw/fftw_xfftn.pyx

-1
Original file line numberDiff line numberDiff line change
@@ -292,4 +292,3 @@ cdef class FFT:
292292
output_array *= self._M
293293

294294
return output_array
295-

setup.py

+17-25
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os, sys
44
import shutil
55
from distutils import ccompiler
6+
import six
67
from setuptools import setup
78
from setuptools.extension import Extension
89
from numpy import get_include
@@ -13,36 +14,27 @@
1314
include_dirs = [get_include(), os.path.join(sys.prefix, 'include')]
1415
library_dirs = [os.path.join(sys.prefix, 'lib')]
1516
for f in ('FFTW_ROOT', 'FFTW_DIR'):
16-
if f in os.environ['PATH']:
17+
if f in os.environ:
1718
library_dirs.append(os.path.join(os.environ[f], 'lib'))
1819
include_dirs.append(os.path.join(os.environ[f], 'include'))
1920

21+
prec_map = {'float': 'f', 'double': '', 'long double': 'l'}
2022
compiler = ccompiler.new_compiler()
21-
assert compiler.find_library_file(library_dirs, 'fftw3'), 'Cannot find FFTW library!'
22-
has_threads = compiler.find_library_file(library_dirs, 'fftw3_threads')
2323

24-
prec_map = {'float': 'fftwf_', 'double': 'fftw_', 'long double': 'fftwl_'}
25-
26-
libs = {
27-
'float': ['fftw3f'],
28-
'double': ['fftw3'],
29-
'long double': ['fftw3l']
30-
}
31-
32-
has_prec = ['double']
33-
for d in ('float', 'long double'):
34-
if compiler.find_library_file(library_dirs, libs[d][0]):
35-
has_prec.append(d)
36-
37-
if has_threads:
38-
for d in has_prec:
39-
libs[d].append('_'.join((libs[d][0], 'threads')))
24+
libs = {}
25+
for d in ('float', 'double', 'long double'):
26+
lib = 'fftw3'+prec_map[d]
27+
if compiler.find_library_file(library_dirs, lib):
28+
libs[d] = [lib]
29+
tlib = '_'.join((lib, 'threads'))
30+
if compiler.find_library_file(library_dirs, tlib):
31+
libs[d].append(tlib)
4032
if sys.platform in ('unix', 'darwin'):
4133
libs[d].append('m')
4234

4335
# Generate files with float and long double if needed
44-
for d in has_prec[1:]:
45-
p = prec_map[d]
36+
for d in ('float', 'long double'):
37+
p = 'fftw'+prec_map[d]+'_'
4638
for fl in ('fftw_planxfftn.h', 'fftw_planxfftn.c', 'fftw_xfftn.pyx', 'fftw_xfftn.pxd'):
4739
fp = fl.replace('fftw_', p)
4840
shutil.copy(os.path.join(fftwdir, fl), os.path.join(fftwdir, fp))
@@ -55,13 +47,13 @@
5547
include_dirs=[get_include(),
5648
os.path.join(sys.prefix, 'include')])]
5749

58-
for d in has_prec:
59-
p = prec_map[d]
50+
for d, v in six.iteritems(libs):
51+
p = 'fftw'+prec_map[d]+'_'
6052
ext.append(Extension("mpi4py_fft.fftw.{}xfftn".format(p),
6153
sources=[os.path.join(fftwdir, "{}xfftn.pyx".format(p)),
6254
os.path.join(fftwdir, "{}planxfftn.c".format(p))],
6355
#define_macros=[('NPY_NO_DEPRECATED_API', 'NPY_1_7_API_VERSION')],
64-
libraries=libs[d],
56+
libraries=v,
6557
include_dirs=include_dirs,
6658
library_dirs=library_dirs))
6759

@@ -92,5 +84,5 @@
9284
],
9385
ext_modules=ext,
9486
install_requires=["mpi4py", "numpy", "six"],
95-
setup_requires=["setuptools>=18.0", "cython>=0.25"]
87+
setup_requires=["setuptools>=18.0", "cython>=0.25", "six"]
9688
)

0 commit comments

Comments
 (0)