File tree 1 file changed +42
-0
lines changed
1 file changed +42
-0
lines changed Original file line number Diff line number Diff line change
1
+ """
2
+ Parameterized Hamiltonian (Pulse control/Analog simulation) with AD/JIT support using jax ode solver
3
+ """
4
+
5
+ import optax
6
+ from jax .experimental .ode import odeint
7
+ import tensorcircuit as tc
8
+
9
+ K = tc .set_backend ("jax" )
10
+ tc .set_dtype ("complex128" )
11
+
12
+ hx = tc .quantum .PauliStringSum2COO ([[1 ]])
13
+ hz = tc .quantum .PauliStringSum2COO ([[3 ]])
14
+
15
+
16
+ # psi = -i H psi
17
+ # we want to optimize the final z expectation over parameters params
18
+ # a single qubit example below
19
+
20
+
21
+ def final_z (b ):
22
+ def f (y , t , b ):
23
+ h = b [3 ] * K .sin (b [0 ] * t + b [1 ]) * hx + K .cos (b [2 ]) * hz
24
+ return - 1.0j * K .sparse_dense_matmul (h , y )
25
+
26
+ y0 = tc .array_to_tensor ([1 , 0 ])
27
+ y0 = K .reshape (y0 , [- 1 , 1 ])
28
+ t = tc .array_to_tensor ([0.0 , 10.0 ], dtype = tc .rdtypestr )
29
+ yf = odeint (f , y0 , t , b )
30
+ c = tc .Circuit (1 , inputs = K .reshape (yf [- 1 ], [- 1 ]))
31
+ return K .real (c .expectation_ps (z = [0 ]))
32
+
33
+
34
+ vgf = K .jit (K .value_and_grad (final_z ))
35
+
36
+
37
+ opt = K .optimizer (optax .adam (0.1 ))
38
+ b = K .implicit_randn ([4 ])
39
+ for _ in range (50 ):
40
+ v , gs = vgf (b )
41
+ b = opt .update (gs , b )
42
+ print (v , b )
You can’t perform that action at this time.
0 commit comments