Skip to content

Commit 15c4942

Browse files
add pulse control example
1 parent b761b44 commit 15c4942

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

examples/analog_evolution_jax.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
Parameterized Hamiltonian (Pulse control/Analog simulation) with AD/JIT support using jax ode solver
3+
"""
4+
5+
import optax
6+
from jax.experimental.ode import odeint
7+
import tensorcircuit as tc
8+
9+
K = tc.set_backend("jax")
10+
tc.set_dtype("complex128")
11+
12+
hx = tc.quantum.PauliStringSum2COO([[1]])
13+
hz = tc.quantum.PauliStringSum2COO([[3]])
14+
15+
16+
# psi = -i H psi
17+
# we want to optimize the final z expectation over parameters params
18+
# a single qubit example below
19+
20+
21+
def final_z(b):
22+
def f(y, t, b):
23+
h = b[3] * K.sin(b[0] * t + b[1]) * hx + K.cos(b[2]) * hz
24+
return -1.0j * K.sparse_dense_matmul(h, y)
25+
26+
y0 = tc.array_to_tensor([1, 0])
27+
y0 = K.reshape(y0, [-1, 1])
28+
t = tc.array_to_tensor([0.0, 10.0], dtype=tc.rdtypestr)
29+
yf = odeint(f, y0, t, b)
30+
c = tc.Circuit(1, inputs=K.reshape(yf[-1], [-1]))
31+
return K.real(c.expectation_ps(z=[0]))
32+
33+
34+
vgf = K.jit(K.value_and_grad(final_z))
35+
36+
37+
opt = K.optimizer(optax.adam(0.1))
38+
b = K.implicit_randn([4])
39+
for _ in range(50):
40+
v, gs = vgf(b)
41+
b = opt.update(gs, b)
42+
print(v, b)

0 commit comments

Comments
 (0)