You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The main challenge as far as I know is getting the gradients correctly propagated.
The path moving forward probably would be to set up a simple model that has a JAX part in it and see if / how it works out.
I imagine the most meaningful thing to do would be to keep the JAX stuff internal to a model component and work with torch tensors on all interfaces...
In particular, the objective system should be considered.
The text was updated successfully, but these errors were encountered:
Let's try to find a way to integrate JAX (ley) in our pipeline.
There are several libraries targeting JAX/torch exchange, eg my quick search showed:
The main challenge as far as I know is getting the gradients correctly propagated.
The path moving forward probably would be to set up a simple model that has a JAX part in it and see if / how it works out.
I imagine the most meaningful thing to do would be to keep the JAX stuff internal to a model component and work with torch tensors on all interfaces...
In particular, the objective system should be considered.
The text was updated successfully, but these errors were encountered: