Skip to content

Commit 46a906b

Browse files
Add test for arbitrary padding at higher dimensions
1 parent 0f3c464 commit 46a906b

File tree

1 file changed

+49
-7
lines changed

1 file changed

+49
-7
lines changed

tests/tensor/test_pad.py

+49-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_constant_pad(
2626
x = np.random.normal(size=size).astype(floatX)
2727
expected = np.pad(x, pad_width, mode="constant", constant_values=constant)
2828
z = pad(x, pad_width, mode="constant", constant_values=constant)
29-
assert z.pad_mode == "constant"
29+
assert z.owner.op.pad_mode == "constant"
3030

3131
f = pytensor.function([], z, mode="FAST_COMPILE")
3232

@@ -43,7 +43,7 @@ def test_edge_pad(size: tuple, pad_width: int | tuple[int, ...]):
4343
x = np.random.normal(size=size).astype(floatX)
4444
expected = np.pad(x, pad_width, mode="edge")
4545
z = pad(x, pad_width, mode="edge")
46-
assert z.pad_mode == "edge"
46+
assert z.owner.op.pad_mode == "edge"
4747

4848
f = pytensor.function([], z, mode="FAST_COMPILE")
4949

@@ -65,7 +65,7 @@ def test_linear_ramp_pad(
6565
x = np.random.normal(size=size).astype(floatX)
6666
expected = np.pad(x, pad_width, mode="linear_ramp", end_values=end_values)
6767
z = pad(x, pad_width, mode="linear_ramp", end_values=end_values)
68-
assert z.pad_mode == "linear_ramp"
68+
assert z.owner.op.pad_mode == "linear_ramp"
6969

7070
f = pytensor.function([], z, mode="FAST_COMPILE")
7171

@@ -89,8 +89,7 @@ def test_stat_pad(
8989
x = np.random.normal(size=size).astype(floatX)
9090
expected = np.pad(x, pad_width, mode=stat, stat_length=stat_length)
9191
z = pad(x, pad_width, mode=stat, stat_length=stat_length)
92-
assert z.pad_mode == stat
93-
assert z.stat_length_input == (stat_length is not None)
92+
assert z.owner.op.pad_mode == stat
9493

9594
f = pytensor.function([], z, mode="FAST_COMPILE")
9695

@@ -107,7 +106,7 @@ def test_wrap_pad(size: tuple, pad_width: int | tuple[int, ...]):
107106
x = np.random.normal(size=size).astype(floatX)
108107
expected = np.pad(x, pad_width, mode="wrap")
109108
z = pad(x, pad_width, mode="wrap")
110-
assert z.pad_mode == "wrap"
109+
assert z.owner.op.pad_mode == "wrap"
111110
f = pytensor.function([], z, mode="FAST_COMPILE")
112111

113112
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@@ -128,7 +127,50 @@ def test_symmetric_pad(size, pad_width, reflect_type):
128127
x = np.random.normal(size=size).astype(floatX)
129128
expected = np.pad(x, pad_width, mode="symmetric", reflect_type=reflect_type)
130129
z = pad(x, pad_width, mode="symmetric", reflect_type=reflect_type)
131-
assert z.pad_mode == "symmetric"
130+
assert z.owner.op.pad_mode == "symmetric"
131+
f = pytensor.function([], z, mode="FAST_COMPILE")
132+
133+
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
134+
135+
136+
@pytest.mark.parametrize(
137+
"mode",
138+
[
139+
"constant",
140+
"edge",
141+
"linear_ramp",
142+
"wrap",
143+
"symmetric",
144+
"mean",
145+
"maximum",
146+
"minimum",
147+
],
148+
)
149+
@pytest.mark.parametrize("padding", ["symmetric", "asymmetric"])
150+
def test_nd_padding(mode, padding):
151+
rng = np.random.default_rng()
152+
n = rng.integers(3, 10)
153+
if padding == "symmetric":
154+
pad_width = [(i, i) for i in rng.integers(1, 5, size=n)]
155+
stat_length = [(i, i) for i in rng.integers(1, 5, size=n)]
156+
else:
157+
pad_width = rng.integers(1, 5, size=(n, 2)).tolist()
158+
stat_length = rng.integers(1, 5, size=(n, 2)).tolist()
159+
160+
test_kwargs = {
161+
"constant": {"constant_values": 0},
162+
"linear_ramp": {"end_values": 0},
163+
"maximum": {"stat_length": stat_length},
164+
"mean": {"stat_length": stat_length},
165+
"minimum": {"stat_length": stat_length},
166+
"reflect": {"reflect_type": "even"},
167+
"symmetric": {"reflect_type": "even"},
168+
}
169+
170+
x = np.random.normal(size=(2,) * n).astype(floatX)
171+
kwargs = test_kwargs.get(mode, {})
172+
expected = np.pad(x, pad_width, mode=mode, **kwargs)
173+
z = pad(x, pad_width, mode=mode, **kwargs)
132174
f = pytensor.function([], z, mode="FAST_COMPILE")
133175

134176
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)

0 commit comments

Comments
 (0)