File tree 2 files changed +11
-14
lines changed
pytensor/link/pytorch/dispatch
2 files changed +11
-14
lines changed Original file line number Diff line number Diff line change @@ -39,14 +39,7 @@ def cholesky(a):
39
39
40
40
@pytorch_funcify .register (Solve )
41
41
def pytorch_funcify_Solve (op , ** kwargs ):
42
- lower = False
43
- if op .assume_a != "gen" and op .lower :
44
- lower = True
45
-
46
42
def solve (a , b ):
47
- if lower :
48
- return torch .linalg .solve (torch .tril (a ), b )
49
-
50
43
return torch .linalg .solve (a , b )
51
44
52
45
return solve
@@ -64,12 +57,12 @@ def pytorch_funcify_SolveTriangular(op, **kwargs):
64
57
trans = op .trans
65
58
66
59
def solve_triangular (A , b ):
67
- A_p = A
68
- if trans == 1 or trans == "T" :
60
+ if trans in [1 , "T" ]:
69
61
A_p = A .T
70
-
71
- if trans == 2 or trans == "C" :
72
- A_p = A .H
62
+ elif trans in [2 , "C" ]:
63
+ A_p = A .conj ().T
64
+ else :
65
+ A_p = A
73
66
74
67
b_p = b
75
68
if b .ndim == 1 :
Original file line number Diff line number Diff line change @@ -40,11 +40,10 @@ def test_pytorch_eigvalsh(lower):
40
40
)
41
41
42
42
43
- def test_pytorch_basic ():
43
+ def test_pytorch_cholesky ():
44
44
rng = np .random .default_rng (28494 )
45
45
46
46
x = matrix ("x" )
47
- b = vector ("b" )
48
47
49
48
out = pt_slinalg .cholesky (x )
50
49
out_fg = FunctionGraph ([x ], [out ])
@@ -68,6 +67,11 @@ def test_pytorch_basic():
68
67
],
69
68
)
70
69
70
+
71
+ def test_pytorch_solve ():
72
+ x = matrix ("x" )
73
+ b = vector ("b" )
74
+
71
75
out = pt_slinalg .solve (x , b )
72
76
out_fg = FunctionGraph ([x , b ], [out ])
73
77
compare_pytorch_and_py (
You can’t perform that action at this time.
0 commit comments