File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -13,24 +13,28 @@ $ pip install jax2torch
1313``` python
1414import jax
1515import torch
16- import jax2torch
16+ from jax2torch import jax2torch
17+
18+ # Jax function
1719
1820@jax.jit
19- def jax_square (x ):
20- return x ** 2
21+ def jax_pow (x , y = 2 ):
22+ return x ** y
23+
24+ # convert to Torch function
2125
22- torch_square = jax2torch(jax_square )
26+ torch_pow = jax2torch(jax_pow )
2327
24- # Run it on Torch data!
28+ # run it on Torch data!
2529
2630x = torch.tensor([1 ., 2 ., 3 .])
27- y = torch_square(x )
28- print (y) # tensor([1., 4 ., 9 .])
31+ y = torch_pow(x, 3 )
32+ print (y) # tensor([1., 8 ., 27 .])
2933
3034# And differentiate!
3135
3236x = torch.tensor([2 ., 3 .], requires_grad = True )
33- y = torch.sum(torch_square(x ))
37+ y = torch.sum(torch_pow(x, 3 ))
3438y.backward()
35- print (x.grad) # tensor([4 ., 9 .])
39+ print (x.grad) # tensor([12 ., 27 .])
3640```
You can’t perform that action at this time.
0 commit comments