Skip to content

Commit ded67f9

Browse files
committed
Revisited slinalg implementation in PyTorch
- Removed lower param to solve - Refactored tests
1 parent ae76cdf commit ded67f9

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

pytensor/link/pytorch/dispatch/slinalg.py

+5-12
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,7 @@ def cholesky(a):
3939

4040
@pytorch_funcify.register(Solve)
4141
def pytorch_funcify_Solve(op, **kwargs):
42-
lower = False
43-
if op.assume_a != "gen" and op.lower:
44-
lower = True
45-
4642
def solve(a, b):
47-
if lower:
48-
return torch.linalg.solve(torch.tril(a), b)
49-
5043
return torch.linalg.solve(a, b)
5144

5245
return solve
@@ -64,12 +57,12 @@ def pytorch_funcify_SolveTriangular(op, **kwargs):
6457
trans = op.trans
6558

6659
def solve_triangular(A, b):
67-
A_p = A
68-
if trans == 1 or trans == "T":
60+
if trans in [1, "T"]:
6961
A_p = A.T
70-
71-
if trans == 2 or trans == "C":
72-
A_p = A.H
62+
elif trans in [2, "C"]:
63+
A_p = A.conj().T
64+
else:
65+
A_p = A
7366

7467
b_p = b
7568
if b.ndim == 1:

tests/link/pytorch/test_slinalg.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@ def test_pytorch_eigvalsh(lower):
4040
)
4141

4242

43-
def test_pytorch_basic():
43+
def test_pytorch_cholesky():
4444
rng = np.random.default_rng(28494)
4545

4646
x = matrix("x")
47-
b = vector("b")
4847

4948
out = pt_slinalg.cholesky(x)
5049
out_fg = FunctionGraph([x], [out])
@@ -68,6 +67,11 @@ def test_pytorch_basic():
6867
],
6968
)
7069

70+
71+
def test_pytorch_solve():
72+
x = matrix("x")
73+
b = vector("b")
74+
7175
out = pt_slinalg.solve(x, b)
7276
out_fg = FunctionGraph([x, b], [out])
7377
compare_pytorch_and_py(

0 commit comments

Comments
 (0)