Skip to content

Commit bad8d20

Browse files
authored
Slogdet returns naive expression and is optimized later (#1041)
1 parent 33a4d48 commit bad8d20

File tree

4 files changed

+197
-54
lines changed

4 files changed

+197
-54
lines changed

Diff for: pytensor/tensor/nlinalg.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.gradient import DisconnectedType
1212
from pytensor.graph.basic import Apply
1313
from pytensor.graph.op import Op
14+
from pytensor.tensor import TensorLike
1415
from pytensor.tensor import basic as ptb
1516
from pytensor.tensor import math as ptm
1617
from pytensor.tensor.basic import as_tensor_variable, diagonal
@@ -266,7 +267,33 @@ def __str__(self):
266267
return "SLogDet"
267268

268269

269-
slogdet = Blockwise(SLogDet())
270+
def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
271+
"""
272+
Compute the sign and (natural) logarithm of the determinant of an array.
273+
274+
Returns a naive graph which is optimized later using rewrites with the det operation.
275+
276+
Parameters
277+
----------
278+
x : (..., M, M) tensor or tensor_like
279+
Input tensor, has to be square.
280+
281+
Returns
282+
-------
283+
A tuple with the following attributes:
284+
285+
sign : (...) tensor_like
286+
A number representing the sign of the determinant. For a real matrix,
287+
this is 1, 0, or -1.
288+
logabsdet : (...) tensor_like
289+
The natural log of the absolute value of the determinant.
290+
291+
If the determinant is zero, then `sign` will be 0 and `logabsdet`
292+
will be -inf. In all cases, the determinant is equal to
293+
``sign * exp(logabsdet)``.
294+
"""
295+
det_val = det(x)
296+
return ptm.sign(det_val), ptm.log(ptm.abs(det_val))
270297

271298

272299
class Eig(Op):

Diff for: pytensor/tensor/rewriting/linalg.py

+72-48
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from collections.abc import Callable
33
from typing import cast
44

5+
import numpy as np
6+
57
from pytensor import Variable
68
from pytensor import tensor as pt
79
from pytensor.compile import optdb
@@ -11,7 +13,7 @@
1113
in2out,
1214
node_rewriter,
1315
)
14-
from pytensor.scalar.basic import Mul
16+
from pytensor.scalar.basic import Abs, Log, Mul, Sign
1517
from pytensor.tensor.basic import (
1618
AllocDiag,
1719
ExtractDiag,
@@ -30,11 +32,11 @@
3032
KroneckerProduct,
3133
MatrixInverse,
3234
MatrixPinv,
35+
SLogDet,
3336
det,
3437
inv,
3538
kron,
3639
pinv,
37-
slogdet,
3840
svd,
3941
)
4042
from pytensor.tensor.rewriting.basic import (
@@ -785,45 +787,6 @@ def rewrite_det_blockdiag(fgraph, node):
785787
return [prod(det_sub_matrices)]
786788

787789

788-
@register_canonicalize
789-
@register_stabilize
790-
@node_rewriter([slogdet])
791-
def rewrite_slogdet_blockdiag(fgraph, node):
792-
"""
793-
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
794-
795-
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
796-
797-
Parameters
798-
----------
799-
fgraph: FunctionGraph
800-
Function graph being optimized
801-
node: Apply
802-
Node of the function graph to be optimized
803-
804-
Returns
805-
-------
806-
list of Variable, optional
807-
List of optimized variables, or None if no optimization was performed
808-
"""
809-
# Check for inner block_diag operation
810-
potential_block_diag = node.inputs[0].owner
811-
if not (
812-
potential_block_diag
813-
and isinstance(potential_block_diag.op, Blockwise)
814-
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
815-
):
816-
return None
817-
818-
# Find the composing sub_matrices
819-
sub_matrices = potential_block_diag.inputs
820-
sign_sub_matrices, logdet_sub_matrices = zip(
821-
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
822-
)
823-
824-
return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]
825-
826-
827790
@register_canonicalize
828791
@register_stabilize
829792
@node_rewriter([ExtractDiag])
@@ -860,10 +823,10 @@ def rewrite_diag_kronecker(fgraph, node):
860823

861824
@register_canonicalize
862825
@register_stabilize
863-
@node_rewriter([slogdet])
864-
def rewrite_slogdet_kronecker(fgraph, node):
826+
@node_rewriter([det])
827+
def rewrite_det_kronecker(fgraph, node):
865828
"""
866-
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
829+
This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those
867830
868831
Parameters
869832
----------
@@ -884,13 +847,12 @@ def rewrite_slogdet_kronecker(fgraph, node):
884847

885848
# Find the matrices
886849
a, b = potential_kron.inputs
887-
signs, logdets = zip(*[slogdet(a), slogdet(b)])
850+
dets = [det(a), det(b)]
888851
sizes = [a.shape[-1], b.shape[-1]]
889852
prod_sizes = prod(sizes, no_zeros_in_input=True)
890-
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
891-
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
853+
det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)])
892854

893-
return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
855+
return [det_final]
894856

895857

896858
@register_canonicalize
@@ -989,3 +951,65 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
989951
"jax",
990952
position=0.9, # Run before canonicalization
991953
)
954+
955+
956+
@register_specialize
957+
@node_rewriter([det])
958+
def slogdet_specialization(fgraph, node):
959+
"""
960+
This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation.
961+
962+
Parameters
963+
----------
964+
fgraph: FunctionGraph
965+
Function graph being optimized
966+
node: Apply
967+
Node of the function graph to be optimized
968+
969+
Returns
970+
-------
971+
dictionary of Variables, optional
972+
Dictionary of nodes and what they should be replaced with, or None if no optimization was performed
973+
"""
974+
dummy_replacements = {}
975+
for client, _ in fgraph.clients[node.outputs[0]]:
976+
# Check for sign(det)
977+
if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign):
978+
dummy_replacements[client.outputs[0]] = "sign"
979+
980+
# Check for log(abs(det))
981+
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs):
982+
potential_log = None
983+
for client_2, _ in fgraph.clients[client.outputs[0]]:
984+
if isinstance(client_2.op, Elemwise) and isinstance(
985+
client_2.op.scalar_op, Log
986+
):
987+
potential_log = client_2
988+
if potential_log:
989+
dummy_replacements[potential_log.outputs[0]] = "log_abs_det"
990+
else:
991+
return None
992+
993+
# Check for log(det)
994+
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log):
995+
dummy_replacements[client.outputs[0]] = "log_det"
996+
997+
# Det is used directly for something else, don't rewrite to avoid computing two dets
998+
else:
999+
return None
1000+
1001+
if not dummy_replacements:
1002+
return None
1003+
else:
1004+
[x] = node.inputs
1005+
sign_det_x, log_abs_det_x = SLogDet()(x)
1006+
log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x)
1007+
slogdet_specialization_map = {
1008+
"sign": sign_det_x,
1009+
"log_abs_det": log_abs_det_x,
1010+
"log_det": log_det_x,
1011+
}
1012+
replacements = {
1013+
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
1014+
}
1015+
return replacements

Diff for: tests/link/pytorch/test_nlinalg.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Sequence
2+
13
import numpy as np
24
import pytest
35

@@ -22,13 +24,13 @@ def matrix_test():
2224

2325
@pytest.mark.parametrize(
2426
"func",
25-
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det),
27+
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
2628
)
2729
def test_lin_alg_no_params(func, matrix_test):
2830
x, test_value = matrix_test
2931

3032
out = func(x)
31-
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
33+
out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out])
3234

3335
def assert_fn(x, y):
3436
np.testing.assert_allclose(x, y, rtol=1e-3)

Diff for: tests/tensor/rewriting/test_linalg.py

+93-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
KroneckerProduct,
2222
MatrixInverse,
2323
MatrixPinv,
24+
SLogDet,
2425
matrix_inverse,
2526
svd,
2627
)
@@ -719,7 +720,7 @@ def test_det_blockdiag_rewrite():
719720

720721

721722
def test_slogdet_blockdiag_rewrite():
722-
n_matrices = 100
723+
n_matrices = 10
723724
matrix_size = (5, 5)
724725
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
725726
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
@@ -776,11 +777,34 @@ def test_diag_kronecker_rewrite():
776777
)
777778

778779

780+
def test_det_kronecker_rewrite():
781+
a, b = pt.dmatrices("a", "b")
782+
kron_prod = pt.linalg.kron(a, b)
783+
det_output = pt.linalg.det(kron_prod)
784+
f_rewritten = function([a, b], [det_output], mode="FAST_RUN")
785+
786+
# Rewrite Test
787+
nodes = f_rewritten.maker.fgraph.apply_nodes
788+
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)
789+
790+
# Value Test
791+
a_test, b_test = np.random.rand(2, 20, 20)
792+
kron_prod_test = np.kron(a_test, b_test)
793+
det_output_test = np.linalg.det(kron_prod_test)
794+
rewritten_det_val = f_rewritten(a_test, b_test)
795+
assert_allclose(
796+
det_output_test,
797+
rewritten_det_val,
798+
atol=1e-3 if config.floatX == "float32" else 1e-8,
799+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
800+
)
801+
802+
779803
def test_slogdet_kronecker_rewrite():
780804
a, b = pt.dmatrices("a", "b")
781805
kron_prod = pt.linalg.kron(a, b)
782806
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
783-
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
807+
f_rewritten = function([a, b], [sign_output, logdet_output], mode="FAST_RUN")
784808

785809
# Rewrite Test
786810
nodes = f_rewritten.maker.fgraph.apply_nodes
@@ -790,7 +814,7 @@ def test_slogdet_kronecker_rewrite():
790814
a_test, b_test = np.random.rand(2, 20, 20)
791815
kron_prod_test = np.kron(a_test, b_test)
792816
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
793-
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
817+
rewritten_sign_val, rewritten_logdet_val = f_rewritten(a_test, b_test)
794818
assert_allclose(
795819
sign_output_test,
796820
rewritten_sign_val,
@@ -906,3 +930,69 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
906930
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
907931
nodes = f_rewritten.maker.fgraph.apply_nodes
908932
assert any(isinstance(node.op, Cholesky) for node in nodes)
933+
934+
935+
def test_slogdet_specialization():
936+
x, a = pt.dmatrix("x"), np.random.rand(20, 20)
937+
det_x, det_a = pt.linalg.det(x), np.linalg.det(a)
938+
log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a))
939+
log_det_x, log_det_a = pt.log(det_x), np.log(det_a)
940+
sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a)
941+
exp_det_x = pt.exp(det_x)
942+
943+
# REWRITE TESTS
944+
# sign(det(x))
945+
f = function([x], [sign_det_x], mode="FAST_RUN")
946+
nodes = f.maker.fgraph.apply_nodes
947+
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
948+
assert not any(isinstance(node.op, Det) for node in nodes)
949+
rw_sign_det_a = f(a)
950+
assert_allclose(
951+
sign_det_a,
952+
rw_sign_det_a,
953+
atol=1e-3 if config.floatX == "float32" else 1e-8,
954+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
955+
)
956+
957+
# log(abs(det(x)))
958+
f = function([x], [log_abs_det_x], mode="FAST_RUN")
959+
nodes = f.maker.fgraph.apply_nodes
960+
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
961+
assert not any(isinstance(node.op, Det) for node in nodes)
962+
rw_log_abs_det_a = f(a)
963+
assert_allclose(
964+
log_abs_det_a,
965+
rw_log_abs_det_a,
966+
atol=1e-3 if config.floatX == "float32" else 1e-8,
967+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
968+
)
969+
970+
# log(det(x))
971+
f = function([x], [log_det_x], mode="FAST_RUN")
972+
nodes = f.maker.fgraph.apply_nodes
973+
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
974+
assert not any(isinstance(node.op, Det) for node in nodes)
975+
rw_log_det_a = f(a)
976+
assert_allclose(
977+
log_det_a,
978+
rw_log_det_a,
979+
atol=1e-3 if config.floatX == "float32" else 1e-8,
980+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
981+
)
982+
983+
# More than 1 valid function
984+
f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN")
985+
nodes = f.maker.fgraph.apply_nodes
986+
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
987+
assert not any(isinstance(node.op, Det) for node in nodes)
988+
989+
# Other functions (rewrite shouldnt be applied to these)
990+
# Only invalid functions
991+
f = function([x], [exp_det_x], mode="FAST_RUN")
992+
nodes = f.maker.fgraph.apply_nodes
993+
assert not any(isinstance(node.op, SLogDet) for node in nodes)
994+
995+
# Invalid + Valid function
996+
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
997+
nodes = f.maker.fgraph.apply_nodes
998+
assert not any(isinstance(node.op, SLogDet) for node in nodes)

0 commit comments

Comments
 (0)