Skip to content

Commit e9bcd4f

Browse files
committed
update readme
1 parent a745a23 commit e9bcd4f

1 file changed

Lines changed: 13 additions & 9 deletions

File tree

README.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,28 @@ $ pip install jax2torch
1313
```python
1414
import jax
1515
import torch
16-
import jax2torch
16+
from jax2torch import jax2torch
17+
18+
# Jax function
1719

1820
@jax.jit
19-
def jax_square(x):
20-
return x ** 2
21+
def jax_pow(x, y=2):
22+
return x ** y
23+
24+
# convert to Torch function
2125

22-
torch_square = jax2torch(jax_square)
26+
torch_pow = jax2torch(jax_pow)
2327

24-
# Run it on Torch data!
28+
# run it on Torch data!
2529

2630
x = torch.tensor([1., 2., 3.])
27-
y = torch_square(x)
28-
print(y) # tensor([1., 4., 9.])
31+
y = torch_pow(x, 3)
32+
print(y) # tensor([1., 8., 27.])
2933

3034
# And differentiate!
3135

3236
x = torch.tensor([2., 3.], requires_grad=True)
33-
y = torch.sum(torch_square(x))
37+
y = torch.sum(torch_pow(x, 3))
3438
y.backward()
35-
print(x.grad) # tensor([4., 9.])
39+
print(x.grad) # tensor([12., 27.])
3640
```

0 commit comments

Comments
 (0)