|
| 1 | +import jax.numpy as jnp |
| 2 | +import numpy as np |
| 3 | + |
| 4 | +from pytensor.link.jax.dispatch import jax_funcify |
| 5 | +from pytensor.tensor.pad import Pad |
| 6 | + |
| 7 | + |
| 8 | +@jax_funcify.register(Pad) |
| 9 | +def jax_funcify_pad(op, **kwargs): |
| 10 | + pad_mode = op.pad_mode |
| 11 | + reflect_type = op.reflect_type |
| 12 | + has_stat_length = op.has_stat_length |
| 13 | + |
| 14 | + if pad_mode == "constant": |
| 15 | + |
| 16 | + def constant_pad(x, pad_width, constant_values): |
| 17 | + return jnp.pad(x, pad_width, mode=pad_mode, constant_values=constant_values) |
| 18 | + |
| 19 | + return constant_pad |
| 20 | + |
| 21 | + elif pad_mode == "linear_ramp": |
| 22 | + |
| 23 | + def lr_pad(x, pad_width, end_values): |
| 24 | + # JAX does not allow a dynamic input if end_values is non-scalar |
| 25 | + if not isinstance(end_values, int | float): |
| 26 | + end_values = tuple(np.array(end_values)) |
| 27 | + return jnp.pad(x, pad_width, mode=pad_mode, end_values=end_values) |
| 28 | + |
| 29 | + return lr_pad |
| 30 | + |
| 31 | + elif pad_mode in ["maximum", "minimum", "mean"] and has_stat_length: |
| 32 | + |
| 33 | + def stat_pad(x, pad_width, stat_length): |
| 34 | + # JAX does not allow a dynamic input here, need to cast to tuple |
| 35 | + return jnp.pad( |
| 36 | + x, pad_width, mode=pad_mode, stat_length=tuple(np.array(stat_length)) |
| 37 | + ) |
| 38 | + |
| 39 | + return stat_pad |
| 40 | + |
| 41 | + elif pad_mode in ["reflect", "symmetric"]: |
| 42 | + |
| 43 | + def loop_pad(x, pad_width): |
| 44 | + return jnp.pad(x, pad_width, mode=pad_mode, reflect_type=reflect_type) |
| 45 | + |
| 46 | + return loop_pad |
| 47 | + |
| 48 | + else: |
| 49 | + |
| 50 | + def pad(x, pad_width): |
| 51 | + return jnp.pad(x, pad_width, mode=pad_mode) |
| 52 | + |
| 53 | + return pad |
0 commit comments