forked from data-apis/array-api-compat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinalg.py
90 lines (76 loc) · 3.18 KB
/
linalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from numpy.linalg import * # noqa: F403
from numpy.linalg import __all__ as linalg_all
import numpy as _np
from ..common import _linalg
from .._internal import get_xp
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
import numpy as np
cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
# of matrices. The np.linalg.solve behavior of allowing stacks of both
# matrices and vectors is ambiguous c.f.
# https://github.com/numpy/numpy/issues/15349 and
# https://github.com/data-apis/array-api/issues/285.
# To workaround this, the below is the code from np.linalg.solve except
# only calling solve1 in the exactly 1D case.
# This code is here instead of in common because it is numpy specific. Also
# note that CuPy's solve() does not currently support broadcasting (see
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
try:
from numpy.linalg._linalg import (
_makearray, _assert_stacked_2d, _assert_stacked_square,
_commonType, isComplexType, _raise_linalgerror_singular
)
except ImportError:
from numpy.linalg.linalg import (
_makearray, _assert_stacked_2d, _assert_stacked_square,
_commonType, isComplexType, _raise_linalgerror_singular
)
from numpy.linalg import _umath_linalg
x1, _ = _makearray(x1)
_assert_stacked_2d(x1)
_assert_stacked_square(x1)
x2, wrap = _makearray(x2)
t, result_t = _commonType(x1, x2)
# This part is different from np.linalg.solve
if x2.ndim == 1:
gufunc = _umath_linalg.solve1
else:
gufunc = _umath_linalg.solve
# This does nothing currently but is left in because it will be relevant
# when complex dtype support is added to the spec in 2022.
signature = 'DD->D' if isComplexType(t) else 'dd->d'
with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
over='ignore', divide='ignore', under='ignore'):
r = gufunc(x1, x2, signature=signature)
return wrap(r.astype(result_t, copy=False))
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, 'vector_norm'):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
__all__ = linalg_all + _linalg.__all__ + ['solve']
del get_xp
del np
del linalg_all
del _linalg