Skip to content

Commit 57a4d4e

Browse files
committed
Fixing import export wisdom for MPI
1 parent 9d85e33 commit 57a4d4e

File tree

4 files changed

+62
-24
lines changed

4 files changed

+62
-24
lines changed

mpi4py_fft/fftw/factory.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import six
22
import numpy as np
3+
from mpi4py import MPI
34
from . import fftw_xfftn
45
try:
56
from . import fftwf_xfftn
@@ -16,6 +17,8 @@
1617
if v is not None:
1718
fftlib[k] = v
1819

20+
comm = MPI.COMM_WORLD
21+
1922
def get_planned_FFT(input_array, output_array, axes=(-1,), kind=FFTW_FORWARD,
2023
threads=1, flags=(FFTW_MEASURE,), normalize=1):
2124
"""Return instance of transform class
@@ -84,17 +87,20 @@ def export_wisdom(filename):
8487
Note
8588
----
8689
Wisdom is stored for all precisions available: float, double and long
87-
double, using, respectively, prefix ``F_``, ``D_`` and ``G_``. Wisdom is
88-
imported using :func:`.import_wisdom`.
90+
double, using, respectively, prefix ``Fn_``, ``Dn_`` and ``Gn_``, where
91+
n is the rank of the processor.
92+
Wisdom is imported using :func:`.import_wisdom`, which must be called
93+
with the same MPI configuration as used with :func:`.export_wisdom`.
8994
9095
See also
9196
--------
9297
:func:`.import_wisdom`
9398
9499
"""
100+
rank = str(comm.Get_rank())
95101
e = []
96102
for key, lib in six.iteritems(fftlib):
97-
e.append(lib.export_wisdom(bytearray(key+'_'+filename, 'utf-8')))
103+
e.append(lib.export_wisdom(bytearray(key+rank+'_'+filename, 'utf-8')))
98104
assert np.all(np.array(e) == 1), "Not able to export wisdom {}".format(filename)
99105

100106
def import_wisdom(filename):
@@ -108,17 +114,22 @@ def import_wisdom(filename):
108114
Note
109115
----
110116
Wisdom is imported for all available precisions: float, double and long
111-
double, using, respectively, prefix ``F_``, ``D_`` and ``G_``. Wisdom is
112-
exported using :func:`.export_wisdom`.
117+
double, using, respectively, prefix ``Fn_``, ``Dn_`` and ``Gn_``, where
118+
n is the rank of the processor.
119+
Wisdom is exported using :func:`.export_wisdom`.
120+
Note that importing wisdom only works when using the same MPI configuration
121+
as used with :func:`.export_wisdom`.
122+
113123
114124
See also
115125
--------
116126
:func:`.export_wisdom`
117127
118128
"""
129+
rank = str(comm.Get_rank())
119130
e = []
120131
for key, lib in six.iteritems(fftlib):
121-
e.append(lib.import_wisdom(bytearray(key+'_'+filename, 'utf-8')))
132+
e.append(lib.import_wisdom(bytearray(key+rank+'_'+filename, 'utf-8')))
122133
assert np.all(np.array(e) == 1), "Not able to import wisdom {}".format(filename)
123134

124135
def forget_wisdom():

tests/test_fftw.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from __future__ import print_function
2-
import six
32
from time import time
3+
import six
44
import numpy as np
5-
import scipy
6-
import pyfftw
5+
from scipy.fftpack import dctn as scipy_dctn
6+
from scipy.fftpack import dstn as scipy_dstn
7+
import scipy.fftpack
78
from mpi4py_fft import fftw
89

10+
has_pyfftw = True
11+
try:
12+
import pyfftw
13+
except ImportError:
14+
has_pyfftw = False
15+
916
abstol = dict(f=5e-4, d=1e-12, g=1e-14)
1017

1118
kinds = {'dst4': fftw.FFTW_RODFT11, # no scipy to compare with
@@ -46,26 +53,27 @@ def test_fftw():
4653
outshape = list(shape)
4754
outshape[axes[-1]] = shape[axes[-1]]//2+1
4855
output_array = fftw.aligned(outshape, dtype=typecode.upper())
49-
oa = output_array if typecode=='d' else None # Test for both types of signature
56+
oa = output_array if typecode == 'd' else None # Test for both types of signature
5057
rfftn = fftw.rfftn(input_array, None, axes, threads, fflags, output_array=oa)
5158
A = np.random.random(shape).astype(typecode)
5259
input_array[:] = A
5360
B = rfftn()
5461
assert id(B) == id(rfftn.output_array)
55-
B2 = pyfftw.interfaces.numpy_fft.rfftn(input_array, axes=axes)
56-
assert allclose(B, B2), np.linalg.norm(B-B2)
57-
ia = input_array if typecode=='d' else None
62+
if has_pyfftw:
63+
B2 = pyfftw.interfaces.numpy_fft.rfftn(input_array, axes=axes)
64+
assert allclose(B, B2), np.linalg.norm(B-B2)
65+
ia = input_array if typecode == 'd' else None
5866
sa = np.take(input_array.shape, axes) if shape[axes[-1]] % 2 == 1 else None
5967
irfftn = fftw.irfftn(output_array, sa, axes, threads, iflags, output_array=ia)
60-
irfftn.input_array[...] = B2
68+
irfftn.input_array[...] = B
6169
A2 = irfftn(normalize=True)
6270
assert allclose(A, A2), np.linalg.norm(A-A2)
6371
hfftn = fftw.hfftn(output_array, sa, axes, threads, fflags, output_array=ia)
64-
hfftn.input_array[...] = B2
72+
hfftn.input_array[...] = B
6573
AC = hfftn().copy()
6674
ihfftn = fftw.ihfftn(input_array, None, axes, threads, iflags, output_array=oa)
6775
A2 = ihfftn(AC, implicit=False, normalize=True)
68-
assert allclose(A2, B2), print(np.linalg.norm(A2-B2))
76+
assert allclose(A2, B), print(np.linalg.norm(A2-B))
6977

7078
# c2c
7179
input_array = fftw.aligned(shape, dtype=typecode.upper())
@@ -79,8 +87,9 @@ def test_fftw():
7987
ifftn.input_array[...] = D
8088
C2 = ifftn(normalize=True)
8189
assert allclose(C, C2), np.linalg.norm(C-C2)
82-
D2 = pyfftw.interfaces.numpy_fft.fftn(C, axes=axes)
83-
assert allclose(D, D2), np.linalg.norm(D-D2)
90+
if has_pyfftw:
91+
D2 = pyfftw.interfaces.numpy_fft.fftn(C, axes=axes)
92+
assert allclose(D, D2), np.linalg.norm(D-D2)
8493

8594
# r2r
8695
input_array = fftw.aligned(shape, dtype=typecode)
@@ -93,7 +102,7 @@ def test_fftw():
93102
A2 = idct(B, implicit=True, normalize=True)
94103
assert allclose(A, A2), np.linalg.norm(A-A2)
95104
if typecode is not 'g' and not type is 4:
96-
B2 = scipy.fftpack.dctn(A, axes=axes, type=type)
105+
B2 = scipy_dctn(A, axes=axes, type=type)
97106
assert allclose(B, B2), np.linalg.norm(B-B2)
98107

99108
dst = fftw.dstn(input_array, None, axes, type, threads, fflags, output_array=oa)
@@ -102,7 +111,7 @@ def test_fftw():
102111
A2 = idst(B, implicit=True, normalize=True)
103112
assert allclose(A, A2), np.linalg.norm(A-A2)
104113
if typecode is not 'g' and not type is 4:
105-
B2 = scipy.fftpack.dstn(A, axes=axes, type=type)
114+
B2 = scipy_dstn(A, axes=axes, type=type)
106115
assert allclose(B, B2), np.linalg.norm(B-B2)
107116

108117
# Different r2r transforms along all axes. Just pick
@@ -128,8 +137,8 @@ def test_fftw():
128137

129138
def test_wisdom():
130139
# Test a simple export/import call
131-
fftw.export_wisdom('wisdom.dat')
132-
fftw.import_wisdom('wisdom.dat')
140+
fftw.export_wisdom('newwisdom.dat')
141+
fftw.import_wisdom('newwisdom.dat')
133142
fftw.forget_wisdom()
134143

135144
def test_timelimit():
@@ -151,4 +160,3 @@ def test_timelimit():
151160
test_fftw()
152161
test_wisdom()
153162
test_timelimit()
154-

tests/test_libfft.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
import numpy as np
44
from mpi4py_fft.libfft import FFT
55

6+
has_pyfftw = True
7+
try:
8+
import pyfftw
9+
except ImportError:
10+
has_pyfftw = False
11+
612
abstol = dict(f=5e-5, d=1e-14, g=1e-15)
713

814
def allclose(a, b):
@@ -17,6 +23,8 @@ def test_libfft():
1723
types = 'fdgFDG'
1824

1925
for use_pyfftw in (False, True):
26+
if has_pyfftw is False and use_pyfftw is True:
27+
continue
2028
t0 = 0
2129
for typecode in types:
2230
for dim in dims:
@@ -49,6 +57,8 @@ def test_libfft():
4957
# difficult to initialize. We solve this problem by making one extra
5058
# transform
5159
for use_pyfftw in (True, False):
60+
if has_pyfftw is False and use_pyfftw is True:
61+
continue
5262
for padding in (1.5, 2.0):
5363
for typecode in types:
5464
for dim in dims:

tests/test_mpifft.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
from __future__ import print_function
22
import functools
33
import numpy as np
4-
import pyfftw
54
from mpi4py import MPI
65
from mpi4py_fft.mpifft import PFFT
76
from mpi4py_fft.pencil import Subcomm
87
from mpi4py_fft import fftw, Function
98

9+
has_pyfftw = True
10+
try:
11+
import pyfftw
12+
except ImportError:
13+
has_pyfftw = False
14+
1015
abstol = dict(f=0.1, d=2e-10, g=1e-10)
1116

1217
def allclose(a, b):
@@ -64,6 +69,8 @@ def test_mpifft():
6469
padding = False
6570
for collapse in (True, False):
6671
for use_pyfftw in (False, True):
72+
if has_pyfftw is False and use_pyfftw is True:
73+
continue
6774
transforms = None
6875
if dim < 3:
6976
allaxes = [None, (-1,), (-2,),
@@ -159,6 +166,8 @@ def test_mpifft():
159166

160167
padding = [1.5]*len(shape)
161168
for use_pyfftw in (True, False):
169+
if has_pyfftw is False and use_pyfftw is True:
170+
continue
162171
if dim < 3:
163172
allaxes = [None, (-1,), (-2,),
164173
(-1, -2,), (-2, -1),

0 commit comments

Comments
 (0)