File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 11# https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9
22
3- import torch
3+ from jax import dlpack as jax_dlpack
4+ from torch .utils import dlpack as torch_dlpack
45
5- import jax
66import jax .numpy as jnp
77from jax .tree_util import tree_map
88
99def j2t (x_jax ):
10- x_torch = torch . utils . dlpack . from_dlpack (jax . dlpack .to_dlpack (x_jax ))
10+ x_torch = torch_dlpack . from_dlpack (jax_dlpack .to_dlpack (x_jax ))
1111 return x_torch
1212
1313def t2j (x_torch ):
1414 x_torch = x_torch .contiguous ()
15- x_jax = jax . dlpack . from_dlpack (torch . utils . dlpack .to_dlpack (x_torch ))
15+ x_jax = jax_dlpack . from_dlpack (torch_dlpack .to_dlpack (x_torch ))
1616 return x_jax
1717
1818def tree_t2j (x_torch ):
Original file line number Diff line number Diff line change 33setup (
44 name = 'jax2torch' ,
55 packages = find_packages (exclude = []),
6- version = '0.0.1 ' ,
6+ version = '0.0.2 ' ,
77 license = 'MIT' ,
88 description = 'Jax 2 Torch' ,
99 author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments