Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX Integration #62

Open
fabioseel opened this issue Nov 19, 2024 · 0 comments
Open

JAX Integration #62

fabioseel opened this issue Nov 19, 2024 · 0 comments
Labels
Feature A new capability in the library

Comments

@fabioseel
Copy link
Contributor

fabioseel commented Nov 19, 2024

Let's try to find a way to integrate JAX (ley) in our pipeline.
There are several libraries targeting JAX/torch exchange, eg my quick search showed:

The main challenge as far as I know is getting the gradients correctly propagated.

The path moving forward probably would be to set up a simple model that has a JAX part in it and see if / how it works out.
I imagine the most meaningful thing to do would be to keep the JAX stuff internal to a model component and work with torch tensors on all interfaces...
In particular, the objective system should be considered.

@fabioseel fabioseel added the Feature A new capability in the library label Nov 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature A new capability in the library
Projects
None yet
Development

No branches or pull requests

1 participant