Skip to content

Commit 1a1c62b

Browse files
authored
added rewrites for inv(diag(x)) and inv(eye) (#898)
* updated tests * updated rewrites * paramterized tests and added batch case * minor changes
1 parent 7eca252 commit 1a1c62b

File tree

2 files changed

+190
-6
lines changed

2 files changed

+190
-6
lines changed

pytensor/tensor/rewriting/linalg.py

+93-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import cast
44

55
from pytensor import Variable
6+
from pytensor import tensor as pt
67
from pytensor.graph import Apply, FunctionGraph
78
from pytensor.graph.rewriting.basic import (
89
copy_stack_trace,
@@ -48,6 +49,7 @@
4849

4950

5051
logger = logging.getLogger(__name__)
52+
ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv)
5153

5254

5355
def is_matrix_transpose(x: TensorVariable) -> bool:
@@ -592,11 +594,10 @@ def rewrite_inv_inv(fgraph, node):
592594
list of Variable, optional
593595
List of optimized variables, or None if no optimization was performed
594596
"""
595-
valid_inverses = (MatrixInverse, MatrixPinv)
596597
# Check if its a valid inverse operation (either inv/pinv)
597598
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
598599
# If the outer operation is not a valid inverse, we do not apply this rewrite
599-
if not isinstance(node.op.core_op, valid_inverses):
600+
if not isinstance(node.op.core_op, ALL_INVERSE_OPS):
600601
return None
601602

602603
potential_inner_inv = node.inputs[0].owner
@@ -607,7 +608,96 @@ def rewrite_inv_inv(fgraph, node):
607608
if not (
608609
potential_inner_inv
609610
and isinstance(potential_inner_inv.op, Blockwise)
610-
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
611+
and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS)
611612
):
612613
return None
613614
return [potential_inner_inv.inputs[0]]
615+
616+
617+
@register_canonicalize
618+
@register_stabilize
619+
@node_rewriter([Blockwise])
620+
def rewrite_inv_eye_to_eye(fgraph, node):
621+
"""
622+
This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself
623+
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op.
624+
Parameters
625+
----------
626+
fgraph: FunctionGraph
627+
Function graph being optimized
628+
node: Apply
629+
Node of the function graph to be optimized
630+
Returns
631+
-------
632+
list of Variable, optional
633+
List of optimized variables, or None if no optimization was performed
634+
"""
635+
core_op = node.op.core_op
636+
if not (isinstance(core_op, ALL_INVERSE_OPS)):
637+
return None
638+
639+
# Check whether input to inverse is Eye and the 1's are on main diagonal
640+
potential_eye = node.inputs[0]
641+
if not (
642+
potential_eye.owner
643+
and isinstance(potential_eye.owner.op, Eye)
644+
and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0
645+
):
646+
return None
647+
return [potential_eye]
648+
649+
650+
@register_canonicalize
651+
@register_stabilize
652+
@node_rewriter([Blockwise])
653+
def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
654+
"""
655+
This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
656+
This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix
657+
658+
Parameters
659+
----------
660+
fgraph: FunctionGraph
661+
Function graph being optimized
662+
node: Apply
663+
Node of the function graph to be optimized
664+
665+
Returns
666+
-------
667+
list of Variable, optional
668+
List of optimized variables, or None if no optimization was performed
669+
"""
670+
core_op = node.op.core_op
671+
if not (isinstance(core_op, ALL_INVERSE_OPS)):
672+
return None
673+
674+
inputs = node.inputs[0]
675+
# Check for use of pt.diag first
676+
if (
677+
inputs.owner
678+
and isinstance(inputs.owner.op, AllocDiag)
679+
and AllocDiag.is_offset_zero(inputs.owner)
680+
):
681+
inv_input = inputs.owner.inputs[0]
682+
inv_val = pt.diag(1 / inv_input)
683+
return [inv_val]
684+
685+
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
686+
inputs_or_none = _find_diag_from_eye_mul(inputs)
687+
if inputs_or_none is None:
688+
return None
689+
690+
eye_input, non_eye_inputs = inputs_or_none
691+
692+
# Dealing with only one other input
693+
if len(non_eye_inputs) != 1:
694+
return None
695+
696+
non_eye_input = non_eye_inputs[0]
697+
698+
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
699+
if non_eye_input.type.broadcastable[-2:] == (False, False):
700+
non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2)
701+
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)
702+
703+
return [eye_input / non_eye_input]

tests/tensor/rewriting/test_linalg.py

+97-3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
from tests.test_rop import break_op
4242

4343

44+
ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8
45+
46+
4447
def test_rop_lop():
4548
mx = matrix("mx")
4649
mv = matrix("mv")
@@ -557,14 +560,105 @@ def test_svd_uv_merge():
557560
assert svd_counter == 1
558561

559562

563+
def get_pt_function(x, op_name):
564+
return getattr(pt.linalg, op_name)(x)
565+
566+
560567
@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"])
561568
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"])
562569
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
563-
def get_pt_function(x, op_name):
564-
return getattr(pt.linalg, op_name)(x)
565-
566570
x = pt.matrix("x")
567571
op1 = get_pt_function(x, inv_op_1)
568572
op2 = get_pt_function(op1, inv_op_2)
569573
rewritten_out = rewrite_graph(op2)
570574
assert rewritten_out == x
575+
576+
577+
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
578+
def test_inv_eye_to_eye(inv_op):
579+
x = pt.eye(10)
580+
x_inv = get_pt_function(x, inv_op)
581+
f_rewritten = function([], x_inv, mode="FAST_RUN")
582+
nodes = f_rewritten.maker.fgraph.apply_nodes
583+
584+
# Rewrite Test
585+
valid_inverses = (MatrixInverse, MatrixPinv)
586+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
587+
588+
# Value Test
589+
x_test = np.eye(10)
590+
x_inv_val = np.linalg.inv(x_test)
591+
rewritten_val = f_rewritten()
592+
593+
assert_allclose(
594+
x_inv_val,
595+
rewritten_val,
596+
atol=1e-3 if config.floatX == "float32" else 1e-8,
597+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
598+
)
599+
600+
601+
@pytest.mark.parametrize(
602+
"shape",
603+
[(), (7,), (7, 7), (5, 7, 7)],
604+
ids=["scalar", "vector", "matrix", "batched"],
605+
)
606+
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
607+
def test_inv_diag_from_eye_mul(shape, inv_op):
608+
# Initializing x based on scalar/vector/matrix
609+
x = pt.tensor("x", shape=shape)
610+
x_diag = pt.eye(7) * x
611+
# Calculating inverse using pt.linalg.inv
612+
x_inv = get_pt_function(x_diag, inv_op)
613+
614+
# REWRITE TEST
615+
f_rewritten = function([x], x_inv, mode="FAST_RUN")
616+
nodes = f_rewritten.maker.fgraph.apply_nodes
617+
618+
valid_inverses = (MatrixInverse, MatrixPinv)
619+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
620+
621+
# NUMERIC VALUE TEST
622+
if len(shape) == 0:
623+
x_test = np.array(np.random.rand()).astype(config.floatX)
624+
elif len(shape) == 1:
625+
x_test = np.random.rand(*shape).astype(config.floatX)
626+
else:
627+
x_test = np.random.rand(*shape).astype(config.floatX)
628+
x_test_matrix = np.eye(7) * x_test
629+
inverse_matrix = np.linalg.inv(x_test_matrix)
630+
rewritten_inverse = f_rewritten(x_test)
631+
632+
assert_allclose(
633+
inverse_matrix,
634+
rewritten_inverse,
635+
atol=ATOL,
636+
rtol=RTOL,
637+
)
638+
639+
640+
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
641+
def test_inv_diag_from_diag(inv_op):
642+
x = pt.dvector("x")
643+
x_diag = pt.diag(x)
644+
x_inv = get_pt_function(x_diag, inv_op)
645+
646+
# REWRITE TEST
647+
f_rewritten = function([x], x_inv, mode="FAST_RUN")
648+
nodes = f_rewritten.maker.fgraph.apply_nodes
649+
650+
valid_inverses = (MatrixInverse, MatrixPinv)
651+
assert not any(isinstance(node.op, valid_inverses) for node in nodes)
652+
653+
# NUMERIC VALUE TEST
654+
x_test = np.random.rand(10)
655+
x_test_matrix = np.eye(10) * x_test
656+
inverse_matrix = np.linalg.inv(x_test_matrix)
657+
rewritten_inverse = f_rewritten(x_test)
658+
659+
assert_allclose(
660+
inverse_matrix,
661+
rewritten_inverse,
662+
atol=ATOL,
663+
rtol=RTOL,
664+
)

0 commit comments

Comments
 (0)