Skip to content

Commit 5e7128c

Browse files
committed
Fix torch.linalg.solve to handle batching correctly according to the standard
1 parent 34faced commit 5e7128c

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

array_api_compat/torch/linalg.py

+16
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
6060

6161
def solve(x1: array, x2: array, /, **kwargs) -> array:
6262
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
63+
# Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
64+
# whenever
65+
# 1. x1.ndim - 1 == x2.ndim
66+
# 2. x1.shape[:-1] == x2.shape
67+
#
68+
# See linalg_solve_is_vector_rhs in
69+
# aten/src/ATen/native/LinearAlgebraUtils.h and
70+
# TORCH_META_FUNC(_linalg_solve_ex) in
71+
# aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
72+
#
73+
# The easiest way to work around this is to prepend a size 1 dimension to
74+
# x2, since x2 is already one dimension less than x1.
75+
#
76+
# See https://github.com/pytorch/pytorch/issues/52915
77+
if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
78+
x2 = x2[None]
6379
return torch.linalg.solve(x1, x2, **kwargs)
6480

6581
# torch.trace doesn't support the offset argument and doesn't support stacking

0 commit comments

Comments
 (0)