|
15 | 15 | from pytensor.tensor.elemwise import DimShuffle
|
16 | 16 | from pytensor.tensor.math import _allclose, dot, matmul
|
17 | 17 | from pytensor.tensor.nlinalg import (
|
| 18 | + SVD, |
18 | 19 | Det,
|
19 | 20 | KroneckerProduct,
|
20 | 21 | MatrixInverse,
|
21 | 22 | MatrixPinv,
|
22 | 23 | matrix_inverse,
|
| 24 | + svd, |
23 | 25 | )
|
24 | 26 | from pytensor.tensor.rewriting.linalg import inv_as_solve
|
25 | 27 | from pytensor.tensor.slinalg import (
|
@@ -390,3 +392,67 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
|
390 | 392 | test_vals = [x @ np.swapaxes(x, -1, -2) for x in test_vals]
|
391 | 393 |
|
392 | 394 | np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
|
| 395 | + |
| 396 | + |
| 397 | +def test_svd_uv_merge(): |
| 398 | + a = matrix("a") |
| 399 | + s_1 = svd(a, full_matrices=False, compute_uv=False) |
| 400 | + _, s_2, _ = svd(a, full_matrices=False, compute_uv=True) |
| 401 | + _, s_3, _ = svd(a, full_matrices=True, compute_uv=True) |
| 402 | + u_4, s_4, v_4 = svd(a, full_matrices=True, compute_uv=True) |
| 403 | + # `grad` will introduces an SVD Op with compute_uv=True |
| 404 | + # full_matrices = True is not supported for grad of svd |
| 405 | + gs = pt.grad(pt.sum(s_1), a) |
| 406 | + |
| 407 | + # 1. compute_uv=False needs rewriting with compute_uv=True |
| 408 | + f_1 = pytensor.function([a], gs) |
| 409 | + nodes = f_1.maker.fgraph.apply_nodes |
| 410 | + svd_counter = 0 |
| 411 | + for node in nodes: |
| 412 | + if isinstance(node.op, SVD): |
| 413 | + assert node.op.compute_uv |
| 414 | + svd_counter += 1 |
| 415 | + assert svd_counter == 1 |
| 416 | + |
| 417 | + # 2. compute_uv=True needs rewriting with compute=False, reuse node |
| 418 | + f_2 = pytensor.function([a], [s_1, s_2]) |
| 419 | + nodes = f_2.maker.fgraph.apply_nodes |
| 420 | + svd_counter = 0 |
| 421 | + for node in nodes: |
| 422 | + if isinstance(node.op, SVD): |
| 423 | + assert not node.op.compute_uv |
| 424 | + svd_counter += 1 |
| 425 | + assert svd_counter == 1 |
| 426 | + |
| 427 | + # 3. compute_uv=True needs rewriting with compute=False, create new node |
| 428 | + # full_matrices needs to retain the value |
| 429 | + f_3 = pytensor.function([a], [s_2]) |
| 430 | + nodes = f_3.maker.fgraph.apply_nodes |
| 431 | + svd_counter = 0 |
| 432 | + for node in nodes: |
| 433 | + if isinstance(node.op, SVD): |
| 434 | + assert not node.op.compute_uv |
| 435 | + svd_counter += 1 |
| 436 | + assert svd_counter == 1 |
| 437 | + |
| 438 | + # Case 2 of 3. for a different full_matrices |
| 439 | + f_4 = pytensor.function([a], [s_3]) |
| 440 | + nodes = f_4.maker.fgraph.apply_nodes |
| 441 | + svd_counter = 0 |
| 442 | + for node in nodes: |
| 443 | + if isinstance(node.op, SVD): |
| 444 | + assert not node.op.compute_uv |
| 445 | + assert node.op.full_matrices |
| 446 | + svd_counter += 1 |
| 447 | + assert svd_counter == 1 |
| 448 | + |
| 449 | + # 4. No rewrite should happen |
| 450 | + f_5 = pytensor.function([a], [u_4]) |
| 451 | + nodes = f_5.maker.fgraph.apply_nodes |
| 452 | + svd_counter = 0 |
| 453 | + for node in nodes: |
| 454 | + if isinstance(node.op, SVD): |
| 455 | + assert node.op.full_matrices |
| 456 | + assert node.op.compute_uv |
| 457 | + svd_counter += 1 |
| 458 | + assert svd_counter == 1 |
0 commit comments