Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Dimshuffle in jax using expand_dims/squeeze #987

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions pytensor/link/jax/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,8 @@ def jax_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = jnp.transpose(x, op.transposition)

shape = list(res.shape[: len(op.shuffle)])

for augm in op.augment:
shape.insert(augm, 1)

res = jnp.reshape(res, shape)
res = jax.lax.expand_dims(res, op.augment)
res = jax.lax.squeeze(res, op.drop)
Copy link
Member

@ricardoV94 ricardoV94 Sep 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think after the transpose the dropped dimension are all on the right, not in op.drop? So should be something like squeeze(res, tuple(-(np.arange(len(op.drop) + 1)) ? Also I think the squeeze should be done before the expand_dims.

I suspect we need more tests of DimShuffle in JAX backend that test for multiple expand_dims / transposition / drop, to catch accidents like in this refactor (unless I'm wrong and the logic is correct, but I suspect not).

Also not related, but the op.inplace logic below is not needed in JAX, which never does inplace operations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of a transpose, op.drop would be an empty list. But tuple(-(np.arange(len(op.drop) + 1)) would be non empty and if squeeze is done before expand_dims, the logic would fail in this case

Copy link
Member

@ricardoV94 ricardoV94 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but an alternative is still needed. The current implementation in this PR is wrong AFAICT


if not op.inplace:
res = jnp.copy(res)
Expand Down
Loading