Skip to content

Commit 2fb9f0e

Browse files
committed
Fixed SolveTriangular Op in PyTorch
Added support for a b param of ndim = 1.
1 parent cf74fda commit 2fb9f0e

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

Diff for: pytensor/link/pytorch/dispatch/slinalg.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,19 @@ def solve_triangular(A, b):
7171
if trans == 2 or trans == "C":
7272
A_p = A.H
7373

74-
return torch.linalg.solve_triangular(
75-
A_p, b, upper=upper, unitriangular=unit_diagonal
74+
b_p = b
75+
if b.ndim == 1:
76+
b_p = b[:, None]
77+
78+
res = torch.linalg.solve_triangular(
79+
A_p, b_p, upper=upper, unitriangular=unit_diagonal
7680
)
7781

82+
if b.ndim == 1 and res.shape[1] == 1:
83+
return res.flatten()
84+
85+
return res
86+
7887
return solve_triangular
7988

8089

Diff for: tests/link/pytorch/test_slinalg.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_pytorch_basic():
5757
],
5858
)
5959

60-
out = pt_slinalg.Cholesky(lower=False)(x)
60+
out = pt_slinalg.cholesky(x, lower=False)
6161
out_fg = FunctionGraph([x], [out])
6262
compare_pytorch_and_py(
6363
out_fg,
@@ -79,7 +79,6 @@ def test_pytorch_basic():
7979
)
8080

8181

82-
@pytest.mark.xfail(reason="Blockwise not implemented")
8382
@pytest.mark.parametrize(
8483
"check_finite",
8584
(False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))),

0 commit comments

Comments
 (0)