Skip to content

Commit 05118b5

Browse files
committed
Implemented linear algebra functions in PyTorch
- BlockDiagonal - Cholesky - Eigvalsh - Solve - SolveTriangular
1 parent 4c71816 commit 05118b5

File tree

4 files changed

+232
-1
lines changed

4 files changed

+232
-1
lines changed

pytensor/link/pytorch/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
import pytensor.link.pytorch.dispatch.elemwise
77
import pytensor.link.pytorch.dispatch.extra_ops
88
import pytensor.link.pytorch.dispatch.sort
9+
import pytensor.link.pytorch.dispatch.slinalg
910
# isort: on

pytensor/link/pytorch/dispatch/basic.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
@singledispatch
1313
def pytorch_typify(data, dtype=None, **kwargs):
1414
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
15-
return torch.as_tensor(data, dtype=dtype)
15+
if data is not None:
16+
return torch.as_tensor(data, dtype=dtype)
17+
return None
1618

1719

1820
@singledispatch
+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4+
from pytensor.tensor.slinalg import (
5+
BlockDiagonal,
6+
Cholesky,
7+
Eigvalsh,
8+
Solve,
9+
SolveTriangular,
10+
)
11+
12+
13+
@pytorch_funcify.register(Eigvalsh)
14+
def pytorch_funcify_Eigvalsh(op, **kwargs):
15+
if op.lower:
16+
UPLO = "L"
17+
else:
18+
UPLO = "U"
19+
20+
def eigvalsh(a, b):
21+
if b is not None:
22+
raise NotImplementedError(
23+
"torch.linalg.eigvalsh does not support generalized eigenvector problems (b != None)"
24+
)
25+
return torch.linalg.eigvalsh(a, UPLO=UPLO)
26+
27+
return eigvalsh
28+
29+
30+
@pytorch_funcify.register(Cholesky)
31+
def pytorch_funcify_Cholesky(op, **kwargs):
32+
upper = not op.lower
33+
34+
def cholesky(a):
35+
return torch.linalg.cholesky(a, upper=upper)
36+
37+
return cholesky
38+
39+
40+
@pytorch_funcify.register(Solve)
41+
def pytorch_funcify_Solve(op, **kwargs):
42+
lower = False
43+
if op.assume_a != "gen" and op.lower:
44+
lower = True
45+
46+
def solve(a, b):
47+
if lower:
48+
return torch.linalg.solve(torch.tril(a), b)
49+
50+
return torch.linalg.solve(a, b)
51+
52+
return solve
53+
54+
55+
@pytorch_funcify.register(SolveTriangular)
56+
def pytorch_funcify_SolveTriangular(op, **kwargs):
57+
if op.check_finite:
58+
raise NotImplementedError(
59+
"Option check_finite is not implemented in torch.linalg.solve_triangular"
60+
)
61+
62+
upper = not op.lower
63+
unit_diagonal = op.unit_diagonal
64+
trans = op.trans
65+
66+
def solve_triangular(A, b):
67+
A_p = A
68+
if trans == 1 or trans == "T":
69+
A_p = A.T
70+
71+
if trans == 2 or trans == "C":
72+
A_p = A.H
73+
74+
return torch.linalg.solve_triangular(
75+
A_p, b, upper=upper, unitriangular=unit_diagonal
76+
)
77+
78+
return solve_triangular
79+
80+
81+
@pytorch_funcify.register(BlockDiagonal)
82+
def pytorch_funcify_BlockDiagonalMatrix(op, **kwargs):
83+
def block_diag(*inputs):
84+
return torch.block_diag(*inputs)
85+
86+
return block_diag

tests/link/pytorch/test_slinalg.py

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor.tensor as pt
5+
from pytensor.configdefaults import config
6+
from pytensor.graph.fg import FunctionGraph
7+
from pytensor.tensor import slinalg as pt_slinalg
8+
from pytensor.tensor.type import matrix, vector
9+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
10+
11+
12+
@pytest.mark.parametrize("lower", [False, True])
13+
def test_pytorch_eigvalsh(lower):
14+
A = matrix("A")
15+
B = matrix("B")
16+
17+
out = pt_slinalg.eigvalsh(A, B, lower=lower)
18+
out_fg = FunctionGraph([A, B], [out])
19+
20+
with pytest.raises(NotImplementedError):
21+
compare_pytorch_and_py(
22+
out_fg,
23+
[
24+
np.array(
25+
[[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]
26+
).astype(config.floatX),
27+
np.array(
28+
[[10, 0, 1, 3], [0, 12, 7, 8], [1, 7, 14, 2], [3, 8, 2, 16]]
29+
).astype(config.floatX),
30+
],
31+
)
32+
compare_pytorch_and_py(
33+
out_fg,
34+
[
35+
np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype(
36+
config.floatX
37+
),
38+
None,
39+
],
40+
)
41+
42+
43+
def test_pytorch_basic():
44+
rng = np.random.default_rng(28494)
45+
46+
x = matrix("x")
47+
b = vector("b")
48+
49+
out = pt_slinalg.cholesky(x)
50+
out_fg = FunctionGraph([x], [out])
51+
compare_pytorch_and_py(
52+
out_fg,
53+
[
54+
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
55+
config.floatX
56+
)
57+
],
58+
)
59+
60+
out = pt_slinalg.Cholesky(lower=False)(x)
61+
out_fg = FunctionGraph([x], [out])
62+
compare_pytorch_and_py(
63+
out_fg,
64+
[
65+
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
66+
config.floatX
67+
)
68+
],
69+
)
70+
71+
out = pt_slinalg.solve(x, b)
72+
out_fg = FunctionGraph([x, b], [out])
73+
compare_pytorch_and_py(
74+
out_fg,
75+
[
76+
np.eye(10).astype(config.floatX),
77+
np.arange(10).astype(config.floatX),
78+
],
79+
)
80+
81+
82+
@pytest.mark.xfail(reason="Blockwise not implemented")
83+
@pytest.mark.parametrize(
84+
"check_finite",
85+
(False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))),
86+
)
87+
@pytest.mark.parametrize("lower", [False, True])
88+
@pytest.mark.parametrize("trans", [0, 1, 2, "S"])
89+
def test_pytorch_SolveTriangular(trans, lower, check_finite):
90+
x = matrix("x")
91+
b = vector("b")
92+
93+
out = pt_slinalg.solve_triangular(
94+
x,
95+
b,
96+
trans=trans,
97+
lower=lower,
98+
check_finite=check_finite,
99+
)
100+
out_fg = FunctionGraph([x, b], [out])
101+
compare_pytorch_and_py(
102+
out_fg,
103+
[
104+
np.eye(10).astype(config.floatX),
105+
np.arange(10).astype(config.floatX),
106+
],
107+
)
108+
109+
110+
def test_pytorch_block_diag():
111+
A = matrix("A")
112+
B = matrix("B")
113+
C = matrix("C")
114+
D = matrix("D")
115+
116+
out = pt_slinalg.block_diag(A, B, C, D)
117+
out_fg = FunctionGraph([A, B, C, D], [out])
118+
119+
compare_pytorch_and_py(
120+
out_fg,
121+
[
122+
np.random.normal(size=(5, 5)).astype(config.floatX),
123+
np.random.normal(size=(3, 3)).astype(config.floatX),
124+
np.random.normal(size=(2, 2)).astype(config.floatX),
125+
np.random.normal(size=(4, 4)).astype(config.floatX),
126+
],
127+
)
128+
129+
130+
@pytest.mark.xfail(reason="Blockwise not implemented")
131+
def test_pytorch_block_diag_blockwise():
132+
A = pt.tensor3("A")
133+
B = pt.tensor3("B")
134+
out = pt_slinalg.block_diag(A, B)
135+
out_fg = FunctionGraph([A, B], [out])
136+
compare_pytorch_and_py(
137+
out_fg,
138+
[
139+
np.random.normal(size=(5, 5, 5)).astype(config.floatX),
140+
np.random.normal(size=(5, 3, 3)).astype(config.floatX),
141+
],
142+
)

0 commit comments

Comments
 (0)