Skip to content

Commit 0b56ed9

Browse files
ricardoV94jessegrabowskizaxtax
authored
Implement batched convolve1d (#1318)
Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Rob Zinkov <[email protected]>
1 parent a038c8e commit 0b56ed9

File tree

14 files changed

+261
-0
lines changed

14 files changed

+261
-0
lines changed

Diff for: pytensor/link/jax/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pytensor.link.jax.dispatch.scalar
1515
import pytensor.link.jax.dispatch.scan
1616
import pytensor.link.jax.dispatch.shape
17+
import pytensor.link.jax.dispatch.signal
1718
import pytensor.link.jax.dispatch.slinalg
1819
import pytensor.link.jax.dispatch.sort
1920
import pytensor.link.jax.dispatch.sparse

Diff for: pytensor/link/jax/dispatch/signal/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.link.jax.dispatch.signal.conv

Diff for: pytensor/link/jax/dispatch/signal/conv.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import jax
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.signal.conv import Conv1d
5+
6+
7+
@jax_funcify.register(Conv1d)
8+
def jax_funcify_Conv1d(op, node, **kwargs):
9+
mode = op.mode
10+
11+
def conv1d(data, kernel):
12+
return jax.numpy.convolve(data, kernel, mode=mode)
13+
14+
return conv1d

Diff for: pytensor/link/numba/dispatch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
import pytensor.link.numba.dispatch.random
1010
import pytensor.link.numba.dispatch.scan
1111
import pytensor.link.numba.dispatch.scalar
12+
import pytensor.link.numba.dispatch.signal
1213
import pytensor.link.numba.dispatch.slinalg
1314
import pytensor.link.numba.dispatch.sparse
1415
import pytensor.link.numba.dispatch.subtensor
1516
import pytensor.link.numba.dispatch.tensor_basic
1617

18+
1719
# isort: on

Diff for: pytensor/link/numba/dispatch/signal/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
import pytensor.link.numba.dispatch.signal.conv

Diff for: pytensor/link/numba/dispatch/signal/conv.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
3+
from pytensor.link.numba.dispatch import numba_funcify
4+
from pytensor.link.numba.dispatch.basic import numba_njit
5+
from pytensor.tensor.signal.conv import Conv1d
6+
7+
8+
@numba_funcify.register(Conv1d)
9+
def numba_funcify_Conv1d(op, node, **kwargs):
10+
mode = op.mode
11+
12+
@numba_njit
13+
def conv1d(data, kernel):
14+
return np.convolve(data, kernel, mode=mode)
15+
16+
return conv1d

Diff for: pytensor/tensor/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
116116
# isort: off
117117
from pytensor.tensor import linalg
118118
from pytensor.tensor import special
119+
from pytensor.tensor import signal
119120

120121
# For backward compatibility
121122
from pytensor.tensor import nlinalg

Diff for: pytensor/tensor/signal/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from pytensor.tensor.signal.conv import convolve1d
2+
3+
4+
__all__ = ("convolve1d",)

Diff for: pytensor/tensor/signal/conv.py

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import TYPE_CHECKING, Literal, cast
2+
3+
from numpy import convolve as numpy_convolve
4+
5+
from pytensor.graph import Apply, Op
6+
from pytensor.scalar.basic import upcast
7+
from pytensor.tensor.basic import as_tensor_variable, join, zeros
8+
from pytensor.tensor.blockwise import Blockwise
9+
from pytensor.tensor.math import maximum, minimum
10+
from pytensor.tensor.type import vector
11+
from pytensor.tensor.variable import TensorVariable
12+
13+
14+
if TYPE_CHECKING:
15+
from pytensor.tensor import TensorLike
16+
17+
18+
class Conv1d(Op):
19+
__props__ = ("mode",)
20+
gufunc_signature = "(n),(k)->(o)"
21+
22+
def __init__(self, mode: Literal["full", "valid"] = "full"):
23+
if mode not in ("full", "valid"):
24+
raise ValueError(f"Invalid mode: {mode}")
25+
self.mode = mode
26+
27+
def make_node(self, in1, in2):
28+
in1 = as_tensor_variable(in1)
29+
in2 = as_tensor_variable(in2)
30+
31+
assert in1.ndim == 1
32+
assert in2.ndim == 1
33+
34+
dtype = upcast(in1.dtype, in2.dtype)
35+
36+
n = in1.type.shape[0]
37+
k = in2.type.shape[0]
38+
39+
if n is None or k is None:
40+
out_shape = (None,)
41+
elif self.mode == "full":
42+
out_shape = (n + k - 1,)
43+
else: # mode == "valid":
44+
out_shape = (max(n, k) - min(n, k) + 1,)
45+
46+
out = vector(dtype=dtype, shape=out_shape)
47+
return Apply(self, [in1, in2], [out])
48+
49+
def perform(self, node, inputs, outputs):
50+
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
51+
# And mode != "same", which this Op doesn't cover anyway.
52+
outputs[0][0] = numpy_convolve(*inputs, mode=self.mode)
53+
54+
def infer_shape(self, fgraph, node, shapes):
55+
in1_shape, in2_shape = shapes
56+
n = in1_shape[0]
57+
k = in2_shape[0]
58+
if self.mode == "full":
59+
shape = n + k - 1
60+
else: # mode == "valid":
61+
shape = maximum(n, k) - minimum(n, k) + 1
62+
return [[shape]]
63+
64+
def L_op(self, inputs, outputs, output_grads):
65+
in1, in2 = inputs
66+
[grad] = output_grads
67+
68+
if self.mode == "full":
69+
valid_conv = type(self)(mode="valid")
70+
in1_bar = valid_conv(grad, in2[::-1])
71+
in2_bar = valid_conv(grad, in1[::-1])
72+
73+
else: # mode == "valid":
74+
full_conv = type(self)(mode="full")
75+
n = in1.shape[0]
76+
k = in2.shape[0]
77+
kmn = maximum(0, k - n)
78+
nkm = maximum(0, n - k)
79+
# We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic.
80+
# Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter.
81+
in1_bar = full_conv(grad, in2[::-1])
82+
in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn]
83+
in2_bar = full_conv(grad, in1[::-1])
84+
in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm]
85+
86+
return [in1_bar, in2_bar]
87+
88+
89+
def convolve1d(
90+
in1: "TensorLike",
91+
in2: "TensorLike",
92+
mode: Literal["full", "valid", "same"] = "full",
93+
) -> TensorVariable:
94+
"""Convolve two one-dimensional arrays.
95+
96+
Convolve in1 and in2, with the output size determined by the mode argument.
97+
98+
Parameters
99+
----------
100+
in1 : (..., N,) tensor_like
101+
First input.
102+
in2 : (..., M,) tensor_like
103+
Second input.
104+
mode : {'full', 'valid', 'same'}, optional
105+
A string indicating the size of the output:
106+
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+M-1,).
107+
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, M) - min(N, M) + 1,).
108+
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
109+
110+
Returns
111+
-------
112+
out: tensor_variable
113+
The discrete linear convolution of in1 with in2.
114+
115+
"""
116+
in1 = as_tensor_variable(in1)
117+
in2 = as_tensor_variable(in2)
118+
119+
if mode == "same":
120+
# We implement "same" as "valid" with padded `in1`.
121+
in1_batch_shape = tuple(in1.shape)[:-1]
122+
zeros_left = in2.shape[0] // 2
123+
zeros_right = (in2.shape[0] - 1) // 2
124+
in1 = join(
125+
-1,
126+
zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype),
127+
in1,
128+
zeros((*in1_batch_shape, zeros_right), dtype=in2.dtype),
129+
)
130+
mode = "valid"
131+
132+
return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2))

Diff for: tests/link/jax/signal/__init__.py

Whitespace-only changes.

Diff for: tests/link/jax/signal/test_conv.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.tensor import dmatrix
5+
from pytensor.tensor.signal import convolve1d
6+
from tests.link.jax.test_basic import compare_jax_and_py
7+
8+
9+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
10+
def test_convolve1d(mode):
11+
x = dmatrix("x")
12+
y = dmatrix("y")
13+
out = convolve1d(x[None], y[:, None], mode=mode)
14+
15+
rng = np.random.default_rng()
16+
test_x = rng.normal(size=(3, 5))
17+
test_y = rng.normal(size=(7, 11))
18+
compare_jax_and_py([x, y], out, [test_x, test_y])

Diff for: tests/link/numba/signal/test_conv.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.tensor import dmatrix
5+
from pytensor.tensor.signal import convolve1d
6+
from tests.link.numba.test_basic import compare_numba_and_py
7+
8+
9+
pytestmark = pytest.mark.filterwarnings("error")
10+
11+
12+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
13+
def test_convolve1d(mode):
14+
x = dmatrix("x")
15+
y = dmatrix("y")
16+
out = convolve1d(x[None], y[:, None], mode=mode)
17+
18+
rng = np.random.default_rng()
19+
test_x = rng.normal(size=(3, 5))
20+
test_y = rng.normal(size=(7, 11))
21+
# Blockwise dispatch for numba can't be run on object mode
22+
compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False)

Diff for: tests/tensor/signal/__init__.py

Whitespace-only changes.

Diff for: tests/tensor/signal/test_conv.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from functools import partial
2+
3+
import numpy as np
4+
import pytest
5+
from scipy.signal import convolve as scipy_convolve
6+
7+
from pytensor import config, function
8+
from pytensor.tensor import matrix, vector
9+
from pytensor.tensor.signal.conv import convolve1d
10+
from tests import unittest_tools as utt
11+
12+
13+
@pytest.mark.parametrize("kernel_shape", [3, 5, 8], ids=lambda x: f"kernel_shape={x}")
14+
@pytest.mark.parametrize("data_shape", [3, 5, 8], ids=lambda x: f"data_shape={x}")
15+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
16+
def test_convolve1d(mode, data_shape, kernel_shape):
17+
data = vector("data")
18+
kernel = vector("kernel")
19+
op = partial(convolve1d, mode=mode)
20+
21+
rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode))))
22+
data_val = rng.normal(size=data_shape).astype(data.dtype)
23+
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)
24+
25+
fn = function([data, kernel], op(data, kernel))
26+
np.testing.assert_allclose(
27+
fn(data_val, kernel_val),
28+
scipy_convolve(data_val, kernel_val, mode=mode),
29+
rtol=1e-6 if config.floatX == "float32" else 1e-15,
30+
)
31+
utt.verify_grad(op=lambda x: op(x, kernel_val), pt=[data_val])
32+
33+
34+
def test_convolve1d_batch():
35+
x = matrix("data")
36+
y = matrix("kernel")
37+
out = convolve1d(x, y)
38+
39+
rng = np.random.default_rng(38)
40+
x_test = rng.normal(size=(2, 8)).astype(x.dtype)
41+
y_test = x_test[::-1]
42+
43+
res = out.eval({x: x_test, y: y_test})
44+
# Second entry of x, y are just y, x respectively,
45+
# so res[0] and res[1] should be identical.
46+
rtol = 1e-6 if config.floatX == "float32" else 1e-15
47+
res_np = np.convolve(x_test[0], y_test[0])
48+
np.testing.assert_allclose(res[0], res_np, rtol=rtol)
49+
np.testing.assert_allclose(res[1], res_np, rtol=rtol)

0 commit comments

Comments
 (0)