Skip to content

Commit ce785c7

Browse files
committed
make sure dlpack inits correctly
1 parent e9bcd4f commit ce785c7

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

jax2torch/jax2torch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
# https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9
22

3-
import torch
3+
from jax import dlpack as jax_dlpack
4+
from torch.utils import dlpack as torch_dlpack
45

5-
import jax
66
import jax.numpy as jnp
77
from jax.tree_util import tree_map
88

99
def j2t(x_jax):
10-
x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
10+
x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
1111
return x_torch
1212

1313
def t2j(x_torch):
1414
x_torch = x_torch.contiguous()
15-
x_jax = jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x_torch))
15+
x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
1616
return x_jax
1717

1818
def tree_t2j(x_torch):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'jax2torch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.1',
6+
version = '0.0.2',
77
license='MIT',
88
description = 'Jax 2 Torch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)