Skip to content

Commit f489cf4

Browse files
authored
Added rewrite for matrix inv(inv(x)) -> x (#893)
1 parent ad27dc7 commit f489cf4

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

pytensor/tensor/rewriting/linalg.py

+42
Original file line numberDiff line numberDiff line change
@@ -569,3 +569,45 @@ def svd_uv_merge(fgraph, node):
569569
or len(fgraph.clients[cl.outputs[2]]) > 0
570570
):
571571
return [cl.outputs[1]]
572+
573+
574+
@register_canonicalize
575+
@register_stabilize
576+
@node_rewriter([Blockwise])
577+
def rewrite_inv_inv(fgraph, node):
578+
"""
579+
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once.
580+
581+
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten.
582+
583+
Parameters
584+
----------
585+
fgraph: FunctionGraph
586+
Function graph being optimized
587+
node: Apply
588+
Node of the function graph to be optimized
589+
590+
Returns
591+
-------
592+
list of Variable, optional
593+
List of optimized variables, or None if no optimization was performed
594+
"""
595+
valid_inverses = (MatrixInverse, MatrixPinv)
596+
# Check if its a valid inverse operation (either inv/pinv)
597+
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
598+
# If the outer operation is not a valid inverse, we do not apply this rewrite
599+
if not isinstance(node.op.core_op, valid_inverses):
600+
return None
601+
602+
potential_inner_inv = node.inputs[0].owner
603+
if potential_inner_inv is None or potential_inner_inv.op is None:
604+
return None
605+
606+
# Check if inner op is blockwise and and possible inv
607+
if not (
608+
potential_inner_inv
609+
and isinstance(potential_inner_inv.op, Blockwise)
610+
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
611+
):
612+
return None
613+
return [potential_inner_inv.inputs[0]]

tests/tensor/rewriting/test_linalg.py

+14
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor import tensor as pt
1111
from pytensor.compile import get_default_mode
1212
from pytensor.configdefaults import config
13+
from pytensor.graph.rewriting.utils import rewrite_graph
1314
from pytensor.tensor import swapaxes
1415
from pytensor.tensor.blockwise import Blockwise
1516
from pytensor.tensor.elemwise import DimShuffle
@@ -554,3 +555,16 @@ def test_svd_uv_merge():
554555
assert node.op.compute_uv
555556
svd_counter += 1
556557
assert svd_counter == 1
558+
559+
560+
@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"])
561+
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"])
562+
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
563+
def get_pt_function(x, op_name):
564+
return getattr(pt.linalg, op_name)(x)
565+
566+
x = pt.matrix("x")
567+
op1 = get_pt_function(x, inv_op_1)
568+
op2 = get_pt_function(op1, inv_op_2)
569+
rewritten_out = rewrite_graph(op2)
570+
assert rewritten_out == x

0 commit comments

Comments
 (0)