Skip to content

Commit 5affae5

Browse files
authored
Merge pull request #139 from lithomas1/dask-fft
Wrap fft for dask
2 parents e9da040 + 2182b4f commit 5affae5

File tree

6 files changed

+41
-36
lines changed

6 files changed

+41
-36
lines changed

Diff for: .github/workflows/array-api-tests-dask.yml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ jobs:
77
uses: ./.github/workflows/array-api-tests.yml
88
with:
99
package-name: dask
10+
package-version: '>= 2024.9.0'
1011
module-name: dask.array
1112
extra-requires: numpy
1213
pytest-extra-args: --disable-deadline --max-examples=5

Diff for: .github/workflows/array-api-tests.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ jobs:
4040
runs-on: ubuntu-latest
4141
strategy:
4242
matrix:
43-
python-version: ['3.9', '3.10', '3.11', '3.12']
43+
# min version of dask we needs drops support for python 3.9
44+
python-version: ${{ inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']') || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }}
4445

4546
steps:
4647
- name: Checkout array-api-compat

Diff for: array_api_compat/dask/array/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
__array_api_version__ = '2022.12'
77

88
__import__(__package__ + '.linalg')
9+
__import__(__package__ + '.fft')

Diff for: array_api_compat/dask/array/_aliases.py

+13-20
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,9 @@
99

1010
import numpy as np
1111
from numpy import (
12-
# Constants
13-
e,
14-
inf,
15-
nan,
16-
pi,
17-
newaxis,
1812
# Dtypes
13+
iinfo,
14+
finfo,
1915
bool_ as bool,
2016
float32,
2117
float64,
@@ -29,8 +25,6 @@
2925
uint64,
3026
complex64,
3127
complex128,
32-
iinfo,
33-
finfo,
3428
can_cast,
3529
result_type,
3630
)
@@ -206,19 +200,18 @@ def _isscalar(a):
206200

207201
return astype(xp.minimum(xp.maximum(x, min), max), x.dtype)
208202

209-
# exclude these from all since
203+
# exclude these from all since dask.array has no sorting functions
210204
_da_unsupported = ['sort', 'argsort']
211205

212-
common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
206+
_common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported]
213207

214-
__all__ = common_aliases + ['__array_namespace_info__', 'asarray', 'bool',
215-
'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2',
216-
'atanh', 'bitwise_left_shift', 'bitwise_invert',
217-
'bitwise_right_shift', 'concat', 'pow', 'e',
218-
'inf', 'nan', 'pi', 'newaxis', 'float32',
219-
'float64', 'int8', 'int16', 'int32', 'int64',
220-
'uint8', 'uint16', 'uint32', 'uint64',
221-
'complex64', 'complex128', 'iinfo', 'finfo',
222-
'can_cast', 'result_type']
208+
__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos',
209+
'acosh', 'asin', 'asinh', 'atan', 'atan2',
210+
'atanh', 'bitwise_left_shift', 'bitwise_invert',
211+
'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast',
212+
'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
213+
'uint8', 'uint16', 'uint32', 'uint64',
214+
'complex64', 'complex128', 'iinfo', 'finfo',
215+
'can_cast', 'result_type']
223216

224-
_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']
217+
_all_ignore = ["get_xp", "da", "np"]

Diff for: array_api_compat/dask/array/fft.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from dask.array.fft import * # noqa: F403
2+
# dask.array.fft doesn't have __all__. If it is added, replace this with
3+
#
4+
# from dask.array.fft import __all__ as linalg_all
5+
_n = {}
6+
exec('from dask.array.fft import *', _n)
7+
del _n['__builtins__']
8+
fft_all = list(_n)
9+
del _n
10+
11+
from ...common import _fft
12+
from ..._internal import get_xp
13+
14+
import dask.array as da
15+
16+
fftfreq = get_xp(da)(_fft.fftfreq)
17+
rfftfreq = get_xp(da)(_fft.rfftfreq)
18+
19+
__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"]
20+
21+
del get_xp
22+
del da
23+
del fft_all
24+
del _fft

Diff for: dask-skips.txt

-15
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,2 @@
1-
# FFT isn't conformant
2-
array_api_tests/test_fft.py
3-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.fft]
4-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifft]
5-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftn]
6-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.ifftn]
7-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfft]
8-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfft]
9-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftn]
10-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.irfftn]
11-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.hfft]
12-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.ihfft]
13-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.fftfreq]
14-
array_api_tests/test_signatures.py::test_extension_func_signature[fft.rfftfreq]
15-
161
# slow and not implemented in dask
172
array_api_tests/test_linalg.py::test_matrix_power

0 commit comments

Comments
 (0)