Skip to content

Commit ab99a1e

Browse files
First draft JAX overload
1 parent c28faaa commit ab99a1e

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

pytensor/link/jax/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytensor.link.jax.dispatch.subtensor
88
import pytensor.link.jax.dispatch.shape
99
import pytensor.link.jax.dispatch.extra_ops
10+
import pytensor.link.jax.dispatch.pad
1011
import pytensor.link.jax.dispatch.nlinalg
1112
import pytensor.link.jax.dispatch.slinalg
1213
import pytensor.link.jax.dispatch.random

pytensor/link/jax/dispatch/pad.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import jax.numpy as jnp
22

33
from pytensor.link.jax.dispatch import jax_funcify
4-
from pytensor.tensor.pad import Pad, allowed_kwargs
4+
from pytensor.tensor.pad import Pad
55

66

7-
@jax_funcify([Pad])
7+
fixed_kwargs = {"reflect": ["reflect_type"], "symmetric": ["reflect_type"]}
8+
9+
10+
@jax_funcify.register(Pad)
811
def jax_funcify_pad(op, **kwargs):
912
pad_mode = op.pad_mode
10-
expected_kwargs = allowed_kwargs[pad_mode]
13+
expected_kwargs = fixed_kwargs.get(pad_mode, {})
1114
mode_kwargs = {kwarg: getattr(op, kwarg) for kwarg in expected_kwargs}
1215

13-
def pad(x, pad_width, pad_mode=pad_mode):
14-
return jnp.pad(x, pad_width=pad_width, pad_mode=pad_mode, **mode_kwargs)
16+
def pad(x, pad_width, *args):
17+
print(args)
18+
return jnp.pad(x, pad_width, mode=pad_mode, **mode_kwargs)
1519

1620
return pad

tests/link/jax/test_pad.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor.tensor as pt
5+
from pytensor import config
6+
from pytensor.graph import FunctionGraph
7+
from pytensor.tensor.pad import PadMode
8+
from tests.link.jax.test_basic import compare_jax_and_py
9+
10+
11+
jax = pytest.importorskip("jax")
12+
floatX = config.floatX
13+
RTOL = ATOL = 1e-6 if floatX.endswith("64") else 1e-3
14+
15+
test_kwargs = {
16+
"constant": {"constant_values": 0},
17+
"linear_ramp": {"end_values": 0},
18+
"maximum": {"stat_length": None},
19+
"mean": {"stat_length": [[1, 2], [3, 3]]},
20+
"median": {"stat_length": 2},
21+
"reflect": {"reflect_type": "even"},
22+
"symmetric": {"reflect_type": "even"},
23+
}
24+
25+
26+
@pytest.mark.parametrize(
27+
"mode",
28+
[
29+
"constant",
30+
"edge",
31+
"linear_ramp",
32+
"wrap",
33+
"symmetric",
34+
"mean",
35+
"maximum",
36+
"minimum",
37+
],
38+
)
39+
def test_jax_pad(mode: PadMode):
40+
x_pt = pt.dmatrix("x")
41+
x = np.random.normal(size=(3, 3))
42+
kwargs = test_kwargs.get(mode, {})
43+
44+
res = pt.pad(x_pt, mode=mode, pad_width=3, **kwargs)
45+
res_fg = FunctionGraph([x_pt], [res])
46+
47+
compare_jax_and_py(
48+
res_fg,
49+
[x],
50+
assert_fn=lambda x, y: np.testing.assert_allclose(x, y, rtol=RTOL, atol=ATOL),
51+
py_mode="FAST_COMPILE",
52+
)

0 commit comments

Comments
 (0)