From 322f6ee77f5df73ca73a35a7392c648e95f680ea Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Sat, 7 Sep 2024 17:38:59 +0530 Subject: [PATCH] Implement Dimshuffle using expand_dims/squeeze --- pytensor/link/jax/dispatch/elemwise.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytensor/link/jax/dispatch/elemwise.py b/pytensor/link/jax/dispatch/elemwise.py index 7d9532557b..27b8234d49 100644 --- a/pytensor/link/jax/dispatch/elemwise.py +++ b/pytensor/link/jax/dispatch/elemwise.py @@ -74,12 +74,8 @@ def jax_funcify_DimShuffle(op, **kwargs): def dimshuffle(x): res = jnp.transpose(x, op.transposition) - shape = list(res.shape[: len(op.shuffle)]) - - for augm in op.augment: - shape.insert(augm, 1) - - res = jnp.reshape(res, shape) + res = jax.lax.expand_dims(res, op.augment) + res = jax.lax.squeeze(res, op.drop) if not op.inplace: res = jnp.copy(res)