from functools import partial
import jax
from jax2torch import jax2torch
import torch
class A:
def __init__(self, x):
self.x = x
@partial(jax.jit, static_argnames=["a"])
def jax_func(b, a):
return b + a.x
torch_func = jax2torch(jax_func)
def main():
a = A(1.0)
b = torch.tensor([1.0])
out = torch_func(a, b)
print(f"{out=}")
if __name__ == "__main__":
main()
Traceback (most recent call last):
File "/tmp/example.py", line 25, in <module>
main()
File "/tmp/example.py", line 21, in main
out = torch_func(a, b)
^^^^^^^^^^^^^^^^
File "/home/asdf/.local/lib/python3.11/site-packages/jax2torch/jax2torch.py", line 54, in inner
return JaxFun.apply(*bound.arguments.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fdsa/conda-envs/asdfasdf/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/asdf/.local/lib/python3.11/site-packages/jax2torch/jax2torch.py", line 37, in forward
y_, ctx.fun_vjp = jax.vjp(fn, *args)
^^^^^^^^^^^^^^^^^^
File "/fdsa/conda-envs/asdfasdf/lib/python3.11/site-packages/jax/_src/api.py", line 2169, in vjp
return _vjp(
^^^^^
File "/fdsa/conda-envs/asdfasdf/lib/python3.11/site-packages/jax/_src/api.py", line 2175, in _vjp
for arg in primals_flat: dispatch.check_arg(arg)
^^^^^^^^^^^^^^^^^^^^^^^
File "/fdsa/conda-envs/asdfasdf/lib/python3.11/site-packages/jax/_src/dispatch.py", line 282, in check_arg
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
TypeError: Argument '<__main__.A object at 0x15549315a850>' of type <class '__main__.A'> is not a valid JAX type.
The following code fails:
Error:
Obviously, I could make the
Atype a JAX pytree and it would work, but this is somewhat of a pain if I just want to treat it as a static argument. What are your thoughts?