From 48d16ffdfc3fe189abbd7e043059edfb68164b92 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 28 Jun 2024 15:36:16 +0530 Subject: [PATCH 01/10] Pytorch CAReduce Ops --- pytensor/link/pytorch/dispatch/basic.py | 19 +++++++++- pytensor/link/pytorch/dispatch/elemwise.py | 9 +++++ tests/link/pytorch/test_basic.py | 42 +++++++++++++++++++++- tests/link/pytorch/test_elemwise.py | 10 ++++++ 4 files changed, 78 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index a9521dc3cd..4da7c00307 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -6,7 +6,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise -from pytensor.tensor.basic import Alloc, AllocEmpty, ARange +from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Join @singledispatch @@ -89,3 +89,20 @@ def arange(start, stop, step): return torch.arange(start, stop, step, dtype=dtype) return arange + + +@pytorch_funcify.register(Join) +def jax_funcify_Join(op, **kwargs): + def join(axis, *tensors): + # tensors could also be tuples, and in this case they don't have a ndim + tensors = [torch.tensor(tensor) for tensor in tensors] + view = op.view + if (view != -1) and all( + tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :] + ): + return tensors[view] + + else: + return torch.cat(tensors, dim=axis) + + return join diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index f39e108bed..954c67bebf 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,6 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise +from pytensor.tensor.math import Sum @pytorch_funcify.register(Elemwise) @@ -34,3 +35,11 @@ def dimshuffle(x): return res return dimshuffle + + +@pytorch_funcify.register(Sum) +def pytorch_funcify_careduce(op, **kwargs): + def torch_sum(x): + return torch.sum(x, dim=op.axis) + + return torch_sum diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index cb6e652e23..393808823a 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import pytensor.tensor.basic as ptb from pytensor.compile.function import function from pytensor.compile.mode import get_mode from pytensor.compile.sharedvalue import SharedVariable, shared @@ -13,7 +14,7 @@ from pytensor.graph.op import Op from pytensor.raise_op import CheckAndRaise from pytensor.tensor import alloc, arange, as_tensor, empty -from pytensor.tensor.type import scalar, vector +from pytensor.tensor.type import matrix, scalar, vector torch = pytest.importorskip("torch") @@ -235,3 +236,42 @@ def test_arange(): FunctionGraph([start, stop, step], [out]), [np.array(1), np.array(10), np.array(2)], ) + + +def test_jax_Join(): + a = matrix("a") + b = matrix("b") + + x = ptb.join(0, a, b) + x_fg = FunctionGraph([a, b], [x]) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), + ], + ) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0]].astype(config.floatX), + ], + ) + + x = ptb.join(1, a, b) + x_fg = FunctionGraph([a, b], [x]) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), + ], + ) + compare_pytorch_and_py( + x_fg, + [ + np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), + np.c_[[5.0, 6.0]].astype(config.floatX), + ], + ) diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 1d843b8051..2da5579402 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -53,3 +53,13 @@ def test_pytorch_elemwise(): fg = FunctionGraph([x], [out]) compare_pytorch_and_py(fg, [[0.9, 0.9]]) + + +def test_pytorch_sum(): + a_pt = vector("a") + test_value = np.r_[1, 2, 3].astype(config.floatX) + + x = pt.math.sum(a_pt, axis=None) + x_fg = FunctionGraph([a_pt], [x]) + + compare_pytorch_and_py(x_fg, [test_value]) From a8c1fb1f3c7de4a55b7b7292074303a90fafe9b6 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 28 Jun 2024 15:39:13 +0530 Subject: [PATCH 02/10] rename test --- tests/link/pytorch/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 393808823a..c6750361a7 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -238,7 +238,7 @@ def test_arange(): ) -def test_jax_Join(): +def test_pytorch_Join(): a = matrix("a") b = matrix("b") From 14dde44435dd9575986596c6c2fd499c50ce050f Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 28 Jun 2024 16:43:33 +0530 Subject: [PATCH 03/10] remove view kwarg --- pytensor/link/pytorch/dispatch/basic.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 4da7c00307..0f5c1b2fe0 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -92,17 +92,11 @@ def arange(start, stop, step): @pytorch_funcify.register(Join) -def jax_funcify_Join(op, **kwargs): +def pytorch_funcify_Join(op, **kwargs): def join(axis, *tensors): # tensors could also be tuples, and in this case they don't have a ndim tensors = [torch.tensor(tensor) for tensor in tensors] - view = op.view - if (view != -1) and all( - tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :] - ): - return tensors[view] - - else: - return torch.cat(tensors, dim=axis) + + return torch.cat(tensors, dim=axis) return join From 0bdeb2e40fa83530281e38855b4d154e09dcc2fc Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 29 Jun 2024 16:33:44 +0530 Subject: [PATCH 04/10] Add pytorch implementation of All --- pytensor/link/pytorch/dispatch/elemwise.py | 14 ++++++++++++-- tests/link/pytorch/test_elemwise.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 954c67bebf..ad24768a53 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,7 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Sum +from pytensor.tensor.math import All, Sum @pytorch_funcify.register(Elemwise) @@ -38,8 +38,18 @@ def dimshuffle(x): @pytorch_funcify.register(Sum) -def pytorch_funcify_careduce(op, **kwargs): +def pytorch_funcify_sum(op, **kwargs): def torch_sum(x): return torch.sum(x, dim=op.axis) return torch_sum + + +@pytorch_funcify.register(All) +def pytorch_funcify_all(op, **kwargs): + dim = op.axis + + def torch_all(x): + return torch.all(x, dim=dim) + + return torch_all diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 2da5579402..4126718400 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -63,3 +63,13 @@ def test_pytorch_sum(): x_fg = FunctionGraph([a_pt], [x]) compare_pytorch_and_py(x_fg, [test_value]) + + +def test_pytorch_all(): + a_pt = matrix("a") + test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) + + x = pt.math.all(a_pt, axis=None) + x_fg = FunctionGraph([a_pt], [x]) + + compare_pytorch_and_py(x_fg, [test_value]) From 519fd477e4f8ed65a3c5bc7e14d564cc94ece367 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 29 Jun 2024 18:25:58 +0530 Subject: [PATCH 05/10] Fix ruff format --- pytensor/link/pytorch/dispatch/elemwise.py | 1 - tests/link/pytorch/test_elemwise.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 0f7e7383b3..6f9f1c231c 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -6,7 +6,6 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad - @pytorch_funcify.register(Elemwise) def pytorch_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 1b8c729b88..c757603c48 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -75,6 +75,8 @@ def test_pytorch_all(): x_fg = FunctionGraph([a_pt], [x]) compare_pytorch_and_py(x_fg, [test_value]) + + @pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("axis", [None, 0, 1]) def test_softmax(axis, dtype): From 9073a563219c6b2c22a0dd6febd42f83cca31d67 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Tue, 2 Jul 2024 01:27:59 +0530 Subject: [PATCH 06/10] Implement Prod and Any Ops and parameterize tests --- pytensor/link/pytorch/dispatch/elemwise.py | 22 ++++++++++++- tests/link/pytorch/test_elemwise.py | 37 ++++++++++++++++++---- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 6f9f1c231c..3fcec04a80 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,7 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import All, Sum +from pytensor.tensor.math import All, Any, Prod, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -56,6 +56,26 @@ def torch_all(x): return torch_all +@pytorch_funcify.register(Prod) +def pytorch_funcify_prod(op, **kwargs): + dim = op.axis[0] + + def torch_prod(x): + return torch.prod(x, dim=dim) + + return torch_prod + + +@pytorch_funcify.register(Any) +def pytorch_funcify_any(op, **kwargs): + dim = op.axis + + def torch_any(x): + return torch.any(x, dim=dim) + + return torch_any + + @pytorch_funcify.register(Softmax) def pytorch_funcify_Softmax(op, **kwargs): axis = op.axis diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index c757603c48..01a16ab230 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -2,6 +2,7 @@ import pytest import pytensor.tensor as pt +import pytensor.tensor.math as ptm from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.tensor import elemwise as pt_elemwise @@ -57,21 +58,45 @@ def test_pytorch_elemwise(): compare_pytorch_and_py(fg, [[0.9, 0.9]]) -def test_pytorch_sum(): - a_pt = vector("a") - test_value = np.r_[1, 2, 3].astype(config.floatX) +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_pytorch_sum(axis): + a_pt = matrix("a") + test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) - x = pt.math.sum(a_pt, axis=None) + x = pt.math.sum(a_pt, axis=axis) x_fg = FunctionGraph([a_pt], [x]) compare_pytorch_and_py(x_fg, [test_value]) -def test_pytorch_all(): +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_pytorch_all(axis): + a_pt = matrix("a") + test_value = np.array([[True, False, True], [False, True, True]]) + + x = ptm.all(a_pt, axis=axis) + x_fg = FunctionGraph([a_pt], [x]) + + compare_pytorch_and_py(x_fg, [test_value]) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_pytorch_prod(axis): a_pt = matrix("a") test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) - x = pt.math.all(a_pt, axis=None) + x = ptm.prod(a_pt, axis=axis) + x_fg = FunctionGraph([a_pt], [x]) + + compare_pytorch_and_py(x_fg, [test_value]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_pytorch_any(axis): + a_pt = matrix("a") + test_value = np.array([[True, False, True], [False, True, True]]) + + x = ptm.any(a_pt, axis=axis) x_fg = FunctionGraph([a_pt], [x]) compare_pytorch_and_py(x_fg, [test_value]) From 036082b47792b92fb2e12ec4bb7cf3b3a00c037e Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Tue, 2 Jul 2024 23:22:23 +0530 Subject: [PATCH 07/10] Implement Max and Min Op --- pytensor/link/pytorch/dispatch/elemwise.py | 22 +++++++++++++++++++++- tests/link/pytorch/test_elemwise.py | 22 ++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 3fcec04a80..8592ef648a 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -2,7 +2,7 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import All, Any, Prod, Sum +from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -76,6 +76,26 @@ def torch_any(x): return torch_any +@pytorch_funcify.register(Max) +def pytorch_funcify_max(op, **kwargs): + dim = op.axis[0] + + def torch_max(x): + return torch.max(x, dim=dim).values + + return torch_max + + +@pytorch_funcify.register(Min) +def pytorch_funcify_min(op, **kwargs): + dim = op.axis[0] + + def torch_min(x): + return torch.min(x, dim=dim).values + + return torch_min + + @pytorch_funcify.register(Softmax) def pytorch_funcify_Softmax(op, **kwargs): axis = op.axis diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 01a16ab230..c80944a539 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -102,6 +102,28 @@ def test_pytorch_any(axis): compare_pytorch_and_py(x_fg, [test_value]) +@pytest.mark.parametrize("axis", [0, 1]) +def test_pytorch_max(axis): + a_pt = matrix("a") + test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) + + x = ptm.max(a_pt, axis=axis) + x_fg = FunctionGraph([a_pt], [x]) + + compare_pytorch_and_py(x_fg, [test_value]) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_pytorch_min(axis): + a_pt = matrix("a") + test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) + + x = ptm.min(a_pt, axis=axis) + x_fg = FunctionGraph([a_pt], [x]) + + compare_pytorch_and_py(x_fg, [test_value]) + + @pytest.mark.parametrize("dtype", ["float64", "int64"]) @pytest.mark.parametrize("axis", [None, 0, 1]) def test_softmax(axis, dtype): From 8dfa0598b99aaf4d06f15b43293c04069e4dd941 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Thu, 4 Jul 2024 01:07:50 +0530 Subject: [PATCH 08/10] Modify Ops for multiple dims --- pytensor/link/pytorch/dispatch/elemwise.py | 35 ++++---- tests/link/pytorch/test_elemwise.py | 95 ++++++++++++++++++---- 2 files changed, 101 insertions(+), 29 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 8592ef648a..7b85c0838d 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -48,50 +48,55 @@ def torch_sum(x): @pytorch_funcify.register(All) def pytorch_funcify_all(op, **kwargs): - dim = op.axis - def torch_all(x): - return torch.all(x, dim=dim) + return torch.all(x, dim=op.axis) return torch_all @pytorch_funcify.register(Prod) def pytorch_funcify_prod(op, **kwargs): - dim = op.axis[0] - def torch_prod(x): - return torch.prod(x, dim=dim) + if isinstance(op.axis, tuple): + for d in op.axis: + x = torch.prod(x, dim=d, keepdim=True) + return x.squeeze() + else: + return torch.prod(x.flatten(), dim=0) return torch_prod @pytorch_funcify.register(Any) def pytorch_funcify_any(op, **kwargs): - dim = op.axis - def torch_any(x): - return torch.any(x, dim=dim) + return torch.any(x, dim=op.axis) return torch_any @pytorch_funcify.register(Max) def pytorch_funcify_max(op, **kwargs): - dim = op.axis[0] - def torch_max(x): - return torch.max(x, dim=dim).values + if isinstance(op.axis, tuple): + for d in op.axis: + x = torch.max(x, dim=d, keepdim=True).values + return x.squeeze() + else: + return torch.max(x.flatten(), dim=0).values return torch_max @pytorch_funcify.register(Min) def pytorch_funcify_min(op, **kwargs): - dim = op.axis[0] - def torch_min(x): - return torch.min(x, dim=dim).values + if isinstance(op.axis, tuple): + for d in op.axis: + x = torch.min(x, dim=d, keepdim=True).values + return x.squeeze() + else: + return torch.min(x.flatten(), dim=0).values return torch_min diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index c80944a539..2e953c015d 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -7,7 +7,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.tensor import elemwise as pt_elemwise from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax -from pytensor.tensor.type import matrix, tensor, vector +from pytensor.tensor.type import matrix, tensor, tensor3, vector from tests.link.pytorch.test_basic import compare_pytorch_and_py @@ -58,10 +58,26 @@ def test_pytorch_elemwise(): compare_pytorch_and_py(fg, [[0.9, 0.9]]) -@pytest.mark.parametrize("axis", [None, 0, 1]) +@pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)]) def test_pytorch_sum(axis): - a_pt = matrix("a") - test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) + a_pt = tensor3("a") + test_value = np.array( + [ + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + ], + [ + [3, 3, 3, 3], + [ + 4, + 4, + 4, + 4, + ], + ], + ] + ).astype(config.floatX) x = pt.math.sum(a_pt, axis=axis) x_fg = FunctionGraph([a_pt], [x]) @@ -80,10 +96,26 @@ def test_pytorch_all(axis): compare_pytorch_and_py(x_fg, [test_value]) -@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (1, 2)]) def test_pytorch_prod(axis): - a_pt = matrix("a") - test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) + a_pt = tensor3("a") + test_value = np.array( + [ + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + ], + [ + [3, 3, 3, 3], + [ + 4, + 4, + 4, + 4, + ], + ], + ] + ).astype(config.floatX) x = ptm.prod(a_pt, axis=axis) x_fg = FunctionGraph([a_pt], [x]) @@ -91,7 +123,7 @@ def test_pytorch_prod(axis): compare_pytorch_and_py(x_fg, [test_value]) -@pytest.mark.parametrize("axis", [None, 0, 1]) +@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) def test_pytorch_any(axis): a_pt = matrix("a") test_value = np.array([[True, False, True], [False, True, True]]) @@ -102,10 +134,26 @@ def test_pytorch_any(axis): compare_pytorch_and_py(x_fg, [test_value]) -@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("axis", [None, 0, 1, (1, -1)]) def test_pytorch_max(axis): - a_pt = matrix("a") - test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) + a_pt = tensor3("a") + test_value = np.array( + [ + [ + [1, 1, 1, 1], + [2, 2, 2, 2], + ], + [ + [3, 3, 3, 3], + [ + 4, + 4, + 4, + 4, + ], + ], + ] + ).astype(config.floatX) x = ptm.max(a_pt, axis=axis) x_fg = FunctionGraph([a_pt], [x]) @@ -113,10 +161,29 @@ def test_pytorch_max(axis): compare_pytorch_and_py(x_fg, [test_value]) -@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("axis", [None, 0, 1, (1, -1)]) def test_pytorch_min(axis): - a_pt = matrix("a") - test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) + a_pt = tensor3("a") + test_value = np.array( + [ + [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]], + [ + [4, 4, 4, 4], + [ + 5, + 5, + 5, + 5, + ], + [ + 6, + 6, + 6, + 6, + ], + ], + ] + ).astype(config.floatX) x = ptm.min(a_pt, axis=axis) x_fg = FunctionGraph([a_pt], [x]) From 18bbc8c16cd8d66ff2e942ab8ff3547cbf76bc3f Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Thu, 4 Jul 2024 10:49:46 +0530 Subject: [PATCH 09/10] Parametrize tests --- pytensor/link/pytorch/dispatch/elemwise.py | 18 ++-- tests/link/pytorch/test_elemwise.py | 105 ++------------------- 2 files changed, 15 insertions(+), 108 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 7b85c0838d..f0d4b892f1 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -58,9 +58,9 @@ def torch_all(x): def pytorch_funcify_prod(op, **kwargs): def torch_prod(x): if isinstance(op.axis, tuple): - for d in op.axis: - x = torch.prod(x, dim=d, keepdim=True) - return x.squeeze() + for d in op.axis[::-1]: + x = torch.prod(x, dim=d) + return x else: return torch.prod(x.flatten(), dim=0) @@ -79,9 +79,9 @@ def torch_any(x): def pytorch_funcify_max(op, **kwargs): def torch_max(x): if isinstance(op.axis, tuple): - for d in op.axis: - x = torch.max(x, dim=d, keepdim=True).values - return x.squeeze() + for d in op.axis[::-1]: + x = torch.max(x, dim=d).values + return x else: return torch.max(x.flatten(), dim=0).values @@ -92,9 +92,9 @@ def torch_max(x): def pytorch_funcify_min(op, **kwargs): def torch_min(x): if isinstance(op.axis, tuple): - for d in op.axis: - x = torch.min(x, dim=d, keepdim=True).values - return x.squeeze() + for d in op.axis[::-1]: + x = torch.min(x, dim=d).values + return x else: return torch.min(x.flatten(), dim=0).values diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 2e953c015d..64a4b80710 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -58,8 +58,9 @@ def test_pytorch_elemwise(): compare_pytorch_and_py(fg, [[0.9, 0.9]]) +@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min]) @pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)]) -def test_pytorch_sum(axis): +def test_pytorch_careduce(fn, axis): a_pt = tensor3("a") test_value = np.array( [ @@ -79,113 +80,19 @@ def test_pytorch_sum(axis): ] ).astype(config.floatX) - x = pt.math.sum(a_pt, axis=axis) - x_fg = FunctionGraph([a_pt], [x]) - - compare_pytorch_and_py(x_fg, [test_value]) - - -@pytest.mark.parametrize("axis", [None, 0, 1]) -def test_pytorch_all(axis): - a_pt = matrix("a") - test_value = np.array([[True, False, True], [False, True, True]]) - - x = ptm.all(a_pt, axis=axis) - x_fg = FunctionGraph([a_pt], [x]) - - compare_pytorch_and_py(x_fg, [test_value]) - - -@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (1, 2)]) -def test_pytorch_prod(axis): - a_pt = tensor3("a") - test_value = np.array( - [ - [ - [1, 1, 1, 1], - [2, 2, 2, 2], - ], - [ - [3, 3, 3, 3], - [ - 4, - 4, - 4, - 4, - ], - ], - ] - ).astype(config.floatX) - - x = ptm.prod(a_pt, axis=axis) + x = fn(a_pt, axis=axis) x_fg = FunctionGraph([a_pt], [x]) compare_pytorch_and_py(x_fg, [test_value]) +@pytest.mark.parametrize("fn", [ptm.any, ptm.all]) @pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) -def test_pytorch_any(axis): +def test_pytorch_any_all(fn, axis): a_pt = matrix("a") test_value = np.array([[True, False, True], [False, True, True]]) - x = ptm.any(a_pt, axis=axis) - x_fg = FunctionGraph([a_pt], [x]) - - compare_pytorch_and_py(x_fg, [test_value]) - - -@pytest.mark.parametrize("axis", [None, 0, 1, (1, -1)]) -def test_pytorch_max(axis): - a_pt = tensor3("a") - test_value = np.array( - [ - [ - [1, 1, 1, 1], - [2, 2, 2, 2], - ], - [ - [3, 3, 3, 3], - [ - 4, - 4, - 4, - 4, - ], - ], - ] - ).astype(config.floatX) - - x = ptm.max(a_pt, axis=axis) - x_fg = FunctionGraph([a_pt], [x]) - - compare_pytorch_and_py(x_fg, [test_value]) - - -@pytest.mark.parametrize("axis", [None, 0, 1, (1, -1)]) -def test_pytorch_min(axis): - a_pt = tensor3("a") - test_value = np.array( - [ - [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]], - [ - [4, 4, 4, 4], - [ - 5, - 5, - 5, - 5, - ], - [ - 6, - 6, - 6, - 6, - ], - ], - ] - ).astype(config.floatX) - - x = ptm.min(a_pt, axis=axis) + x = fn(a_pt, axis=axis) x_fg = FunctionGraph([a_pt], [x]) compare_pytorch_and_py(x_fg, [test_value]) From 35153f7ce66abb84715e24621660da7b8f01d5e3 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Thu, 4 Jul 2024 14:42:44 +0530 Subject: [PATCH 10/10] Increase readability --- pytensor/link/pytorch/dispatch/elemwise.py | 6 +++--- tests/link/pytorch/test_elemwise.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index f0d4b892f1..b1ad5582c5 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -58,7 +58,7 @@ def torch_all(x): def pytorch_funcify_prod(op, **kwargs): def torch_prod(x): if isinstance(op.axis, tuple): - for d in op.axis[::-1]: + for d in sorted(op.axis, reverse=True): x = torch.prod(x, dim=d) return x else: @@ -79,7 +79,7 @@ def torch_any(x): def pytorch_funcify_max(op, **kwargs): def torch_max(x): if isinstance(op.axis, tuple): - for d in op.axis[::-1]: + for d in sorted(op.axis, reverse=True): x = torch.max(x, dim=d).values return x else: @@ -92,7 +92,7 @@ def torch_max(x): def pytorch_funcify_min(op, **kwargs): def torch_min(x): if isinstance(op.axis, tuple): - for d in op.axis[::-1]: + for d in sorted(op.axis, reverse=True): x = torch.min(x, dim=d).values return x else: diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 64a4b80710..afb62848cc 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -59,7 +59,7 @@ def test_pytorch_elemwise(): @pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min]) -@pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)]) +@pytest.mark.parametrize("axis", [None, 0, 1, (0, -1)]) def test_pytorch_careduce(fn, axis): a_pt = tensor3("a") test_value = np.array(