Skip to content

Commit 981688c

Browse files
Implement pad (#748)
* Add `pt.pad` * Refactor linspace, logspace, and geomspace to match numpy implementation * Add `pt.flip` * Move `flip` to `tensor/subtensor.py`, add docstring * Move `slice_at_axis` to `tensor/subtensor` and expose it in `pytensor.tensor`
1 parent f489cf4 commit 981688c

File tree

11 files changed

+1632
-40
lines changed

11 files changed

+1632
-40
lines changed

pytensor/link/jax/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytensor.link.jax.dispatch.blockwise
77
import pytensor.link.jax.dispatch.elemwise
88
import pytensor.link.jax.dispatch.extra_ops
9+
import pytensor.link.jax.dispatch.pad
910
import pytensor.link.jax.dispatch.math
1011
import pytensor.link.jax.dispatch.nlinalg
1112
import pytensor.link.jax.dispatch.random

pytensor/link/jax/dispatch/pad.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import jax.numpy as jnp
2+
import numpy as np
3+
4+
from pytensor.link.jax.dispatch import jax_funcify
5+
from pytensor.tensor.pad import Pad
6+
7+
8+
@jax_funcify.register(Pad)
9+
def jax_funcify_pad(op, **kwargs):
10+
pad_mode = op.pad_mode
11+
reflect_type = op.reflect_type
12+
has_stat_length = op.has_stat_length
13+
14+
if pad_mode == "constant":
15+
16+
def constant_pad(x, pad_width, constant_values):
17+
return jnp.pad(x, pad_width, mode=pad_mode, constant_values=constant_values)
18+
19+
return constant_pad
20+
21+
elif pad_mode == "linear_ramp":
22+
23+
def lr_pad(x, pad_width, end_values):
24+
# JAX does not allow a dynamic input if end_values is non-scalar
25+
if not isinstance(end_values, int | float):
26+
end_values = tuple(np.array(end_values))
27+
return jnp.pad(x, pad_width, mode=pad_mode, end_values=end_values)
28+
29+
return lr_pad
30+
31+
elif pad_mode in ["maximum", "minimum", "mean"] and has_stat_length:
32+
33+
def stat_pad(x, pad_width, stat_length):
34+
# JAX does not allow a dynamic input here, need to cast to tuple
35+
return jnp.pad(
36+
x, pad_width, mode=pad_mode, stat_length=tuple(np.array(stat_length))
37+
)
38+
39+
return stat_pad
40+
41+
elif pad_mode in ["reflect", "symmetric"]:
42+
43+
def loop_pad(x, pad_width):
44+
return jnp.pad(x, pad_width, mode=pad_mode, reflect_type=reflect_type)
45+
46+
return loop_pad
47+
48+
else:
49+
50+
def pad(x, pad_width):
51+
return jnp.pad(x, pad_width, mode=pad_mode)
52+
53+
return pad

pytensor/tensor/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
130130
from pytensor.tensor.extra_ops import *
131131
from pytensor.tensor.io import *
132132
from pytensor.tensor.math import *
133+
from pytensor.tensor.pad import pad
133134
from pytensor.tensor.shape import (
134135
reshape,
135136
shape,

0 commit comments

Comments
 (0)