Skip to content

Commit 920b409

Browse files
Add rewrite to merge multiple SVD Ops with different settings (#769)
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent a8d7638 commit 920b409

File tree

2 files changed

+128
-1
lines changed

2 files changed

+128
-1
lines changed

pytensor/tensor/rewriting/linalg.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,25 @@
44

55
from pytensor import Variable
66
from pytensor.graph import Apply, FunctionGraph
7-
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
7+
from pytensor.graph.rewriting.basic import (
8+
copy_stack_trace,
9+
node_rewriter,
10+
)
811
from pytensor.tensor.basic import TensorVariable, diagonal
912
from pytensor.tensor.blas import Dot22
1013
from pytensor.tensor.blockwise import Blockwise
1114
from pytensor.tensor.elemwise import DimShuffle
1215
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
1316
from pytensor.tensor.nlinalg import (
17+
SVD,
1418
KroneckerProduct,
1519
MatrixInverse,
1620
MatrixPinv,
1721
det,
1822
inv,
1923
kron,
2024
pinv,
25+
svd,
2126
)
2227
from pytensor.tensor.rewriting.basic import (
2328
register_canonicalize,
@@ -377,3 +382,59 @@ def local_lift_through_linalg(
377382
return [block_diag(*inner_matrices)]
378383
else:
379384
raise NotImplementedError # pragma: no cover
385+
386+
387+
@register_canonicalize
388+
@register_stabilize
389+
@register_specialize
390+
@node_rewriter([Blockwise])
391+
def svd_uv_merge(fgraph, node):
392+
"""If we have more than one `SVD` `Op`s and at least one has keyword argument
393+
`compute_uv=True`, then we can change `compute_uv = False` to `True` everywhere
394+
and allow `pytensor` to re-use the decomposition outputs instead of recomputing.
395+
"""
396+
if not isinstance(node.op.core_op, SVD):
397+
return
398+
399+
(x,) = node.inputs
400+
401+
if node.op.core_op.compute_uv:
402+
# compute_uv=True returns [u, s, v].
403+
# if at least u or v is used, no need to rewrite this node.
404+
if (
405+
len(fgraph.clients[node.outputs[0]]) > 0
406+
or len(fgraph.clients[node.outputs[2]]) > 0
407+
):
408+
return
409+
410+
# Else, has to replace the s of this node with s of an SVD Op that compute_uv=False.
411+
# First, iterate to see if there is an SVD Op that can be reused.
412+
for cl, _ in fgraph.clients[x]:
413+
if cl == "output":
414+
continue
415+
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
416+
if not cl.op.core_op.compute_uv:
417+
return {
418+
node.outputs[1]: cl.outputs[0],
419+
}
420+
421+
# If no SVD reusable, return a new one.
422+
return {
423+
node.outputs[1]: svd(
424+
x, full_matrices=node.op.core_op.full_matrices, compute_uv=False
425+
),
426+
}
427+
428+
else:
429+
# compute_uv=False returns [s].
430+
# We want rewrite if there is another one with compute_uv=True.
431+
# For this case, just reuse the `s` from the one with compute_uv=True.
432+
for cl, _ in fgraph.clients[x]:
433+
if cl == "output":
434+
continue
435+
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, SVD):
436+
if cl.op.core_op.compute_uv and (
437+
len(fgraph.clients[cl.outputs[0]]) > 0
438+
or len(fgraph.clients[cl.outputs[2]]) > 0
439+
):
440+
return [cl.outputs[1]]

tests/tensor/rewriting/test_linalg.py

+66
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
from pytensor.tensor.elemwise import DimShuffle
1616
from pytensor.tensor.math import _allclose, dot, matmul
1717
from pytensor.tensor.nlinalg import (
18+
SVD,
1819
Det,
1920
KroneckerProduct,
2021
MatrixInverse,
2122
MatrixPinv,
2223
matrix_inverse,
24+
svd,
2325
)
2426
from pytensor.tensor.rewriting.linalg import inv_as_solve
2527
from pytensor.tensor.slinalg import (
@@ -390,3 +392,67 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
390392
test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
391393

392394
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
395+
396+
397+
def test_svd_uv_merge():
398+
a = matrix("a")
399+
s_1 = svd(a, full_matrices=False, compute_uv=False)
400+
_, s_2, _ = svd(a, full_matrices=False, compute_uv=True)
401+
_, s_3, _ = svd(a, full_matrices=True, compute_uv=True)
402+
u_4, s_4, v_4 = svd(a, full_matrices=True, compute_uv=True)
403+
# `grad` will introduces an SVD Op with compute_uv=True
404+
# full_matrices = True is not supported for grad of svd
405+
gs = pt.grad(pt.sum(s_1), a)
406+
407+
# 1. compute_uv=False needs rewriting with compute_uv=True
408+
f_1 = pytensor.function([a], gs)
409+
nodes = f_1.maker.fgraph.apply_nodes
410+
svd_counter = 0
411+
for node in nodes:
412+
if isinstance(node.op, SVD):
413+
assert node.op.compute_uv
414+
svd_counter += 1
415+
assert svd_counter == 1
416+
417+
# 2. compute_uv=True needs rewriting with compute=False, reuse node
418+
f_2 = pytensor.function([a], [s_1, s_2])
419+
nodes = f_2.maker.fgraph.apply_nodes
420+
svd_counter = 0
421+
for node in nodes:
422+
if isinstance(node.op, SVD):
423+
assert not node.op.compute_uv
424+
svd_counter += 1
425+
assert svd_counter == 1
426+
427+
# 3. compute_uv=True needs rewriting with compute=False, create new node
428+
# full_matrices needs to retain the value
429+
f_3 = pytensor.function([a], [s_2])
430+
nodes = f_3.maker.fgraph.apply_nodes
431+
svd_counter = 0
432+
for node in nodes:
433+
if isinstance(node.op, SVD):
434+
assert not node.op.compute_uv
435+
svd_counter += 1
436+
assert svd_counter == 1
437+
438+
# Case 2 of 3. for a different full_matrices
439+
f_4 = pytensor.function([a], [s_3])
440+
nodes = f_4.maker.fgraph.apply_nodes
441+
svd_counter = 0
442+
for node in nodes:
443+
if isinstance(node.op, SVD):
444+
assert not node.op.compute_uv
445+
assert node.op.full_matrices
446+
svd_counter += 1
447+
assert svd_counter == 1
448+
449+
# 4. No rewrite should happen
450+
f_5 = pytensor.function([a], [u_4])
451+
nodes = f_5.maker.fgraph.apply_nodes
452+
svd_counter = 0
453+
for node in nodes:
454+
if isinstance(node.op, SVD):
455+
assert node.op.full_matrices
456+
assert node.op.compute_uv
457+
svd_counter += 1
458+
assert svd_counter == 1

0 commit comments

Comments
 (0)