diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index a9521dc3cd..0f5c1b2fe0 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -6,7 +6,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise -from pytensor.tensor.basic import Alloc, AllocEmpty, ARange +from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join @singledispatch @@ -89,3 +89,14 @@ def arange(start, stop, step): return torch.arange(start, stop, step, dtype=dtype) return arange + + +@pytorch_funcify.register(Join) +def pytorch_funcify_Join(op, **kwargs): + def join(axis, *tensors): + # tensors could also be tuples, and in this case they don't have a ndim + tensors = [torch.tensor(tensor) for tensor in tensors] + + return torch.cat(tensors, dim=axis) + + return join diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 0ddb25765f..b1ad5582c5 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,6 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -37,6 +38,69 @@ def dimshuffle(x): return dimshuffle +@pytorch_funcify.register(Sum) +def pytorch_funcify_sum(op, **kwargs): + def torch_sum(x): + return torch.sum(x, dim=op.axis) + + return torch_sum + + +@pytorch_funcify.register(All) +def pytorch_funcify_all(op, **kwargs): + def torch_all(x): + return torch.all(x, dim=op.axis) + + return torch_all + + +@pytorch_funcify.register(Prod) +def pytorch_funcify_prod(op, **kwargs): + def torch_prod(x): + if isinstance(op.axis, tuple): + for d in sorted(op.axis, reverse=True): + x = torch.prod(x, dim=d) + return x + else: + return torch.prod(x.flatten(), dim=0) + + return torch_prod + + +@pytorch_funcify.register(Any) +def pytorch_funcify_any(op, **kwargs): + def torch_any(x): + return torch.any(x, dim=op.axis) + + return torch_any + + +@pytorch_funcify.register(Max) +def pytorch_funcify_max(op, **kwargs): + def torch_max(x): + if isinstance(op.axis, tuple): + for d in sorted(op.axis, reverse=True): + x = torch.max(x, dim=d).values + return x + else: + return torch.max(x.flatten(), dim=0).values + + return torch_max + + +@pytorch_funcify.register(Min) +def pytorch_funcify_min(op, **kwargs): + def torch_min(x): + if isinstance(op.axis, tuple): + for d in sorted(op.axis, reverse=True): + x = torch.min(x, dim=d).values + return x + else: + return torch.min(x.flatten(), dim=0).values + + return torch_min + + @pytorch_funcify.register(Softmax) def pytorch_funcify_Softmax(op, **kwargs): axis = op.axis diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index cb6e652e23..c6750361a7 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import pytensor.tensor.basic as ptb from pytensor.compile.function import function from pytensor.compile.mode import get_mode from pytensor.compile.sharedvalue import SharedVariable, shared @@ -13,7 +14,7 @@ from pytensor.graph.op import Op from pytensor.raise_op import CheckAndRaise from pytensor.tensor import alloc, arange, as_tensor, empty -from pytensor.tensor.type import scalar, vector +from pytensor.tensor.type import matrix, scalar, vector torch = pytest.importorskip("torch") @@ -235,3 +236,42 @@ def test_arange(): FunctionGraph([start, stop, step], [out]), [np.array(1), np.array(10), np.array(2)], ) + + +def test_pytorch_Join(): + a = matrix("a") + b = matrix("b") + + x = ptb.join(0, a, b) + x_fg = FunctionGraph([a, b], [x]) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), + ], + ) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0]].astype(config.floatX), + ], + ) + + x = ptb.join(1, a, b) + x_fg = FunctionGraph([a, b], [x]) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), + ], + ) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), + np.c_[[5.0, 6.0]].astype(config.floatX), + ], + ) diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 586789772f..afb62848cc 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -2,11 +2,12 @@ import pytest import pytensor.tensor as pt +import pytensor.tensor.math as ptm from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.tensor import elemwise as pt_elemwise from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax -from pytensor.tensor.type import matrix, tensor, vector +from pytensor.tensor.type import matrix, tensor, tensor3, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -57,6 +58,46 @@ def test_pytorch_elemwise(): compare_pytorch_and_py(fg, [[0.9, 0.9]]) +@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min]) +@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)]) +def test_pytorch_careduce(fn, axis): + a_pt = tensor3("a") + test_value = np.array( + [ + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + ], + [ + [3, 3, 3, 3], + [ + 4, + 4, + 4, + 4, + ], + ], + ] + ).astype(config.floatX) + + x = fn(a_pt, axis=axis) + x_fg = FunctionGraph([a_pt], [x]) + + compare_pytorch_and_py(x_fg, [test_value]) + + +@pytest.mark.parametrize("fn", [ptm.any, ptm.all]) +@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) +def test_pytorch_any_all(fn, axis): + a_pt = matrix("a") + test_value = np.array([[True, False, True], [False, True, True]]) + + x = fn(a_pt, axis=axis) + x_fg = FunctionGraph([a_pt], [x]) + + compare_pytorch_and_py(x_fg, [test_value]) + + @pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("axis", [None, 0, 1]) def test_softmax(axis, dtype):