Skip to content

Commit 8657396

Browse files
authored
Merge pull request #93 from asmeurer/np-solve-fix
Workaround np.linalg.solve ambiguity
2 parents 817d2ff + 0afd1a7 commit 8657396

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

array_api_compat/numpy/linalg.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from numpy.linalg import * # noqa: F403
22
from numpy.linalg import __all__ as linalg_all
3+
import numpy as _np
34

45
from ..common import _linalg
56
from .._internal import get_xp
@@ -27,14 +28,61 @@
2728
diagonal = get_xp(np)(_linalg.diagonal)
2829
trace = get_xp(np)(_linalg.trace)
2930

31+
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
32+
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
33+
# of matrices. The np.linalg.solve behavior of allowing stacks of both
34+
# matrices and vectors is ambiguous c.f.
35+
# https://github.com/numpy/numpy/issues/15349 and
36+
# https://github.com/data-apis/array-api/issues/285.
37+
38+
# To workaround this, the below is the code from np.linalg.solve except
39+
# only calling solve1 in the exactly 1D case.
40+
41+
# This code is here instead of in common because it is numpy specific. Also
42+
# note that CuPy's solve() does not currently support broadcasting (see
43+
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
44+
def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
45+
try:
46+
from numpy.linalg._linalg import (
47+
_makearray, _assert_stacked_2d, _assert_stacked_square,
48+
_commonType, isComplexType, _raise_linalgerror_singular
49+
)
50+
except ImportError:
51+
from numpy.linalg.linalg import (
52+
_makearray, _assert_stacked_2d, _assert_stacked_square,
53+
_commonType, isComplexType, _raise_linalgerror_singular
54+
)
55+
from numpy.linalg import _umath_linalg
56+
57+
x1, _ = _makearray(x1)
58+
_assert_stacked_2d(x1)
59+
_assert_stacked_square(x1)
60+
x2, wrap = _makearray(x2)
61+
t, result_t = _commonType(x1, x2)
62+
63+
# This part is different from np.linalg.solve
64+
if x2.ndim == 1:
65+
gufunc = _umath_linalg.solve1
66+
else:
67+
gufunc = _umath_linalg.solve
68+
69+
# This does nothing currently but is left in because it will be relevant
70+
# when complex dtype support is added to the spec in 2022.
71+
signature = 'DD->D' if isComplexType(t) else 'dd->d'
72+
with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
73+
over='ignore', divide='ignore', under='ignore'):
74+
r = gufunc(x1, x2, signature=signature)
75+
76+
return wrap(r.astype(result_t, copy=False))
77+
3078
# These functions are completely new here. If the library already has them
3179
# (i.e., numpy 2.0), use the library version instead of our wrapper.
3280
if hasattr(np.linalg, 'vector_norm'):
3381
vector_norm = np.linalg.vector_norm
3482
else:
3583
vector_norm = get_xp(np)(_linalg.vector_norm)
3684

37-
__all__ = linalg_all + _linalg.__all__
85+
__all__ = linalg_all + _linalg.__all__ + ['solve']
3886

3987
del get_xp
4088
del np

numpy-xfails.txt

-6
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
1010
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
1111
array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
1212

13-
# linalg tests require https://github.com/data-apis/array-api-tests/pull/101
14-
# cleanups. Also some tests are using .mT
15-
array_api_tests/test_linalg.py::test_eigvalsh
16-
array_api_tests/test_linalg.py::test_solve
17-
array_api_tests/test_linalg.py::test_trace
18-
1913
# Array methods and attributes not already on np.ndarray cannot be wrapped
2014
array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
2115
array_api_tests/test_signatures.py::test_array_method_signature[to_device]

0 commit comments

Comments
 (0)