diff --git a/pytensor/link/jax/dispatch/elemwise.py b/pytensor/link/jax/dispatch/elemwise.py index 7d9532557b..27b8234d49 100644 --- a/pytensor/link/jax/dispatch/elemwise.py +++ b/pytensor/link/jax/dispatch/elemwise.py @@ -74,12 +74,8 @@ def jax_funcify_DimShuffle(op, **kwargs): def dimshuffle(x): res = jnp.transpose(x, op.transposition) - shape = list(res.shape[: len(op.shuffle)]) - - for augm in op.augment: - shape.insert(augm, 1) - - res = jnp.reshape(res, shape) + res = jax.lax.expand_dims(res, op.augment) + res = jax.lax.squeeze(res, op.drop) if not op.inplace: res = jnp.copy(res)