1
1
from __future__ import print_function
2
- import six
3
2
from time import time
3
+ import six
4
4
import numpy as np
5
- import scipy
6
- import pyfftw
5
+ from scipy .fftpack import dctn as scipy_dctn
6
+ from scipy .fftpack import dstn as scipy_dstn
7
+ import scipy .fftpack
7
8
from mpi4py_fft import fftw
8
9
10
+ has_pyfftw = True
11
+ try :
12
+ import pyfftw
13
+ except ImportError :
14
+ has_pyfftw = False
15
+
9
16
abstol = dict (f = 5e-4 , d = 1e-12 , g = 1e-14 )
10
17
11
18
kinds = {'dst4' : fftw .FFTW_RODFT11 , # no scipy to compare with
@@ -46,26 +53,27 @@ def test_fftw():
46
53
outshape = list (shape )
47
54
outshape [axes [- 1 ]] = shape [axes [- 1 ]]// 2 + 1
48
55
output_array = fftw .aligned (outshape , dtype = typecode .upper ())
49
- oa = output_array if typecode == 'd' else None # Test for both types of signature
56
+ oa = output_array if typecode == 'd' else None # Test for both types of signature
50
57
rfftn = fftw .rfftn (input_array , None , axes , threads , fflags , output_array = oa )
51
58
A = np .random .random (shape ).astype (typecode )
52
59
input_array [:] = A
53
60
B = rfftn ()
54
61
assert id (B ) == id (rfftn .output_array )
55
- B2 = pyfftw .interfaces .numpy_fft .rfftn (input_array , axes = axes )
56
- assert allclose (B , B2 ), np .linalg .norm (B - B2 )
57
- ia = input_array if typecode == 'd' else None
62
+ if has_pyfftw :
63
+ B2 = pyfftw .interfaces .numpy_fft .rfftn (input_array , axes = axes )
64
+ assert allclose (B , B2 ), np .linalg .norm (B - B2 )
65
+ ia = input_array if typecode == 'd' else None
58
66
sa = np .take (input_array .shape , axes ) if shape [axes [- 1 ]] % 2 == 1 else None
59
67
irfftn = fftw .irfftn (output_array , sa , axes , threads , iflags , output_array = ia )
60
- irfftn .input_array [...] = B2
68
+ irfftn .input_array [...] = B
61
69
A2 = irfftn (normalize = True )
62
70
assert allclose (A , A2 ), np .linalg .norm (A - A2 )
63
71
hfftn = fftw .hfftn (output_array , sa , axes , threads , fflags , output_array = ia )
64
- hfftn .input_array [...] = B2
72
+ hfftn .input_array [...] = B
65
73
AC = hfftn ().copy ()
66
74
ihfftn = fftw .ihfftn (input_array , None , axes , threads , iflags , output_array = oa )
67
75
A2 = ihfftn (AC , implicit = False , normalize = True )
68
- assert allclose (A2 , B2 ), print (np .linalg .norm (A2 - B2 ))
76
+ assert allclose (A2 , B ), print (np .linalg .norm (A2 - B ))
69
77
70
78
# c2c
71
79
input_array = fftw .aligned (shape , dtype = typecode .upper ())
@@ -79,8 +87,9 @@ def test_fftw():
79
87
ifftn .input_array [...] = D
80
88
C2 = ifftn (normalize = True )
81
89
assert allclose (C , C2 ), np .linalg .norm (C - C2 )
82
- D2 = pyfftw .interfaces .numpy_fft .fftn (C , axes = axes )
83
- assert allclose (D , D2 ), np .linalg .norm (D - D2 )
90
+ if has_pyfftw :
91
+ D2 = pyfftw .interfaces .numpy_fft .fftn (C , axes = axes )
92
+ assert allclose (D , D2 ), np .linalg .norm (D - D2 )
84
93
85
94
# r2r
86
95
input_array = fftw .aligned (shape , dtype = typecode )
@@ -93,7 +102,7 @@ def test_fftw():
93
102
A2 = idct (B , implicit = True , normalize = True )
94
103
assert allclose (A , A2 ), np .linalg .norm (A - A2 )
95
104
if typecode is not 'g' and not type is 4 :
96
- B2 = scipy . fftpack . dctn (A , axes = axes , type = type )
105
+ B2 = scipy_dctn (A , axes = axes , type = type )
97
106
assert allclose (B , B2 ), np .linalg .norm (B - B2 )
98
107
99
108
dst = fftw .dstn (input_array , None , axes , type , threads , fflags , output_array = oa )
@@ -102,7 +111,7 @@ def test_fftw():
102
111
A2 = idst (B , implicit = True , normalize = True )
103
112
assert allclose (A , A2 ), np .linalg .norm (A - A2 )
104
113
if typecode is not 'g' and not type is 4 :
105
- B2 = scipy . fftpack . dstn (A , axes = axes , type = type )
114
+ B2 = scipy_dstn (A , axes = axes , type = type )
106
115
assert allclose (B , B2 ), np .linalg .norm (B - B2 )
107
116
108
117
# Different r2r transforms along all axes. Just pick
@@ -128,8 +137,8 @@ def test_fftw():
128
137
129
138
def test_wisdom ():
130
139
# Test a simple export/import call
131
- fftw .export_wisdom ('wisdom .dat' )
132
- fftw .import_wisdom ('wisdom .dat' )
140
+ fftw .export_wisdom ('newwisdom .dat' )
141
+ fftw .import_wisdom ('newwisdom .dat' )
133
142
fftw .forget_wisdom ()
134
143
135
144
def test_timelimit ():
@@ -151,4 +160,3 @@ def test_timelimit():
151
160
test_fftw ()
152
161
test_wisdom ()
153
162
test_timelimit ()
154
-
0 commit comments