Skip to content

Commit f3d2ede

Browse files
Implemented JAX backend for Eigvalsh (#867)
Co-authored-by: Jesse Grabowski <[email protected]>
1 parent 920b409 commit f3d2ede

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

pytensor/link/jax/dispatch/slinalg.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,30 @@
11
import jax
22

33
from pytensor.link.jax.dispatch.basic import jax_funcify
4-
from pytensor.tensor.slinalg import BlockDiagonal, Cholesky, Solve, SolveTriangular
4+
from pytensor.tensor.slinalg import (
5+
BlockDiagonal,
6+
Cholesky,
7+
Eigvalsh,
8+
Solve,
9+
SolveTriangular,
10+
)
11+
12+
13+
@jax_funcify.register(Eigvalsh)
14+
def jax_funcify_Eigvalsh(op, **kwargs):
15+
if op.lower:
16+
UPLO = "L"
17+
else:
18+
UPLO = "U"
19+
20+
def eigvalsh(a, b):
21+
if b is not None:
22+
raise NotImplementedError(
23+
"jax.numpy.linalg.eigvalsh does not support generalized eigenvector problems (b != None)"
24+
)
25+
return jax.numpy.linalg.eigvalsh(a, UPLO=UPLO)
26+
27+
return eigvalsh
528

629

730
@jax_funcify.register(Cholesky)

tests/link/jax/test_slinalg.py

+31
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,34 @@ def test_jax_block_diag_blockwise():
163163
np.random.normal(size=(5, 3, 3)).astype(config.floatX),
164164
],
165165
)
166+
167+
168+
@pytest.mark.parametrize("lower", [False, True])
169+
def test_jax_eigvalsh(lower):
170+
A = matrix("A")
171+
B = matrix("B")
172+
173+
out = pt_slinalg.eigvalsh(A, B, lower=lower)
174+
out_fg = FunctionGraph([A, B], [out])
175+
176+
with pytest.raises(NotImplementedError):
177+
compare_jax_and_py(
178+
out_fg,
179+
[
180+
np.array(
181+
[[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]
182+
).astype(config.floatX),
183+
np.array(
184+
[[10, 0, 1, 3], [0, 12, 7, 8], [1, 7, 14, 2], [3, 8, 2, 16]]
185+
).astype(config.floatX),
186+
],
187+
)
188+
compare_jax_and_py(
189+
out_fg,
190+
[
191+
np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype(
192+
config.floatX
193+
),
194+
None,
195+
],
196+
)

0 commit comments

Comments
 (0)