Skip to content

Workaround np.linalg.solve ambiguity #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion array_api_compat/numpy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from numpy.linalg import * # noqa: F403
from numpy.linalg import __all__ as linalg_all
import numpy as _np

from ..common import _linalg
from .._internal import get_xp
Expand Down Expand Up @@ -27,14 +28,61 @@
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)

# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
# of matrices. The np.linalg.solve behavior of allowing stacks of both
# matrices and vectors is ambiguous c.f.
# https://github.com/numpy/numpy/issues/15349 and
# https://github.com/data-apis/array-api/issues/285.

# To workaround this, the below is the code from np.linalg.solve except
# only calling solve1 in the exactly 1D case.

# This code is here instead of in common because it is numpy specific. Also
# note that CuPy's solve() does not currently support broadcasting (see
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray:
try:
from numpy.linalg._linalg import (
_makearray, _assert_stacked_2d, _assert_stacked_square,
_commonType, isComplexType, _raise_linalgerror_singular
)
except ImportError:
from numpy.linalg.linalg import (
_makearray, _assert_stacked_2d, _assert_stacked_square,
_commonType, isComplexType, _raise_linalgerror_singular
)
from numpy.linalg import _umath_linalg

x1, _ = _makearray(x1)
_assert_stacked_2d(x1)
_assert_stacked_square(x1)
x2, wrap = _makearray(x2)
t, result_t = _commonType(x1, x2)

# This part is different from np.linalg.solve
if x2.ndim == 1:
gufunc = _umath_linalg.solve1
else:
gufunc = _umath_linalg.solve

# This does nothing currently but is left in because it will be relevant
# when complex dtype support is added to the spec in 2022.
signature = 'DD->D' if isComplexType(t) else 'dd->d'
with _np.errstate(call=_raise_linalgerror_singular, invalid='call',
over='ignore', divide='ignore', under='ignore'):
r = gufunc(x1, x2, signature=signature)

return wrap(r.astype(result_t, copy=False))

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, 'vector_norm'):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)

__all__ = linalg_all + _linalg.__all__
__all__ = linalg_all + _linalg.__all__ + ['solve']

del get_xp
del np
Expand Down
6 changes: 0 additions & 6 deletions numpy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]

# linalg tests require https://github.com/data-apis/array-api-tests/pull/101
# cleanups. Also some tests are using .mT
array_api_tests/test_linalg.py::test_eigvalsh
array_api_tests/test_linalg.py::test_solve
array_api_tests/test_linalg.py::test_trace

# Array methods and attributes not already on np.ndarray cannot be wrapped
array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]
array_api_tests/test_signatures.py::test_array_method_signature[to_device]
Expand Down