From 547eaa5b241a4dc18971bed616301698f902544f Mon Sep 17 00:00:00 2001 From: PabloRoque Date: Wed, 5 Mar 2025 09:49:05 +0100 Subject: [PATCH 1/2] Attempt to add support for scipy.signal.gauss_spline --- pytensor/tensor/ssignal.py | 38 ++++++++++++++++++++++++++++++++++++ tests/tensor/test_ssignal.py | 34 ++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 pytensor/tensor/ssignal.py create mode 100644 tests/tensor/test_ssignal.py diff --git a/pytensor/tensor/ssignal.py b/pytensor/tensor/ssignal.py new file mode 100644 index 0000000000..d05c99e1fb --- /dev/null +++ b/pytensor/tensor/ssignal.py @@ -0,0 +1,38 @@ +from pytensor.tensor import Op, as_tensor_variable +from pytensor.graph.basic import Apply +import pytensor.tensor as pt +from pytensor.tensor.type import TensorType +import scipy.signal as scipy_signal + + +class GaussSpline(Op): + __props__ = ("n",) + + def __init__(self, n: int = None): + self.n = n + + def make_node(self, knots): + knots = as_tensor_variable(knots) + if not isinstance(knots.type, TensorType): + raise TypeError("Input must be a TensorType") + + if not isinstance(self.n, int) or self.n is None or self.n < 0: + raise ValueError("n must be a non-negative integer") + + if knots.ndim < 1: + raise TypeError("Input must be at least 1-dimensional") + + out = knots.type() + return Apply(self, [knots], [out]) + + def perform(self, node, inputs, output_storage): + [x] = inputs + [out] = output_storage + out[0] = scipy_signal.gauss_spline(x, self.n) + + def infer_shape(self, fgraph, node, shapes): + return [shapes[0]] + + +def gauss_spline(x, n): + return GaussSpline(n)(x) \ No newline at end of file diff --git a/tests/tensor/test_ssignal.py b/tests/tensor/test_ssignal.py new file mode 100644 index 0000000000..07e44c0026 --- /dev/null +++ b/tests/tensor/test_ssignal.py @@ -0,0 +1,34 @@ +from pytensor.tensor.ssignal import GaussSpline, gauss_spline +from pytensor.tensor.type import matrix +from pytensor import function +from pytensor import tensor as pt +import numpy as np +import pytest +from tests import unittest_tools as utt + +import scipy.signal as scipy_signal + +class TestGaussSpline(utt.InferShapeTester): + + def setup_method(self): + super().setup_method() + self.op_class = GaussSpline + self.op = gauss_spline + + @pytest.mark.parametrize("n", [-1, 1.5, None, "string"]) + def test_make_node_raises(self, n): + a = matrix() + with pytest.raises(ValueError, match="n must be a non-negative integer"): + self.op(a, n=n) + + def test_perform(self): + a = matrix() + f = function([a], self.op(a, n=10)) + a = np.random.random((8, 6)) + assert np.allclose(f(a), scipy_signal.gauss_spline(a, 10)) + + def test_infer_shape(self): + a = matrix() + self._compile_and_check( + [a], [self.op(a, 16)], [np.random.random((12, 4))], self.op_class + ) \ No newline at end of file From 4622edfcf2b0158809b60233a3028c822ff8c85f Mon Sep 17 00:00:00 2001 From: PabloRoque Date: Wed, 5 Mar 2025 09:59:58 +0100 Subject: [PATCH 2/2] Run ruff on changes --- pytensor/tensor/ssignal.py | 10 +++++----- tests/tensor/test_ssignal.py | 15 +++++++-------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/pytensor/tensor/ssignal.py b/pytensor/tensor/ssignal.py index d05c99e1fb..dae02d62ba 100644 --- a/pytensor/tensor/ssignal.py +++ b/pytensor/tensor/ssignal.py @@ -1,14 +1,14 @@ +import scipy.signal as scipy_signal + +from pytensor.graph.basic import Apply from pytensor.tensor import Op, as_tensor_variable -from pytensor.graph.basic import Apply -import pytensor.tensor as pt from pytensor.tensor.type import TensorType -import scipy.signal as scipy_signal class GaussSpline(Op): __props__ = ("n",) - def __init__(self, n: int = None): + def __init__(self, n: int): self.n = n def make_node(self, knots): @@ -35,4 +35,4 @@ def infer_shape(self, fgraph, node, shapes): def gauss_spline(x, n): - return GaussSpline(n)(x) \ No newline at end of file + return GaussSpline(n)(x) diff --git a/tests/tensor/test_ssignal.py b/tests/tensor/test_ssignal.py index 07e44c0026..0b53ccfc9f 100644 --- a/tests/tensor/test_ssignal.py +++ b/tests/tensor/test_ssignal.py @@ -1,15 +1,14 @@ -from pytensor.tensor.ssignal import GaussSpline, gauss_spline -from pytensor.tensor.type import matrix -from pytensor import function -from pytensor import tensor as pt import numpy as np import pytest +import scipy.signal as scipy_signal + +from pytensor import function +from pytensor.tensor.ssignal import GaussSpline, gauss_spline +from pytensor.tensor.type import matrix from tests import unittest_tools as utt -import scipy.signal as scipy_signal class TestGaussSpline(utt.InferShapeTester): - def setup_method(self): super().setup_method() self.op_class = GaussSpline @@ -20,7 +19,7 @@ def test_make_node_raises(self, n): a = matrix() with pytest.raises(ValueError, match="n must be a non-negative integer"): self.op(a, n=n) - + def test_perform(self): a = matrix() f = function([a], self.op(a, n=10)) @@ -31,4 +30,4 @@ def test_infer_shape(self): a = matrix() self._compile_and_check( [a], [self.op(a, 16)], [np.random.random((12, 4))], self.op_class - ) \ No newline at end of file + )