From f007c0dcc0b4e0419d8331e2147075fc90ac33f0 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Thu, 20 Jun 2024 22:13:00 +0530 Subject: [PATCH 1/8] Add Pytorch support for Cum Op --- pytensor/link/pytorch/dispatch/__init__.py | 1 + pytensor/link/pytorch/dispatch/extra_ops.py | 18 ++++++++++++ tests/link/pytorch/test_extra_ops.py | 31 +++++++++++++++++++++ 3 files changed, 50 insertions(+) create mode 100644 pytensor/link/pytorch/dispatch/extra_ops.py create mode 100644 tests/link/pytorch/test_extra_ops.py diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index b6af171995..7e476aba04 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -4,4 +4,5 @@ # # Load dispatch specializations import pytensor.link.pytorch.dispatch.scalar import pytensor.link.pytorch.dispatch.elemwise +import pytensor.link.pytorch.dispatch.extra_ops # isort: on diff --git a/pytensor/link/pytorch/dispatch/extra_ops.py b/pytensor/link/pytorch/dispatch/extra_ops.py new file mode 100644 index 0000000000..fcbd2c9bfc --- /dev/null +++ b/pytensor/link/pytorch/dispatch/extra_ops.py @@ -0,0 +1,18 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.extra_ops import CumOp + + +@pytorch_funcify.register(CumOp) +def pytorch_funcify_Cumop(op, **kwargs): + dim = op.axis + mode = op.mode + + def cumop(x, dim=dim, mode=mode): + if mode == "add": + return torch.cumsum(x, dim=dim) + else: + return torch.cumprod(x, dim=dim) + + return cumop diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py new file mode 100644 index 0000000000..05079f2841 --- /dev/null +++ b/tests/link/pytorch/test_extra_ops.py @@ -0,0 +1,31 @@ +import numpy as np + +import pytensor.tensor as pt +from pytensor.configdefaults import config +from pytensor.graph import FunctionGraph +from pytensor.graph.op import get_test_value +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_CumOp(): + """Test PyTorch conversion of the `CumOp` `Op`.""" + + # Create a symbolic input for the first input of `CumOp` + a = pt.matrix("a") + + # Create test value tag for a + a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + # Create the output variable + out = pt.cumsum(a, axis=0) + + # Create a PyTensor `FunctionGraph` + fgraph = FunctionGraph([a], [out]) + + # Pass the graph and inputs to the testing function + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + # For the second mode of CumOp + out = pt.cumprod(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) From 6ae355f0962ad59dd174401296f9c4e43499887d Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 22 Jun 2024 16:13:21 +0530 Subject: [PATCH 2/8] Modify test for Cum op --- pytensor/link/pytorch/dispatch/extra_ops.py | 5 ++++- tests/link/pytorch/test_extra_ops.py | 21 ++++++++++++++------- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/extra_ops.py b/pytensor/link/pytorch/dispatch/extra_ops.py index fcbd2c9bfc..0015359a18 100644 --- a/pytensor/link/pytorch/dispatch/extra_ops.py +++ b/pytensor/link/pytorch/dispatch/extra_ops.py @@ -9,7 +9,10 @@ def pytorch_funcify_Cumop(op, **kwargs): dim = op.axis mode = op.mode - def cumop(x, dim=dim, mode=mode): + def cumop(x, dim=dim): + if dim is None: + x = x.reshape(-1) + dim = 0 if mode == "add": return torch.cumsum(x, dim=dim) else: diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index 05079f2841..2036544099 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -1,31 +1,38 @@ import numpy as np +import pytest import pytensor.tensor as pt from pytensor.configdefaults import config from pytensor.graph import FunctionGraph -from pytensor.graph.op import get_test_value from tests.link.pytorch.test_basic import compare_pytorch_and_py -def test_pytorch_CumOp(): +@pytest.mark.parametrize( + "axis", + [ + None, + 1, + ], +) +def test_pytorch_CumOp(axis): """Test PyTorch conversion of the `CumOp` `Op`.""" # Create a symbolic input for the first input of `CumOp` a = pt.matrix("a") # Create test value tag for a - a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) # Create the output variable - out = pt.cumsum(a, axis=0) + out = pt.cumsum(a, axis=axis) # Create a PyTensor `FunctionGraph` fgraph = FunctionGraph([a], [out]) # Pass the graph and inputs to the testing function - compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_pytorch_and_py(fgraph, [test_value]) # For the second mode of CumOp - out = pt.cumprod(a, axis=1) + out = pt.cumprod(a, axis=axis) fgraph = FunctionGraph([a], [out]) - compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_pytorch_and_py(fgraph, [test_value]) From 33463fc152075c299dc4c874a775700411c545e0 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Mon, 24 Jun 2024 00:57:28 +0530 Subject: [PATCH 3/8] Raise TypeError if axis not int or None --- pytensor/link/pytorch/dispatch/extra_ops.py | 8 +++-- pytensor/tensor/extra_ops.py | 7 ++-- tests/link/pytorch/test_extra_ops.py | 36 +++++++++++---------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/extra_ops.py b/pytensor/link/pytorch/dispatch/extra_ops.py index 0015359a18..f7af1eca7b 100644 --- a/pytensor/link/pytorch/dispatch/extra_ops.py +++ b/pytensor/link/pytorch/dispatch/extra_ops.py @@ -6,13 +6,15 @@ @pytorch_funcify.register(CumOp) def pytorch_funcify_Cumop(op, **kwargs): - dim = op.axis + axis = op.axis mode = op.mode - def cumop(x, dim=dim): - if dim is None: + def cumop(x): + if axis is None: x = x.reshape(-1) dim = 0 + else: + dim = axis if mode == "add": return torch.cumsum(x, dim=dim) else: diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 06a82744b2..d064fe7b6d 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -283,8 +283,11 @@ class CumOp(COp): def __init__(self, axis: int | None = None, mode="add"): if mode not in ("add", "mul"): raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"') - self.axis = axis - self.mode = mode + if isinstance(axis, int) or axis is None: + self.axis = axis + self.mode = mode + else: + raise TypeError("axis must be an integer or None.") c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis) diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index 2036544099..e335fbfb91 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -9,10 +9,7 @@ @pytest.mark.parametrize( "axis", - [ - None, - 1, - ], + [None, 1, (0,)], ) def test_pytorch_CumOp(axis): """Test PyTorch conversion of the `CumOp` `Op`.""" @@ -20,19 +17,24 @@ def test_pytorch_CumOp(axis): # Create a symbolic input for the first input of `CumOp` a = pt.matrix("a") - # Create test value tag for a + # Create test value test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) # Create the output variable - out = pt.cumsum(a, axis=axis) - - # Create a PyTensor `FunctionGraph` - fgraph = FunctionGraph([a], [out]) - - # Pass the graph and inputs to the testing function - compare_pytorch_and_py(fgraph, [test_value]) - - # For the second mode of CumOp - out = pt.cumprod(a, axis=axis) - fgraph = FunctionGraph([a], [out]) - compare_pytorch_and_py(fgraph, [test_value]) + if isinstance(axis, tuple): + with pytest.raises(TypeError, match="axis must be an integer or None."): + out = pt.cumsum(a, axis=axis) + with pytest.raises(TypeError, match="axis must be an integer or None."): + out = pt.cumprod(a, axis=axis) + else: + out = pt.cumsum(a, axis=axis) + # Create a PyTensor `FunctionGraph` + fgraph = FunctionGraph([a], [out]) + + # Pass the graph and inputs to the testing function + compare_pytorch_and_py(fgraph, [test_value]) + + # For the second mode of CumOp + out = pt.cumprod(a, axis=axis) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [test_value]) From bf905cbc3b1577ee053e36bcaf18c876c3ae9921 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Mon, 24 Jun 2024 03:26:53 +0530 Subject: [PATCH 4/8] Fix init method of CumOp --- pytensor/tensor/extra_ops.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index d064fe7b6d..94e63d33d6 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -283,11 +283,10 @@ class CumOp(COp): def __init__(self, axis: int | None = None, mode="add"): if mode not in ("add", "mul"): raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"') - if isinstance(axis, int) or axis is None: - self.axis = axis - self.mode = mode - else: + if not (isinstance(axis, int) or axis is None): raise TypeError("axis must be an integer or None.") + self.axis = axis + self.mode = mode c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis) From debc3e0c6bf69e8af14aff0c6221c04a0d39a8f1 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Mon, 24 Jun 2024 03:32:10 +0530 Subject: [PATCH 5/8] Extend tutorial on documentation for Pytorch --- doc/conf.py | 1 + doc/environment.yml | 1 + doc/extending/creating_a_numba_jax_op.rst | 359 +++++++++++++++------- 3 files changed, 257 insertions(+), 104 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 7820a05a14..9fa44c98f0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -32,6 +32,7 @@ "sphinx.ext.napoleon", "sphinx.ext.linkcode", "sphinx.ext.mathjax", + "sphinx_design" ] needs_sphinx = "3" diff --git a/doc/environment.yml b/doc/environment.yml index c86375ccf1..ae17b6379d 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -13,6 +13,7 @@ dependencies: - mock - pillow - pymc-sphinx-theme + - sphinx-design - pip - pip: - -e .. diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 0d5f6460e9..fa872b427f 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -1,16 +1,15 @@ -Adding JAX and Numba support for `Op`\s +Adding JAX, Numba and Pytorch support for `Op`\s ======================================= -PyTensor is able to convert its graphs into JAX and Numba compiled functions. In order to do -this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba implementation function. +PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do +this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function. -This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will -focus specifically on the JAX case, but the same mechanisms are used for Numba as well. +This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`. -Step 1: Identify the PyTensor :class:`Op` you'd like to implement in JAX +Step 1: Identify the PyTensor :class:`Op` you'd like to implement ------------------------------------------------------------------------ -Find the source for the PyTensor :class:`Op` you'd like to be supported in JAX, and +Find the source for the PyTensor :class:`Op` you'd like to be supported and identify the function signature and return values. These can be determined by looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar with PyTensor :class:`Op`\s in order to provide a conversion implementation, so first read @@ -46,7 +45,7 @@ which currently has an :meth:`Op.make_node` as follows: return Apply(self, [x], [out_type]) The :class:`Apply` instance that's returned specifies the exact types of inputs that -our JAX implementation will receive and the exact types of outputs it's expected to +our implementation will receive and the exact types of outputs it's expected to return--both in terms of their data types and number of dimensions/shapes. The actual inputs our implementation will receive are necessarily numeric values or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the @@ -57,7 +56,7 @@ automatically converted to PyTensor variables via :func:`as_tensor_variable`. There is another parameter, `axis`, that is used to determine the direction of the operation, hence shape of the output. The check that follows imply that `axis` must refer to a dimension in the input tensor. The input's elements -could also have any data type (e.g. floats, ints), so our JAX implementation +could also have any data type (e.g. floats, ints), so our implementation must be able to handle all the possible data types. It also tells us that there's only one return value, that it has a data type @@ -89,42 +88,76 @@ as :class:`CumsumOp`\ :class:`Op`. The difference lies in that the `mode` attrib c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis) `__props__` is used to parametrize the general behavior of the :class:`Op`. One need to -pay attention to this to decide whether the JAX implementation should support all variants +pay attention to this to decide whether the implementation should support all variants or raise an explicit NotImplementedError for cases that are not supported e.g., when :class:`CumsumOp` of :class:`CumOp("add")` is supported but not :class:`CumprodOp` of :class:`CumOp("mul")`. Next, we look at the :meth:`Op.perform` implementation to see exactly how the inputs and outputs are used to compute the outputs for an :class:`Op` -in Python. This method is effectively what needs to be implemented in JAX. +in Python. This method is effectively what needs to be implemented. -Step 2: Find the relevant JAX method (or something close) +Step 2: Find the relevant method in JAX/Numba/Pytorch (or something close) --------------------------------------------------------- With a precise idea of what the PyTensor :class:`Op` does we need to figure out how -to implement it in JAX. In the best case scenario, JAX has a similarly named +to implement it in JAX, Numba or Pytorch. In the best case scenario, there is a similarly named function that performs exactly the same computations as the :class:`Op`. For example, the :class:`Eye` operator has a JAX equivalent: :func:`jax.numpy.eye` -(see `the documentation `_). +(see `the documentation `_) and a Pytorch equivalent :func:`torch.eye` (see `documentation `_). -If we wanted to implement an :class:`Op` like :class:`IfElse`, we might need to +If we wanted to implement an :class:`Op` like :class:`DimShuffle`, we might need to recreate the functionality with some custom logic. In many cases, at least some custom logic is needed to reformat the inputs and outputs so that they exactly match the `Op`'s. -Here's an example for :class:`IfElse`: +Here's an example for :class:`DimShuffle`: -.. code:: python - def ifelse(cond, *args, n_outs=n_outs): - res = jax.lax.cond( - cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None - ) - return res if n_outs > 1 else res[0] +.. tab-set:: + + .. tab-item:: JAX/Numba + + .. code:: python + + def dimshuffle(x, op): + res = jnp.transpose(x, op.transposition) + + shape = list(res.shape[: len(op.shuffle)]) + + for augm in op.augment: + shape.insert(augm, 1) + + res = jnp.reshape(res, shape) + + if not op.inplace: + res = jnp.copy(res) + + return res + + .. tab-item:: Pytorch + + .. code:: python + + def dimshuffle(x, op): + res = torch.permute(x, op.transposition) + + shape = list(res.shape[: len(op.shuffle)]) + + for augm in op.augment: + shape.insert(augm, 1) + + res = torch.reshape(res, shape) + + if not op.inplace: + res = res.clone() + + return res In this case, :class:`CumOp` is implemented with NumPy's :func:`numpy.cumsum` and :func:`numpy.cumprod`, which have JAX equivalents: :func:`jax.numpy.cumsum` -and :func:`jax.numpy.cumprod`. +and :func:`jax.numpy.cumprod`. The Pytorch equivalents are :func:`torch.cumsum` +and :func:`torch.cumprod` .. code:: python @@ -136,132 +169,250 @@ and :func:`jax.numpy.cumprod`. else: z[0] = np.cumprod(x, axis=self.axis) -Step 3: Register the function with the `jax_funcify` dispatcher +Step 3: Register the function with the respective dispatcher --------------------------------------------------------------- -With the PyTensor `Op` replicated in JAX, we'll need to register the -function with the PyTensor JAX `Linker`. This is done through the use of +With the PyTensor `Op` replicated, we'll need to register the +function with the backends `Linker`. This is done through the use of `singledispatch`. If you don't know how `singledispatch` works, see the `Python documentation `_. -The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and -:func:`pytensor.link.jax.dispatch.jax_funcify`. +The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify`, +:func:`pytensor.link.jax.dispatch.jax_funcify` and :func:`pytensor.link.pytorch.dispatch.pytorch_funcify`. Here's an example for the `CumOp`\ `Op`: -.. code:: python +.. tab-set:: - import jax.numpy as jnp + .. tab-item:: JAX/Numba - from pytensor.tensor.extra_ops import CumOp - from pytensor.link.jax.dispatch import jax_funcify + .. code:: python + import jax.numpy as jnp - @jax_funcify.register(CumOp) - def jax_funcify_CumOp(op, **kwargs): - axis = op.axis - mode = op.mode + from pytensor.tensor.extra_ops import CumOp + from pytensor.link.jax.dispatch import jax_funcify - def cumop(x, axis=axis, mode=mode): - if mode == "add": - return jnp.cumsum(x, axis=axis) - else: - return jnp.cumprod(x, axis=axis) - return cumop + @jax_funcify.register(CumOp) + def jax_funcify_CumOp(op, **kwargs): + axis = op.axis + mode = op.mode -Suppose `jnp.cumprod` does not exist, we will need to register the function as follows: + def cumop(x, axis=axis, mode=mode): + if mode == "add": + return jnp.cumsum(x, axis=axis) + else: + return jnp.cumprod(x, axis=axis) -.. code:: python + return cumop - import jax.numpy as jnp + Suppose `jnp.cumprod` does not exist, we will need to register the function as follows: - from pytensor.tensor.extra_ops import CumOp - from pytensor.link.jax.dispatch import jax_funcify + .. code:: python + import jax.numpy as jnp - @jax_funcify.register(CumOp) - def jax_funcify_CumOp(op, **kwargs): - axis = op.axis - mode = op.mode + from pytensor.tensor.extra_ops import CumOp + from pytensor.link.jax.dispatch import jax_funcify - def cumop(x, axis=axis, mode=mode): - if mode == "add": - return jnp.cumsum(x, axis=axis) - else: - raise NotImplementedError("JAX does not support cumprod function at the moment.") - return cumop + @jax_funcify.register(CumOp) + def jax_funcify_CumOp(op, **kwargs): + axis = op.axis + mode = op.mode -Step 4: Write tests -------------------- + def cumop(x, axis=axis, mode=mode): + if mode == "add": + return jnp.cumsum(x, axis=axis) + else: + raise NotImplementedError("JAX does not support cumprod function at the moment.") -Test that your registered `Op` is working correctly by adding tests to the -appropriate test suites in PyTensor (e.g. in ``tests.link.jax`` and one of -the modules in ``tests.link.numba``). The tests should ensure that your implementation can -handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`. -Check the existing tests for the general outline of these kinds of tests. In -most cases, a helper function can be used to easily verify the correspondence -between a JAX/Numba implementation and its `Op`. + return cumop -For example, the :func:`compare_jax_and_py` function streamlines the steps -involved in making comparisons with `Op.perform`. -Here's a small example of a test for :class:`CumOp` above: + .. tab-item:: Pytorch -.. code:: python - - import numpy as np - import pytensor.tensor as pt - from pytensor.configdefaults import config - from tests.link.jax.test_basic import compare_jax_and_py - from pytensor.graph import FunctionGraph - from pytensor.graph.op import get_test_value + .. code:: python - def test_jax_CumOp(): - """Test JAX conversion of the `CumOp` `Op`.""" + import torch - # Create a symbolic input for the first input of `CumOp` - a = pt.matrix("a") + from pytensor.link.pytorch.dispatch.basic import pytorch_funcify + from pytensor.tensor.extra_ops import CumOp - # Create test value tag for a - a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) - # Create the output variable - out = pt.cumsum(a, axis=0) + @pytorch_funcify.register(CumOp) + def pytorch_funcify_Cumop(op, **kwargs): + axis = op.axis + mode = op.mode - # Create a PyTensor `FunctionGraph` - fgraph = FunctionGraph([a], [out]) + def cumop(x,): + if axis is None: + x = x.reshape(-1) + dim = 0 + else: + dim=axis + if mode == "add": + return torch.cumsum(x, dim=dim) + else: + return torch.cumprod(x, dim=dim) - # Pass the graph and inputs to the testing function - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + return cumop - # For the second mode of CumOp - out = pt.cumprod(a, axis=1) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) -If the variant :class:`CumprodOp` is not implemented, we can add a test for it as follows: + Suppose `torch.cumprod` does not exist, we will need to register the function as follows: + + .. code:: python + + import torch + + from pytensor.tensor.extra_ops import CumOp + from pytensor.link.pytorch.dispatch import pytorch_funcify -.. code:: python - import pytest + @pytorch_funcify.register(CumOp) + def pytorch_funcify_Cumop(op, **kwargs): + axis = op.axis + mode = op.mode + + def cumop(x, axis=axis, mode=mode): + if mode == "add": + return torch.cumsum(x, axis=axis) + else: + raise NotImplementedError("Pytorch does not support cumprod function at the moment.") + + return cumop + +Step 4: Write tests +------------------- +.. tab-set:: + + .. tab-item:: JAX/Numba + + Test that your registered `Op` is working correctly by adding tests to the + appropriate test suites in PyTensor (e.g. in ``tests.link.jax`` and one of + the modules in ``tests.link.numba``). The tests should ensure that your implementation can + handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`. + Check the existing tests for the general outline of these kinds of tests. In + most cases, a helper function can be used to easily verify the correspondence + between a JAX/Numba implementation and its `Op`. + + For example, the :func:`compare_jax_and_py` function streamlines the steps + involved in making comparisons with `Op.perform`. + + Here's a small example of a test for :class:`CumOp` above: + + .. code:: python + + import numpy as np + import pytensor.tensor as pt + from pytensor.configdefaults import config + from tests.link.jax.test_basic import compare_jax_and_py + from pytensor.graph import FunctionGraph + from pytensor.graph.op import get_test_value + + def test_jax_CumOp(): + """Test JAX conversion of the `CumOp` `Op`.""" + + # Create a symbolic input for the first input of `CumOp` + a = pt.matrix("a") + + # Create test value tag for a + a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + # Create the output variable + out = pt.cumsum(a, axis=0) + + # Create a PyTensor `FunctionGraph` + fgraph = FunctionGraph([a], [out]) + + # Pass the graph and inputs to the testing function + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + # For the second mode of CumOp + out = pt.cumprod(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + If the variant :class:`CumprodOp` is not implemented, we can add a test for it as follows: + + .. code:: python + + import pytest + + def test_jax_CumOp(): + """Test JAX conversion of the `CumOp` `Op`.""" + a = pt.matrix("a") + a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + with pytest.raises(NotImplementedError): + out = pt.cumprod(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + - def test_jax_CumOp(): - """Test JAX conversion of the `CumOp` `Op`.""" - a = pt.matrix("a") - a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + .. tab-item:: Pytorch - with pytest.raises(NotImplementedError): - out = pt.cumprod(a, axis=1) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + Test that your registered `Op` is working correctly by adding tests to the + appropriate test suites in PyTensor (``tests.link.pytorch``). The tests should ensure that your implementation can + handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`. + Check the existing tests for the general outline of these kinds of tests. In + most cases, a helper function can be used to easily verify the correspondence + between a Pytorch implementation and its `Op`. + + For example, the :func:`compare_pytorch_and_py` function streamlines the steps + involved in making comparisons with `Op.perform`. + + Here's a small example of a test for :class:`CumOp` above: + + .. code:: python + + import numpy as np + import pytest + import pytensor.tensor as pt + from pytensor.configdefaults import config + from tests.link.pytorch.test_basic import compare_pytorch_and_py + from pytensor.graph import FunctionGraph + + @pytest.mark.parametrize( + "axis", + [None, 1, (0,)], + ) + def test_pytorch_CumOp(axis): + """Test PyTorch conversion of the `CumOp` `Op`.""" + + # Create a symbolic input for the first input of `CumOp` + a = pt.matrix("a") + + # Create test value + test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + + # Create the output variable + if isinstance(axis, tuple): + with pytest.raises(TypeError, match="axis must be an integer or None."): + out = pt.cumsum(a, axis=axis) + with pytest.raises(TypeError, match="axis must be an integer or None."): + out = pt.cumprod(a, axis=axis) + else: + out = pt.cumsum(a, axis=axis) + # Create a PyTensor `FunctionGraph` + fgraph = FunctionGraph([a], [out]) + + # Pass the graph and inputs to the testing function + compare_pytorch_and_py(fgraph, [test_value]) + + # For the second mode of CumOp + out = pt.cumprod(a, axis=axis) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [test_value]) + Note ---- In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows: .. code:: python + def test_jax_Eye(): """Test JAX conversion of the `Eye` `Op`.""" From 2bc7ddce45bfa0f2e5bbead889f36a597a65b5ca Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Mon, 24 Jun 2024 21:04:53 +0530 Subject: [PATCH 6/8] Add tab for Numba --- doc/extending/creating_a_numba_jax_op.rst | 225 +++++++++++++++++++++- 1 file changed, 219 insertions(+), 6 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index fa872b427f..4d345e1b8b 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -116,7 +116,7 @@ Here's an example for :class:`DimShuffle`: .. tab-set:: - .. tab-item:: JAX/Numba + .. tab-item:: JAX .. code:: python @@ -134,6 +134,105 @@ Here's an example for :class:`DimShuffle`: res = jnp.copy(res) return res + + .. tab-item:: Numba + + .. code:: python + + def numba_funcify_DimShuffle(op, node, **kwargs): + shuffle = tuple(op.shuffle) + transposition = tuple(op.transposition) + augment = tuple(op.augment) + inplace = op.inplace + + ndim_new_shape = len(shuffle) + len(augment) + + no_transpose = all(i == j for i, j in enumerate(transposition)) + if no_transpose: + + @numba_basic.numba_njit + def transpose(x): + return x + + else: + + @numba_basic.numba_njit + def transpose(x): + return np.transpose(x, transposition) + + shape_template = (1,) * ndim_new_shape + + # When `len(shuffle) == 0`, the `shuffle_shape[j]` expression below + # is typed as `getitem(Tuple(), int)`, which has no implementation + # (since getting an item from an empty sequence doesn't make sense). + # To avoid this compile-time error, we omit the expression altogether. + if len(shuffle) > 0: + # Use the statically known shape if available + if all(length is not None for length in node.outputs[0].type.shape): + shape = node.outputs[0].type.shape + + @numba_basic.numba_njit + def find_shape(array_shape): + return shape + + else: + + @numba_basic.numba_njit + def find_shape(array_shape): + shape = shape_template + j = 0 + for i in range(ndim_new_shape): + if i not in augment: + length = array_shape[j] + shape = numba_basic.tuple_setitem(shape, i, length) + j = j + 1 + return shape + + else: + + @numba_basic.numba_njit + def find_shape(array_shape): + return shape_template + + if ndim_new_shape > 0: + + @numba_basic.numba_njit + def dimshuffle_inner(x, shuffle): + x = transpose(x) + shuffle_shape = x.shape[: len(shuffle)] + new_shape = find_shape(shuffle_shape) + + # FIXME: Numba's `array.reshape` only accepts C arrays. + res_reshape = np.reshape(np.ascontiguousarray(x), new_shape) + + if not inplace: + return res_reshape.copy() + else: + return res_reshape + + else: + + @numba_basic.numba_njit + def dimshuffle_inner(x, shuffle): + return np.reshape(np.ascontiguousarray(x), ()) + + # Without the following wrapper function we would see this error: + # E No implementation of function Function() found for signature: + # E + # E >>> getitem(UniTuple(int64 x 2), slice) + # E + # E There are 22 candidate implementations: + # E - Of which 22 did not match due to: + # E Overload of function 'getitem': File: : Line N/A. + # E With argument(s): '(UniTuple(int64 x 2), slice)': + # E No match. + # ...(on this line)... + # E shuffle_shape = res.shape[: len(shuffle)] + @numba_basic.numba_njit(inline="always") + def dimshuffle(x): + return dimshuffle_inner(np.asarray(x), shuffle) + + return dimshuffle .. tab-item:: Pytorch @@ -184,7 +283,7 @@ Here's an example for the `CumOp`\ `Op`: .. tab-set:: - .. tab-item:: JAX/Numba + .. tab-item:: JAX .. code:: python @@ -230,6 +329,82 @@ Here's an example for the `CumOp`\ `Op`: return cumop + .. tab-item:: Numba + + .. code:: python + + import numpy as np + + from pytensor import config + from pytensor.graph import Apply + from pytensor.link.numba.dispatch import basic as numba_basic + from pytensor.tensor import TensorVariable + from pytensor.tensor.extra_ops import CumOp, + + def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): + axis = op.axis + mode = op.mode + ndim = cast(TensorVariable, node.outputs[0]).ndim + + if axis is not None: + if axis < 0: + axis = ndim + axis + if axis < 0 or axis >= ndim: + raise ValueError(f"Invalid axis {axis} for array with ndim {ndim}") + + reaxis_first = (axis, *(i for i in range(ndim) if i != axis)) + reaxis_first_inv = tuple(np.argsort(reaxis_first)) + + if mode == "add": + if axis is None or ndim == 1: + + @numba_basic.numba_njit(fastmath=config.numba__fastmath) + def cumop(x): + return np.cumsum(x) + + else: + + @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + def cumop(x): + out_dtype = x.dtype + if x.shape[axis] < 2: + return x.astype(out_dtype) + + x_axis_first = x.transpose(reaxis_first) + res = np.empty(x_axis_first.shape, dtype=out_dtype) + + res[0] = x_axis_first[0] + for m in range(1, x.shape[axis]): + res[m] = res[m - 1] + x_axis_first[m] + + return res.transpose(reaxis_first_inv) + + else: + if axis is None or ndim == 1: + + @numba_basic.numba_njit(fastmath=config.numba__fastmath) + def cumop(x): + return np.cumprod(x) + + else: + + @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + def cumop(x): + out_dtype = x.dtype + if x.shape[axis] < 2: + return x.astype(out_dtype) + + x_axis_first = x.transpose(reaxis_first) + res = np.empty(x_axis_first.shape, dtype=out_dtype) + + res[0] = x_axis_first[0] + for m in range(1, x.shape[axis]): + res[m] = res[m - 1] * x_axis_first[m] + + return res.transpose(reaxis_first) + + return cumop + .. tab-item:: Pytorch @@ -287,15 +462,15 @@ Step 4: Write tests ------------------- .. tab-set:: - .. tab-item:: JAX/Numba + .. tab-item:: JAX Test that your registered `Op` is working correctly by adding tests to the - appropriate test suites in PyTensor (e.g. in ``tests.link.jax`` and one of - the modules in ``tests.link.numba``). The tests should ensure that your implementation can + appropriate test suites in PyTensor (e.g. in ``tests.link.jax``). + The tests should ensure that your implementation can handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`. Check the existing tests for the general outline of these kinds of tests. In most cases, a helper function can be used to easily verify the correspondence - between a JAX/Numba implementation and its `Op`. + between a Numba implementation and its `Op`. For example, the :func:`compare_jax_and_py` function streamlines the steps involved in making comparisons with `Op.perform`. @@ -351,6 +526,44 @@ Step 4: Write tests compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + .. tab-item:: Numba + + Test that your registered `Op` is working correctly by adding tests to the + appropriate test suites in PyTensor (e.g. in ``tests.link.numba``). + The tests should ensure that your implementation can + handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`. + Check the existing tests for the general outline of these kinds of tests. In + most cases, a helper function can be used to easily verify the correspondence + between a Numba implementation and its `Op`. + + For example, the :func:`compare_numba_and_py` function streamlines the steps + involved in making comparisons with `Op.perform`. + + Here's a small example of a test for :class:`CumOp` above: + + .. code:: python + + from tests.link.numba.test_basic import compare_numba_and_py + from pytensor.graph import FunctionGraph + from pytensor.compile.sharedvalue import SharedVariable + from pytensor.graph.basic import Constant + from pytensor.tensor import extra_ops + + def test_CumOp(val, axis, mode): + g = extra_ops.CumOp(axis=axis, mode=mode)(val) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, SharedVariable | Constant) + ], + ) + + + .. tab-item:: Pytorch Test that your registered `Op` is working correctly by adding tests to the From ec87e4e3eebd35d4354e18c266510065624ddcd4 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Mon, 24 Jun 2024 23:13:34 +0530 Subject: [PATCH 7/8] Add intersphinx mapping --- doc/conf.py | 9 ++++++++- doc/extending/creating_a_numba_jax_op.rst | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 9fa44c98f0..a47bed060c 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -32,9 +32,16 @@ "sphinx.ext.napoleon", "sphinx.ext.linkcode", "sphinx.ext.mathjax", - "sphinx_design" + "sphinx_design", + "sphinx.ext.intersphinx" ] +intersphinx_mapping = { + "jax": ("https://jax.readthedocs.io/en/latest", None), + "numpy": ("https://numpy.org/doc/stable", None), + "torch": ("https://pytorch.org/docs/stable", None), +} + needs_sphinx = "3" todo_include_todos = True diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 4d345e1b8b..1e735bab13 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -104,7 +104,7 @@ With a precise idea of what the PyTensor :class:`Op` does we need to figure out to implement it in JAX, Numba or Pytorch. In the best case scenario, there is a similarly named function that performs exactly the same computations as the :class:`Op`. For example, the :class:`Eye` operator has a JAX equivalent: :func:`jax.numpy.eye` -(see `the documentation `_) and a Pytorch equivalent :func:`torch.eye` (see `documentation `_). +and a Pytorch equivalent: :func:`torch.eye`. If we wanted to implement an :class:`Op` like :class:`DimShuffle`, we might need to recreate the functionality with some custom logic. In many cases, at least some From a6e6bd81366056f1c4d4d81db10705ecd453d16c Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Fri, 28 Jun 2024 17:00:08 +0530 Subject: [PATCH 8/8] Parametrize dtype --- doc/extending/creating_a_numba_jax_op.rst | 12 ++++++++---- tests/link/pytorch/test_extra_ops.py | 11 +++++++---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 1e735bab13..42c7304b5c 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -588,17 +588,21 @@ Step 4: Write tests from pytensor.graph import FunctionGraph @pytest.mark.parametrize( - "axis", + "dtype", + ["float64", "int64"], + ) + @pytest.mark.parametrize( + "axis", [None, 1, (0,)], ) - def test_pytorch_CumOp(axis): + def test_pytorch_CumOp(axis, dtype): """Test PyTorch conversion of the `CumOp` `Op`.""" # Create a symbolic input for the first input of `CumOp` - a = pt.matrix("a") + a = pt.matrix("a", dtype=dtype) # Create test value - test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + test_value = np.arange(9, dtype=dtype).reshape((3, 3)) # Create the output variable if isinstance(axis, tuple): diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index e335fbfb91..72faa3d0d0 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -2,23 +2,26 @@ import pytest import pytensor.tensor as pt -from pytensor.configdefaults import config from pytensor.graph import FunctionGraph from tests.link.pytorch.test_basic import compare_pytorch_and_py +@pytest.mark.parametrize( + "dtype", + ["float64", "int64"], +) @pytest.mark.parametrize( "axis", [None, 1, (0,)], ) -def test_pytorch_CumOp(axis): +def test_pytorch_CumOp(axis, dtype): """Test PyTorch conversion of the `CumOp` `Op`.""" # Create a symbolic input for the first input of `CumOp` - a = pt.matrix("a") + a = pt.matrix("a", dtype=dtype) # Create test value - test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) + test_value = np.arange(9, dtype=dtype).reshape((3, 3)) # Create the output variable if isinstance(axis, tuple):