diff --git a/pytensor/link/pytorch/dispatch/extra_ops.py b/pytensor/link/pytorch/dispatch/extra_ops.py index f7af1eca7b..74284d651d 100644 --- a/pytensor/link/pytorch/dispatch/extra_ops.py +++ b/pytensor/link/pytorch/dispatch/extra_ops.py @@ -1,7 +1,7 @@ import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify -from pytensor.tensor.extra_ops import CumOp +from pytensor.tensor.extra_ops import CumOp, Repeat, Unique @pytorch_funcify.register(CumOp) @@ -21,3 +21,38 @@ def cumop(x): return torch.cumprod(x, dim=dim) return cumop + + +@pytorch_funcify.register(Repeat) +def pytorch_funcify_Repeat(op, **kwargs): + axis = op.axis + + def repeat(x, repeats): + return x.repeat_interleave(repeats, dim=axis) + + return repeat + + +@pytorch_funcify.register(Unique) +def pytorch_funcify_Unique(op, **kwargs): + return_index = op.return_index + + if return_index: + # TODO: evaluate whether is worth implementing this param + # (see https://github.com/pytorch/pytorch/issues/36748) + raise NotImplementedError("return_index is not implemented for pytorch") + + axis = op.axis + return_inverse = op.return_inverse + return_counts = op.return_counts + + def unique(x): + return torch.unique( + x, + sorted=True, + return_inverse=return_inverse, + return_counts=return_counts, + dim=axis, + ) + + return unique diff --git a/tests/link/pytorch/test_extra_ops.py b/tests/link/pytorch/test_extra_ops.py index 72faa3d0d0..221855864a 100644 --- a/tests/link/pytorch/test_extra_ops.py +++ b/tests/link/pytorch/test_extra_ops.py @@ -41,3 +41,61 @@ def test_pytorch_CumOp(axis, dtype): out = pt.cumprod(a, axis=axis) fgraph = FunctionGraph([a], [out]) compare_pytorch_and_py(fgraph, [test_value]) + + +@pytest.mark.parametrize( + "axis, repeats", + [ + (0, (1, 2, 3)), + (1, (3, 3)), + pytest.param( + None, + 3, + marks=pytest.mark.xfail(reason="Reshape not implemented"), + ), + ], +) +def test_pytorch_Repeat(axis, repeats): + a = pt.matrix("a", dtype="float64") + + test_value = np.arange(6, dtype="float64").reshape((3, 2)) + + out = pt.repeat(a, repeats, axis=axis) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [test_value]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_pytorch_Unique_axis(axis): + a = pt.matrix("a", dtype="float64") + + test_value = np.array( + [[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64" + ) + + out = pt.unique(a, axis=axis) + fgraph = FunctionGraph([a], [out]) + compare_pytorch_and_py(fgraph, [test_value]) + + +@pytest.mark.parametrize("return_inverse", [False, True]) +@pytest.mark.parametrize("return_counts", [False, True]) +@pytest.mark.parametrize( + "return_index", + (False, pytest.param(True, marks=pytest.mark.xfail(raises=NotImplementedError))), +) +def test_pytorch_Unique_params(return_index, return_inverse, return_counts): + a = pt.matrix("a", dtype="float64") + test_value = np.array( + [[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64" + ) + + out = pt.unique( + a, + return_index=return_index, + return_inverse=return_inverse, + return_counts=return_counts, + axis=0, + ) + fgraph = FunctionGraph([a], [out[0] if isinstance(out, list) else out]) + compare_pytorch_and_py(fgraph, [test_value])