Skip to content

Commit f9306da

Browse files
committed
readme
1 parent 64ddc95 commit f9306da

1 file changed

Lines changed: 33 additions & 0 deletions

File tree

README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,36 @@
11
## jax2torch
22

33
Use Jax functions in Pytorch with DLPack, as outlined <a href="https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9">in a gist</a> by <a href="https://github.com/mattjj">@mattjj</a>. Right now only supports one tensor input (with optional non-tensor input arguments) to one tensor output, for the purposes of <a href="https://github.com/spetti/SMURF">differentiable alignment</a>.
4+
5+
## Install
6+
7+
```bash
8+
$ pip install jax2torch
9+
```
10+
11+
## Usage
12+
13+
```python
14+
import jax
15+
import torch
16+
import jax2torch
17+
18+
@jax.jit
19+
def jax_square(x):
20+
return x ** 2
21+
22+
torch_square = jax2torch(jax_square)
23+
24+
# Run it on Torch data!
25+
26+
x = torch.tensor([1., 2., 3.])
27+
y = torch_square(x)
28+
print(y) # tensor([1., 4., 9.])
29+
30+
# And differentiate!
31+
32+
x = torch.tensor([2., 3.], requires_grad=True)
33+
y = torch.sum(torch_square(x))
34+
y.backward()
35+
print(x.grad) # tensor([4., 9.])
36+
```

0 commit comments

Comments
 (0)