1
1
from __future__ import annotations
2
2
3
- from typing import TYPE_CHECKING , Union , Optional , Literal
3
+ from collections .abc import Sequence
4
+ from typing import Union , Optional , Literal
4
5
5
- if TYPE_CHECKING :
6
- from ._typing import Device , ndarray , DType
7
- from collections .abc import Sequence
6
+ from ._typing import Device , Array , DType , Namespace
8
7
9
8
# Note: NumPy fft functions improperly upcast float32 and complex64 to
10
9
# complex128, which is why we require wrapping them all here.
11
10
12
11
def fft (
13
- x : ndarray ,
12
+ x : Array ,
14
13
/ ,
15
- xp ,
14
+ xp : Namespace ,
16
15
* ,
17
16
n : Optional [int ] = None ,
18
17
axis : int = - 1 ,
19
18
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
20
- ) -> ndarray :
19
+ ) -> Array :
21
20
res = xp .fft .fft (x , n = n , axis = axis , norm = norm )
22
21
if x .dtype in [xp .float32 , xp .complex64 ]:
23
22
return res .astype (xp .complex64 )
24
23
return res
25
24
26
25
def ifft (
27
- x : ndarray ,
26
+ x : Array ,
28
27
/ ,
29
- xp ,
28
+ xp : Namespace ,
30
29
* ,
31
30
n : Optional [int ] = None ,
32
31
axis : int = - 1 ,
33
32
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
34
- ) -> ndarray :
33
+ ) -> Array :
35
34
res = xp .fft .ifft (x , n = n , axis = axis , norm = norm )
36
35
if x .dtype in [xp .float32 , xp .complex64 ]:
37
36
return res .astype (xp .complex64 )
38
37
return res
39
38
40
39
def fftn (
41
- x : ndarray ,
40
+ x : Array ,
42
41
/ ,
43
- xp ,
42
+ xp : Namespace ,
44
43
* ,
45
44
s : Sequence [int ] = None ,
46
45
axes : Sequence [int ] = None ,
47
46
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
48
- ) -> ndarray :
47
+ ) -> Array :
49
48
res = xp .fft .fftn (x , s = s , axes = axes , norm = norm )
50
49
if x .dtype in [xp .float32 , xp .complex64 ]:
51
50
return res .astype (xp .complex64 )
52
51
return res
53
52
54
53
def ifftn (
55
- x : ndarray ,
54
+ x : Array ,
56
55
/ ,
57
- xp ,
56
+ xp : Namespace ,
58
57
* ,
59
58
s : Sequence [int ] = None ,
60
59
axes : Sequence [int ] = None ,
61
60
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
62
- ) -> ndarray :
61
+ ) -> Array :
63
62
res = xp .fft .ifftn (x , s = s , axes = axes , norm = norm )
64
63
if x .dtype in [xp .float32 , xp .complex64 ]:
65
64
return res .astype (xp .complex64 )
66
65
return res
67
66
68
67
def rfft (
69
- x : ndarray ,
68
+ x : Array ,
70
69
/ ,
71
- xp ,
70
+ xp : Namespace ,
72
71
* ,
73
72
n : Optional [int ] = None ,
74
73
axis : int = - 1 ,
75
74
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
76
- ) -> ndarray :
75
+ ) -> Array :
77
76
res = xp .fft .rfft (x , n = n , axis = axis , norm = norm )
78
77
if x .dtype == xp .float32 :
79
78
return res .astype (xp .complex64 )
80
79
return res
81
80
82
81
def irfft (
83
- x : ndarray ,
82
+ x : Array ,
84
83
/ ,
85
- xp ,
84
+ xp : Namespace ,
86
85
* ,
87
86
n : Optional [int ] = None ,
88
87
axis : int = - 1 ,
89
88
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
90
- ) -> ndarray :
89
+ ) -> Array :
91
90
res = xp .fft .irfft (x , n = n , axis = axis , norm = norm )
92
91
if x .dtype == xp .complex64 :
93
92
return res .astype (xp .float32 )
94
93
return res
95
94
96
95
def rfftn (
97
- x : ndarray ,
96
+ x : Array ,
98
97
/ ,
99
- xp ,
98
+ xp : Namespace ,
100
99
* ,
101
100
s : Sequence [int ] = None ,
102
101
axes : Sequence [int ] = None ,
103
102
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
104
- ) -> ndarray :
103
+ ) -> Array :
105
104
res = xp .fft .rfftn (x , s = s , axes = axes , norm = norm )
106
105
if x .dtype == xp .float32 :
107
106
return res .astype (xp .complex64 )
108
107
return res
109
108
110
109
def irfftn (
111
- x : ndarray ,
110
+ x : Array ,
112
111
/ ,
113
- xp ,
112
+ xp : Namespace ,
114
113
* ,
115
114
s : Sequence [int ] = None ,
116
115
axes : Sequence [int ] = None ,
117
116
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
118
- ) -> ndarray :
117
+ ) -> Array :
119
118
res = xp .fft .irfftn (x , s = s , axes = axes , norm = norm )
120
119
if x .dtype == xp .complex64 :
121
120
return res .astype (xp .float32 )
122
121
return res
123
122
124
123
def hfft (
125
- x : ndarray ,
124
+ x : Array ,
126
125
/ ,
127
- xp ,
126
+ xp : Namespace ,
128
127
* ,
129
128
n : Optional [int ] = None ,
130
129
axis : int = - 1 ,
131
130
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
132
- ) -> ndarray :
131
+ ) -> Array :
133
132
res = xp .fft .hfft (x , n = n , axis = axis , norm = norm )
134
133
if x .dtype in [xp .float32 , xp .complex64 ]:
135
134
return res .astype (xp .float32 )
136
135
return res
137
136
138
137
def ihfft (
139
- x : ndarray ,
138
+ x : Array ,
140
139
/ ,
141
- xp ,
140
+ xp : Namespace ,
142
141
* ,
143
142
n : Optional [int ] = None ,
144
143
axis : int = - 1 ,
145
144
norm : Literal ["backward" , "ortho" , "forward" ] = "backward" ,
146
- ) -> ndarray :
145
+ ) -> Array :
147
146
res = xp .fft .ihfft (x , n = n , axis = axis , norm = norm )
148
147
if x .dtype in [xp .float32 , xp .complex64 ]:
149
148
return res .astype (xp .complex64 )
@@ -152,12 +151,12 @@ def ihfft(
152
151
def fftfreq (
153
152
n : int ,
154
153
/ ,
155
- xp ,
154
+ xp : Namespace ,
156
155
* ,
157
156
d : float = 1.0 ,
158
157
dtype : Optional [DType ] = None ,
159
- device : Optional [Device ] = None
160
- ) -> ndarray :
158
+ device : Optional [Device ] = None ,
159
+ ) -> Array :
161
160
if device not in ["cpu" , None ]:
162
161
raise ValueError (f"Unsupported device { device !r} " )
163
162
res = xp .fft .fftfreq (n , d = d )
@@ -168,23 +167,27 @@ def fftfreq(
168
167
def rfftfreq (
169
168
n : int ,
170
169
/ ,
171
- xp ,
170
+ xp : Namespace ,
172
171
* ,
173
172
d : float = 1.0 ,
174
173
dtype : Optional [DType ] = None ,
175
- device : Optional [Device ] = None
176
- ) -> ndarray :
174
+ device : Optional [Device ] = None ,
175
+ ) -> Array :
177
176
if device not in ["cpu" , None ]:
178
177
raise ValueError (f"Unsupported device { device !r} " )
179
178
res = xp .fft .rfftfreq (n , d = d )
180
179
if dtype is not None :
181
180
return res .astype (dtype )
182
181
return res
183
182
184
- def fftshift (x : ndarray , / , xp , * , axes : Union [int , Sequence [int ]] = None ) -> ndarray :
183
+ def fftshift (
184
+ x : Array , / , xp : Namespace , * , axes : Union [int , Sequence [int ]] = None
185
+ ) -> Array :
185
186
return xp .fft .fftshift (x , axes = axes )
186
187
187
- def ifftshift (x : ndarray , / , xp , * , axes : Union [int , Sequence [int ]] = None ) -> ndarray :
188
+ def ifftshift (
189
+ x : Array , / , xp : Namespace , * , axes : Union [int , Sequence [int ]] = None
190
+ ) -> Array :
188
191
return xp .fft .ifftshift (x , axes = axes )
189
192
190
193
__all__ = [
0 commit comments