Skip to content

Commit c5d55ae

Browse files
committed
Fix ruff errors for dask/array/linalg
1 parent 5cd47df commit c5d55ae

File tree

2 files changed

+81
-57
lines changed

2 files changed

+81
-57
lines changed

array_api_compat/dask/array/_aliases.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from __future__ import annotations
2+
23
from functools import partial
4+
from typing import TYPE_CHECKING
35

4-
from ...common import _aliases
5-
from ...common._helpers import _check_device
6+
import numpy as np
67

78
from ..._internal import get_xp
9+
from ...common import _aliases, _linalg
10+
from ...common._helpers import _check_device
811

9-
import numpy as np
10-
11-
from typing import TYPE_CHECKING
1212
if TYPE_CHECKING:
13-
from typing import Optional, Union
14-
from ...common._typing import ndarray, Device, Dtype
13+
from typing import Optional, Tuple, Union
14+
15+
from ...common._typing import Device, Dtype, ndarray
1516

1617
import dask.array as da
1718

@@ -24,6 +25,7 @@
2425
# not pass stop/step as keyword arguments, which will cause
2526
# an error with dask
2627

28+
2729
# TODO: delete the xp stuff, it shouldn't be necessary
2830
def dask_arange(
2931
start: Union[int, float],
@@ -34,7 +36,7 @@ def dask_arange(
3436
xp,
3537
dtype: Optional[Dtype] = None,
3638
device: Optional[Device] = None,
37-
**kwargs
39+
**kwargs,
3840
) -> ndarray:
3941
_check_device(xp, device)
4042
args = [start]
@@ -47,10 +49,11 @@ def dask_arange(
4749
args.append(step)
4850
return xp.arange(*args, dtype=dtype, **kwargs)
4951

52+
5053
arange = get_xp(da)(dask_arange)
5154
eye = get_xp(da)(_aliases.eye)
5255

53-
asarray = partial(_aliases._asarray, namespace='dask.array')
56+
asarray = partial(_aliases._asarray, namespace="dask.array")
5457
asarray.__doc__ = _aliases._asarray.__doc__
5558

5659
linspace = get_xp(da)(_aliases.linspace)
@@ -86,3 +89,22 @@ def dask_arange(
8689
matmul = get_xp(np)(_aliases.matmul)
8790
tensordot = get_xp(np)(_aliases.tensordot)
8891

92+
93+
EighResult = _linalg.EighResult
94+
QRResult = _linalg.QRResult
95+
SlogdetResult = _linalg.SlogdetResult
96+
SVDResult = _linalg.SVDResult
97+
qr = get_xp(da)(_linalg.qr)
98+
cholesky = get_xp(da)(_linalg.cholesky)
99+
matrix_rank = get_xp(da)(_linalg.matrix_rank)
100+
matrix_norm = get_xp(da)(_linalg.matrix_norm)
101+
102+
103+
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
104+
# TODO: can't avoid computing U or V for dask
105+
_, s, _ = da.linalg.svd(x)
106+
return s
107+
108+
109+
vector_norm = get_xp(da)(_linalg.vector_norm)
110+
diagonal = get_xp(da)(_linalg.diagonal)

array_api_compat/dask/array/linalg.py

+50-48
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,50 @@
1-
from __future__ import annotations
2-
3-
from dask.array.linalg import *
4-
from ...common import _linalg
5-
from ..._internal import get_xp
6-
from dask.array import matmul, tensordot, trace, outer
7-
from ._aliases import matrix_transpose, vecdot
8-
9-
import dask.array as da
10-
11-
from typing import TYPE_CHECKING
12-
if TYPE_CHECKING:
13-
from typing import Union, Tuple
14-
from ...common._typing import ndarray
15-
16-
# cupy.linalg doesn't have __all__. If it is added, replace this with
17-
#
18-
# from cupy.linalg import __all__ as linalg_all
19-
_n = {}
20-
exec('from dask.array.linalg import *', _n)
21-
del _n['__builtins__']
22-
linalg_all = list(_n)
23-
del _n
24-
25-
EighResult = _linalg.EighResult
26-
QRResult = _linalg.QRResult
27-
SlogdetResult = _linalg.SlogdetResult
28-
SVDResult = _linalg.SVDResult
29-
qr = get_xp(da)(_linalg.qr)
30-
cholesky = get_xp(da)(_linalg.cholesky)
31-
matrix_rank = get_xp(da)(_linalg.matrix_rank)
32-
matrix_norm = get_xp(da)(_linalg.matrix_norm)
33-
34-
def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
35-
# TODO: can't avoid computing U or V for dask
36-
_, s, _ = svd(x)
37-
return s
38-
39-
vector_norm = get_xp(da)(_linalg.vector_norm)
40-
diagonal = get_xp(da)(_linalg.diagonal)
41-
42-
__all__ = linalg_all + ["EighResult", "QRResult", "SlogdetResult",
43-
"SVDResult", "qr", "cholesky", "matrix_rank", "matrix_norm",
44-
"svdvals", "vector_norm", "diagonal"]
45-
46-
del get_xp
47-
del da
48-
del _linalg
1+
import dask.array as _da
2+
from dask.array import (
3+
matmul,
4+
outer,
5+
tensordot,
6+
trace,
7+
)
8+
from dask.array.linalg import * # noqa: F401, F403
9+
10+
from .._internal import _get_all_public_members
11+
from ._aliases import (
12+
EighResult,
13+
QRResult,
14+
SlogdetResult,
15+
SVDResult,
16+
cholesky,
17+
diagonal,
18+
matrix_norm,
19+
matrix_rank,
20+
matrix_transpose,
21+
qr,
22+
svdvals,
23+
vecdot,
24+
vector_norm,
25+
)
26+
27+
__all__ = [
28+
"matmul",
29+
"outer",
30+
"tensordot",
31+
"trace",
32+
]
33+
34+
__all__ += _get_all_public_members(_da.linalg)
35+
36+
__all__ += [
37+
"EighResult",
38+
"QRResult",
39+
"SlogdetResult",
40+
"SVDResult",
41+
"qr",
42+
"cholesky",
43+
"matrix_rank",
44+
"matrix_norm",
45+
"matrix_transpose",
46+
"vecdot",
47+
"svdvals",
48+
"vector_norm",
49+
"diagonal",
50+
]

0 commit comments

Comments
 (0)