1
1
from __future__ import annotations
2
+
2
3
from functools import partial
4
+ from typing import TYPE_CHECKING
3
5
4
- from ...common import _aliases
5
- from ...common ._helpers import _check_device
6
+ import numpy as np
6
7
7
8
from ..._internal import get_xp
9
+ from ...common import _aliases , _linalg
10
+ from ...common ._helpers import _check_device
8
11
9
- import numpy as np
10
-
11
- from typing import TYPE_CHECKING
12
12
if TYPE_CHECKING :
13
- from typing import Optional , Union
14
- from ...common ._typing import ndarray , Device , Dtype
13
+ from typing import Optional , Tuple , Union
14
+
15
+ from ...common ._typing import Device , Dtype , ndarray
15
16
16
17
import dask .array as da
17
18
24
25
# not pass stop/step as keyword arguments, which will cause
25
26
# an error with dask
26
27
28
+
27
29
# TODO: delete the xp stuff, it shouldn't be necessary
28
30
def dask_arange (
29
31
start : Union [int , float ],
@@ -34,7 +36,7 @@ def dask_arange(
34
36
xp ,
35
37
dtype : Optional [Dtype ] = None ,
36
38
device : Optional [Device ] = None ,
37
- ** kwargs
39
+ ** kwargs ,
38
40
) -> ndarray :
39
41
_check_device (xp , device )
40
42
args = [start ]
@@ -47,10 +49,11 @@ def dask_arange(
47
49
args .append (step )
48
50
return xp .arange (* args , dtype = dtype , ** kwargs )
49
51
52
+
50
53
arange = get_xp (da )(dask_arange )
51
54
eye = get_xp (da )(_aliases .eye )
52
55
53
- asarray = partial (_aliases ._asarray , namespace = ' dask.array' )
56
+ asarray = partial (_aliases ._asarray , namespace = " dask.array" )
54
57
asarray .__doc__ = _aliases ._asarray .__doc__
55
58
56
59
linspace = get_xp (da )(_aliases .linspace )
@@ -86,3 +89,22 @@ def dask_arange(
86
89
matmul = get_xp (np )(_aliases .matmul )
87
90
tensordot = get_xp (np )(_aliases .tensordot )
88
91
92
+
93
+ EighResult = _linalg .EighResult
94
+ QRResult = _linalg .QRResult
95
+ SlogdetResult = _linalg .SlogdetResult
96
+ SVDResult = _linalg .SVDResult
97
+ qr = get_xp (da )(_linalg .qr )
98
+ cholesky = get_xp (da )(_linalg .cholesky )
99
+ matrix_rank = get_xp (da )(_linalg .matrix_rank )
100
+ matrix_norm = get_xp (da )(_linalg .matrix_norm )
101
+
102
+
103
+ def svdvals (x : ndarray ) -> Union [ndarray , Tuple [ndarray , ...]]:
104
+ # TODO: can't avoid computing U or V for dask
105
+ _ , s , _ = da .linalg .svd (x )
106
+ return s
107
+
108
+
109
+ vector_norm = get_xp (da )(_linalg .vector_norm )
110
+ diagonal = get_xp (da )(_linalg .diagonal )
0 commit comments