From a2c3ae06314b735370170c5ee3817bc996e340be Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Thu, 11 Jul 2024 19:53:03 +0200 Subject: [PATCH 1/6] Implemented linear algebra functions in PyTorch - BlockDiagonal - Cholesky - Eigvalsh - Solve - SolveTriangular --- pytensor/link/pytorch/dispatch/__init__.py | 1 + pytensor/link/pytorch/dispatch/slinalg.py | 86 +++++++++++++ tests/link/pytorch/test_slinalg.py | 142 +++++++++++++++++++++ 3 files changed, 229 insertions(+) create mode 100644 pytensor/link/pytorch/dispatch/slinalg.py create mode 100644 tests/link/pytorch/test_slinalg.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index fddded525a..f732848afc 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -11,4 +11,5 @@ import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.subtensor +import pytensor.link.pytorch.dispatch.slinalg # isort: on diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py new file mode 100644 index 0000000000..6d70f84e14 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -0,0 +1,86 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.slinalg import ( + BlockDiagonal, + Cholesky, + Eigvalsh, + Solve, + SolveTriangular, +) + + +@pytorch_funcify.register(Eigvalsh) +def pytorch_funcify_Eigvalsh(op, **kwargs): + if op.lower: + UPLO = "L" + else: + UPLO = "U" + + def eigvalsh(a, b): + if b is not None: + raise NotImplementedError( + "torch.linalg.eigvalsh does not support generalized eigenvector problems (b != None)" + ) + return torch.linalg.eigvalsh(a, UPLO=UPLO) + + return eigvalsh + + +@pytorch_funcify.register(Cholesky) +def pytorch_funcify_Cholesky(op, **kwargs): + upper = not op.lower + + def cholesky(a): + return torch.linalg.cholesky(a, upper=upper) + + return cholesky + + +@pytorch_funcify.register(Solve) +def pytorch_funcify_Solve(op, **kwargs): + lower = False + if op.assume_a != "gen" and op.lower: + lower = True + + def solve(a, b): + if lower: + return torch.linalg.solve(torch.tril(a), b) + + return torch.linalg.solve(a, b) + + return solve + + +@pytorch_funcify.register(SolveTriangular) +def pytorch_funcify_SolveTriangular(op, **kwargs): + if op.check_finite: + raise NotImplementedError( + "Option check_finite is not implemented in torch.linalg.solve_triangular" + ) + + upper = not op.lower + unit_diagonal = op.unit_diagonal + trans = op.trans + + def solve_triangular(A, b): + A_p = A + if trans == 1 or trans == "T": + A_p = A.T + + if trans == 2 or trans == "C": + A_p = A.H + + return torch.linalg.solve_triangular( + A_p, b, upper=upper, unitriangular=unit_diagonal + ) + + return solve_triangular + + +@pytorch_funcify.register(BlockDiagonal) +def pytorch_funcify_BlockDiagonalMatrix(op, **kwargs): + def block_diag(*inputs): + return torch.block_diag(*inputs) + + return block_diag diff --git a/tests/link/pytorch/test_slinalg.py b/tests/link/pytorch/test_slinalg.py new file mode 100644 index 0000000000..9fb251b0d0 --- /dev/null +++ b/tests/link/pytorch/test_slinalg.py @@ -0,0 +1,142 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor import slinalg as pt_slinalg +from pytensor.tensor.type import matrix, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +@pytest.mark.parametrize("lower", [False, True]) +def test_pytorch_eigvalsh(lower): + A = matrix("A") + B = matrix("B") + + out = pt_slinalg.eigvalsh(A, B, lower=lower) + out_fg = FunctionGraph([A, B], [out]) + + with pytest.raises(NotImplementedError): + compare_pytorch_and_py( + out_fg, + [ + np.array( + [[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]] + ).astype(config.floatX), + np.array( + [[10, 0, 1, 3], [0, 12, 7, 8], [1, 7, 14, 2], [3, 8, 2, 16]] + ).astype(config.floatX), + ], + ) + compare_pytorch_and_py( + out_fg, + [ + np.array([[6, 3, 1, 5], [3, 0, 5, 1], [1, 5, 6, 2], [5, 1, 2, 2]]).astype( + config.floatX + ), + None, + ], + ) + + +def test_pytorch_basic(): + rng = np.random.default_rng(28494) + + x = matrix("x") + b = vector("b") + + out = pt_slinalg.cholesky(x) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + out = pt_slinalg.Cholesky(lower=False)(x) + out_fg = FunctionGraph([x], [out]) + compare_pytorch_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + out = pt_slinalg.solve(x, b) + out_fg = FunctionGraph([x, b], [out]) + compare_pytorch_and_py( + out_fg, + [ + np.eye(10).astype(config.floatX), + np.arange(10).astype(config.floatX), + ], + ) + + +@pytest.mark.xfail(reason="Blockwise not implemented") +@pytest.mark.parametrize( + "check_finite", + (False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))), +) +@pytest.mark.parametrize("lower", [False, True]) +@pytest.mark.parametrize("trans", [0, 1, 2, "S"]) +def test_pytorch_SolveTriangular(trans, lower, check_finite): + x = matrix("x") + b = vector("b") + + out = pt_slinalg.solve_triangular( + x, + b, + trans=trans, + lower=lower, + check_finite=check_finite, + ) + out_fg = FunctionGraph([x, b], [out]) + compare_pytorch_and_py( + out_fg, + [ + np.eye(10).astype(config.floatX), + np.arange(10).astype(config.floatX), + ], + ) + + +def test_pytorch_block_diag(): + A = matrix("A") + B = matrix("B") + C = matrix("C") + D = matrix("D") + + out = pt_slinalg.block_diag(A, B, C, D) + out_fg = FunctionGraph([A, B, C, D], [out]) + + compare_pytorch_and_py( + out_fg, + [ + np.random.normal(size=(5, 5)).astype(config.floatX), + np.random.normal(size=(3, 3)).astype(config.floatX), + np.random.normal(size=(2, 2)).astype(config.floatX), + np.random.normal(size=(4, 4)).astype(config.floatX), + ], + ) + + +@pytest.mark.xfail(reason="Blockwise not implemented") +def test_pytorch_block_diag_blockwise(): + A = pt.tensor3("A") + B = pt.tensor3("B") + out = pt_slinalg.block_diag(A, B) + out_fg = FunctionGraph([A, B], [out]) + compare_pytorch_and_py( + out_fg, + [ + np.random.normal(size=(5, 5, 5)).astype(config.floatX), + np.random.normal(size=(5, 3, 3)).astype(config.floatX), + ], + ) From b816fe864f70979b4cf7dc8dba00218c89d4ba6a Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Sat, 13 Jul 2024 19:13:36 +0200 Subject: [PATCH 2/6] Fixed SolveTriangular Op in PyTorch Added support for a b param of ndim = 1. --- pytensor/link/pytorch/dispatch/slinalg.py | 13 +++++++++++-- tests/link/pytorch/test_slinalg.py | 3 +-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py index 6d70f84e14..0765ba0ad0 100644 --- a/pytensor/link/pytorch/dispatch/slinalg.py +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -71,10 +71,19 @@ def solve_triangular(A, b): if trans == 2 or trans == "C": A_p = A.H - return torch.linalg.solve_triangular( - A_p, b, upper=upper, unitriangular=unit_diagonal + b_p = b + if b.ndim == 1: + b_p = b[:, None] + + res = torch.linalg.solve_triangular( + A_p, b_p, upper=upper, unitriangular=unit_diagonal ) + if b.ndim == 1 and res.shape[1] == 1: + return res.flatten() + + return res + return solve_triangular diff --git a/tests/link/pytorch/test_slinalg.py b/tests/link/pytorch/test_slinalg.py index 9fb251b0d0..485f778b60 100644 --- a/tests/link/pytorch/test_slinalg.py +++ b/tests/link/pytorch/test_slinalg.py @@ -57,7 +57,7 @@ def test_pytorch_basic(): ], ) - out = pt_slinalg.Cholesky(lower=False)(x) + out = pt_slinalg.cholesky(x, lower=False) out_fg = FunctionGraph([x], [out]) compare_pytorch_and_py( out_fg, @@ -79,7 +79,6 @@ def test_pytorch_basic(): ) -@pytest.mark.xfail(reason="Blockwise not implemented") @pytest.mark.parametrize( "check_finite", (False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))), From 0811baa7fa9c9212e1f223b8516e94f617b80faf Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Thu, 18 Jul 2024 21:02:27 +0200 Subject: [PATCH 3/6] Revisited slinalg implementation in PyTorch - Removed lower param to solve - Refactored tests --- pytensor/link/pytorch/dispatch/slinalg.py | 17 +++++------------ tests/link/pytorch/test_slinalg.py | 8 ++++++-- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py index 0765ba0ad0..af3ac8efaf 100644 --- a/pytensor/link/pytorch/dispatch/slinalg.py +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -39,14 +39,7 @@ def cholesky(a): @pytorch_funcify.register(Solve) def pytorch_funcify_Solve(op, **kwargs): - lower = False - if op.assume_a != "gen" and op.lower: - lower = True - def solve(a, b): - if lower: - return torch.linalg.solve(torch.tril(a), b) - return torch.linalg.solve(a, b) return solve @@ -64,12 +57,12 @@ def pytorch_funcify_SolveTriangular(op, **kwargs): trans = op.trans def solve_triangular(A, b): - A_p = A - if trans == 1 or trans == "T": + if trans in [1, "T"]: A_p = A.T - - if trans == 2 or trans == "C": - A_p = A.H + elif trans in [2, "C"]: + A_p = A.conj().T + else: + A_p = A b_p = b if b.ndim == 1: diff --git a/tests/link/pytorch/test_slinalg.py b/tests/link/pytorch/test_slinalg.py index 485f778b60..74f9159de9 100644 --- a/tests/link/pytorch/test_slinalg.py +++ b/tests/link/pytorch/test_slinalg.py @@ -40,11 +40,10 @@ def test_pytorch_eigvalsh(lower): ) -def test_pytorch_basic(): +def test_pytorch_cholesky(): rng = np.random.default_rng(28494) x = matrix("x") - b = vector("b") out = pt_slinalg.cholesky(x) out_fg = FunctionGraph([x], [out]) @@ -68,6 +67,11 @@ def test_pytorch_basic(): ], ) + +def test_pytorch_solve(): + x = matrix("x") + b = vector("b") + out = pt_slinalg.solve(x, b) out_fg = FunctionGraph([x, b], [out]) compare_pytorch_and_py( From c3c3875e9ef12642624c1dccfc1d09eabf2f3af2 Mon Sep 17 00:00:00 2001 From: Diego Sandoval <46681084+twaclaw@users.noreply.github.com> Date: Thu, 18 Jul 2024 21:03:49 +0200 Subject: [PATCH 4/6] Update pytensor/link/pytorch/dispatch/slinalg.py Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/link/pytorch/dispatch/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py index af3ac8efaf..ebc8cd847a 100644 --- a/pytensor/link/pytorch/dispatch/slinalg.py +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -64,7 +64,7 @@ def solve_triangular(A, b): else: A_p = A - b_p = b + b_p = pt.atleast_2d(b, left=False) if b.ndim == 1: b_p = b[:, None] From b85b76f26f1d4185588b7f163c3e9d939756ea85 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Thu, 18 Jul 2024 21:09:00 +0200 Subject: [PATCH 5/6] Revert "Update pytensor/link/pytorch/dispatch/slinalg.py" This reverts commit 0dc65e2c5d2550a19fb7dd777c5d4baa433f7602. --- pytensor/link/pytorch/dispatch/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/slinalg.py b/pytensor/link/pytorch/dispatch/slinalg.py index ebc8cd847a..af3ac8efaf 100644 --- a/pytensor/link/pytorch/dispatch/slinalg.py +++ b/pytensor/link/pytorch/dispatch/slinalg.py @@ -64,7 +64,7 @@ def solve_triangular(A, b): else: A_p = A - b_p = pt.atleast_2d(b, left=False) + b_p = b if b.ndim == 1: b_p = b[:, None] From 4067a87ad233242c5b857f593e28ad0a60fb7254 Mon Sep 17 00:00:00 2001 From: Diego Sandoval Date: Thu, 18 Jul 2024 21:39:41 +0200 Subject: [PATCH 6/6] Removed block_diag test with BlockWise --- tests/link/pytorch/test_slinalg.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tests/link/pytorch/test_slinalg.py b/tests/link/pytorch/test_slinalg.py index 74f9159de9..3055bf809d 100644 --- a/tests/link/pytorch/test_slinalg.py +++ b/tests/link/pytorch/test_slinalg.py @@ -1,7 +1,6 @@ import numpy as np import pytest -import pytensor.tensor as pt from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.tensor import slinalg as pt_slinalg @@ -128,18 +127,3 @@ def test_pytorch_block_diag(): np.random.normal(size=(4, 4)).astype(config.floatX), ], ) - - -@pytest.mark.xfail(reason="Blockwise not implemented") -def test_pytorch_block_diag_blockwise(): - A = pt.tensor3("A") - B = pt.tensor3("B") - out = pt_slinalg.block_diag(A, B) - out_fg = FunctionGraph([A, B], [out]) - compare_pytorch_and_py( - out_fg, - [ - np.random.normal(size=(5, 5, 5)).astype(config.floatX), - np.random.normal(size=(5, 3, 3)).astype(config.floatX), - ], - )