Skip to content

Commit c926979

Browse files
committed
Improving interface for possibly missing precisions
1 parent 57a4d4e commit c926979

File tree

4 files changed

+32
-16
lines changed

4 files changed

+32
-16
lines changed

conf/meta.yaml

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ requirements:
3030
- pyfftw
3131
- six
3232
- nomkl
33-
- scipy
33+
- scipy >=1.0.0
3434
- h5py-parallel
3535
- netcdf4-parallel
3636

3737
test:
3838
requires:
3939
- coverage # [py3k]
4040
- codecov # [py3k]
41+
- scipy >=1.0.0
4142

4243
source_files:
4344
- tests

mpi4py_fft/fftw/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .xfftn import *
22
from .factory import get_planned_FFT, export_wisdom, import_wisdom, \
3-
forget_wisdom, cleanup, set_timelimit
4-
3+
forget_wisdom, cleanup, set_timelimit, get_fftw_lib

mpi4py_fft/fftw/factory.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,34 @@
11
import six
22
import numpy as np
33
from mpi4py import MPI
4-
from . import fftw_xfftn
5-
try:
6-
from . import fftwf_xfftn
7-
except ImportError:
8-
fftwf_xfftn = None
9-
try:
10-
from . import fftwl_xfftn
11-
except ImportError:
12-
fftwl_xfftn = None
134
from .utilities import FFTW_FORWARD, FFTW_MEASURE
145

6+
def get_fftw_lib(dtype):
7+
dtype = np.dtype(dtype).char.upper()
8+
if dtype == 'G':
9+
try:
10+
from . import fftwl_xfftn
11+
return fftwl_xfftn
12+
except ImportError:
13+
return None
14+
elif dtype == 'D':
15+
try:
16+
from . import fftw_xfftn
17+
return fftw_xfftn
18+
except ImportError:
19+
return None
20+
elif dtype == 'F':
21+
try:
22+
from . import fftwf_xfftn
23+
return fftwf_xfftn
24+
except ImportError:
25+
return None
26+
1527
fftlib = {}
16-
for k, v in zip(('F', 'D', 'G'), (fftwf_xfftn, fftw_xfftn, fftwl_xfftn)):
17-
if v is not None:
18-
fftlib[k] = v
28+
for t in 'fdg':
29+
lib = get_fftw_lib(t)
30+
if lib is not None:
31+
fftlib[t.upper()] = lib
1932

2033
comm = MPI.COMM_WORLD
2134

tests/test_fftw.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ def test_fftw():
3535

3636
dims = (1, 2, 3)
3737
sizes = (7, 8, 10)
38-
types = 'fdg'
38+
types = ''
39+
for t in 'fdg':
40+
if fftw.get_fftw_lib(t):
41+
types += t
3942
fflags = (fftw.FFTW_ESTIMATE, fftw.FFTW_DESTROY_INPUT)
4043
iflags = (fftw.FFTW_ESTIMATE, fftw.FFTW_DESTROY_INPUT)
4144

0 commit comments

Comments
 (0)