Skip to content

Commit 58246e4

Browse files
authored
Merge pull request #116 from Langhaarzombie/feature/ntbk_jax
QuTiPv5 Paper Notebook: JAX
2 parents 629548e + 4f282da commit 58246e4

File tree

1 file changed

+335
-0
lines changed

1 file changed

+335
-0
lines changed
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
---
2+
jupyter:
3+
jupytext:
4+
text_representation:
5+
extension: .md
6+
format_name: markdown
7+
format_version: '1.3'
8+
jupytext_version: 1.13.8
9+
kernelspec:
10+
display_name: Python 3 (ipykernel)
11+
language: python
12+
name: python3
13+
---
14+
15+
# QuTiPv5 Paper Example: QuTiP-JAX with mesolve and auto-differnetiation
16+
17+
Authors: Maximilian Meyer-Mölleringhof ([email protected]), Rochisha Agarwal ([email protected]), Neill Lambert ([email protected])
18+
19+
For many years now, GPUs have been a fundamental tool for accelerating numerical tasks.
20+
Today, many libraries enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations.
21+
QuTiP’s flexible data layer can directly be used with many such libraries and thereby drastically reduce computation time.
22+
Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration [\[1\]](#References) due to JAX's robust auto-differentiation features and widespread adoption in machine learning.
23+
24+
In these examples we illustrate how JAX naturally integrates into QuTiP v5 [\[2\]](#References) using the QuTiP-JAX package.
25+
As a simple first example, we look at a one-dimensional spin chain and how we might employ `mesolve()` and JAX to solve the related master equation.
26+
In the second part, we focus on the auto-differentiation capabilities.
27+
For this we first consider the counting statistics of an open quantum system connected to an environment.
28+
Lastly, we look at a driven qubit system and computing the gradient of its state population in respect to the drive's frequency using `mcsolve()`.
29+
30+
## Introduction
31+
32+
In addition to the standard QuTiP package, in order to use QuTiP-JAX, the JAX package needs to be installed.
33+
This package also comes with `jax.numpy` which mirrors all of `numpy`s functionality for seamless integration with JAX.
34+
35+
```python
36+
import jax.numpy as jnp
37+
import matplotlib.pyplot as plt
38+
import qutip_jax as qj
39+
from diffrax import PIDController, Tsit5
40+
from jax import default_device, devices, grad, jacfwd, jacrev, jit
41+
from qutip import (CoreOptions, about, basis, destroy, lindblad_dissipator,
42+
liouvillian, mcsolve, mesolve, projection, qeye, settings,
43+
sigmam, sigmax, sigmay, sigmaz, spost, spre, sprepost,
44+
steadystate, tensor)
45+
46+
%matplotlib inline
47+
```
48+
49+
An immediate effect of importing `qutip_jax` is the availability of the data layer formats `jax` and `jax_dia`.
50+
They allow for dense (`jax`) and custom sparse (`jax_dia`) formats.
51+
52+
```python
53+
print(qeye(3, dtype="jax").dtype.__name__)
54+
print(qeye(3, dtype="jaxdia").dtype.__name__)
55+
```
56+
57+
In order to use the JAX data layer also within the master equation solver, there are two settings we can choose.
58+
First, is simply adding `method: diffrax` to the options parameter of the solver.
59+
Second, is to use the `qutip_jax.set_as_default()` method.
60+
It automatically switches all data data types to JAX compatible versions and sets the default solver method to `diffrax`.
61+
62+
```python
63+
qj.set_as_default()
64+
```
65+
66+
To revert this setting, we can set the function parameter `revert = True`.
67+
68+
```python
69+
# qj.set_as_default(revert = True)
70+
```
71+
72+
## Using the GPU with JAX and Diffrax - 1D Ising Spin Chain
73+
74+
Before diving into the example, it is worth noting here that GPU acceleration depends heavily on the type of problem.
75+
GPUs are good at parallelizing many small matrix-vector operations, such as integrating small systems across multiple parameters or simulating quantum circuits with repeated small matrix operations.
76+
For a single ODE involving large matrices, the advantages are less straightforward since ODE solvers are inherently sequential.
77+
However, as it is illustrated in the QuTiP v5 paper [\[2\]](#References), there is a cross-over point at which using JAX becomes beneficial.
78+
79+
### 1D Ising Spin Chain
80+
81+
To illustrate the usage of QuTiP-JAX, we look at the one-dimensional spin chain with the Hamiltonian
82+
83+
$H = \sum_{i=1}^N g_0 \sigma_z^{(n)} - \sum_{n=1}^{N-1} J_0 \sigma_x^{(n)} \sigma_x^{(n+1)}$.
84+
85+
We hereby consider $N$ spins that share an energy splitting of $g_0$ and have a coupling strength $J_0$.
86+
The end of the chain connects to an environment described by a Lindbladian dissipator which we model with the collapse operator $\sigma_x^{(N-1)}$ and coupling rate $\gamma$.
87+
88+
As part of the [QuTiPv5 paper](#References), we see an extensive study on the computation time depending on the dimensionality $N$.
89+
In this example we cannot replicate the performance of a supercomputer of course, so we rather focus on the correct implementation to solve the Lindblad equation for this system using JAX and `mesolve()`.
90+
91+
```python
92+
# system parameters
93+
N = 4 # number of spins
94+
g0 = 1 # energy splitting
95+
J0 = 1.4 # coupling strength
96+
gamma = 0.1 # dissipation rate
97+
98+
# simulation parameters
99+
tlist = jnp.linspace(0, 5, 100)
100+
opt = {
101+
"normalize_output": False,
102+
"store_states": True,
103+
"method": "diffrax",
104+
"stepsize_controller": PIDController(
105+
rtol=settings.core["rtol"], atol=settings.core["atol"]
106+
),
107+
"solver": Tsit5(),
108+
}
109+
```
110+
111+
```python
112+
with CoreOptions(default_dtype="jaxdia"):
113+
# Operators for individual qubits
114+
sx_list, sy_list, sz_list = [], [], []
115+
for i in range(N):
116+
op_list = [qeye(2)] * N
117+
op_list[i] = sigmax()
118+
sx_list.append(tensor(op_list))
119+
op_list[i] = sigmay()
120+
sy_list.append(tensor(op_list))
121+
op_list[i] = sigmaz()
122+
sz_list.append(tensor(op_list))
123+
124+
# Hamiltonian - Energy splitting terms
125+
H = 0.0
126+
for i in range(N):
127+
H += g0 * sz_list[i]
128+
129+
# Interaction terms
130+
for n in range(N - 1):
131+
H += -J0 * sx_list[n] * sx_list[n + 1]
132+
133+
# Collapse operator acting locally on single spin
134+
c_ops = [gamma * sx_list[N - 1]]
135+
136+
# Initial state
137+
state_list = [basis(2, 1)] * (N - 1)
138+
state_list.append(basis(2, 0))
139+
psi0 = tensor(state_list)
140+
141+
result = mesolve(H, psi0, tlist, c_ops, e_ops=sz_list, options=opt)
142+
```
143+
144+
```python
145+
for i, s in enumerate(result.expect):
146+
plt.plot(tlist, s, label=rf"$n = {i+1}$")
147+
plt.xlabel("Time")
148+
plt.ylabel(r"$\langle \sigma^{(n)}_z \rangle$")
149+
plt.legend()
150+
plt.show()
151+
```
152+
153+
## Auto-Differentiation
154+
155+
We have seen in the previous example how the new JAX data-layer in QuTiP works.
156+
On top of that, JAX adds the features of auto-differentiation.
157+
To compute derivatives, it is often numerical approximations (e.g., finite difference method) that need to be employed.
158+
Especially for higher order derivatives, these methods can turn into costly and inaccurate calculations.
159+
160+
Auto-differentiation, on the other hand, exploits the chain rule to compute such derivatives.
161+
The idea is that any numerical function can be expressed by elementary analytical functions and operations.
162+
Consequently, using the chain rule, the derivatives of almost any higher-level function become accessible.
163+
164+
Although there are many applications for this technique, in this chapter we want to focus on two examples where auto-differentiation becomes relevant.
165+
166+
### Statistics of Excitations between Quantum System and Environment
167+
168+
We consider an open quantum system that is in contact with an evironment via a single jump operator.
169+
Additionally, we have a measurement device that tracks the flow of excitations between the system and the environment.
170+
The probability distribution that describes the number of such exchanged excitations $n$ in a certain time $t$ is called the full counting statistics and denoted by $P_n(t)$.
171+
This statistics is a defining property that allows to derive many experimental observables like shot noise or current.
172+
173+
For the example here, we can calculate this statistics by using a modified version of the density operator and Lindblad master equation.
174+
We introduce the *tilted* density operator $G(z,t) = \sum_n e^{zn} \rho^n (t)$ with $\rho^n(t)$ being the density operator of the system conditioned on $n$ exchanges by time $t$, so $\text{Tr}[\rho^n(t)] = P_n(t)$.
175+
The master equation for this operator, including the jump operator $C$, is then given as
176+
177+
$\dot{G}(z,t) = -\dfrac{i}{\hbar} [H(t), G(z,t)] + \dfrac{1}{2} [2 e^z C \rho(t)C^\dagger - \rho C^\dagger C - C^\dagger C \rho(t)]$.
178+
179+
We see that for $z = 0$, this master equation becomes the regular Lindblad equation and $G(0,t) = \rho(t)$.
180+
However, it also allows us to describe the counting statistics through its derivatives
181+
182+
$\langle n^m \rangle (t) = \sum_n n^m \text{Tr} [\rho^n (t)] = \dfrac{d^m}{dz^m} \text{Tr} [G(z,t)]|_{z=0}$.
183+
184+
These derivatives are precisely where the auto-differention by JAX finds its application for us.
185+
186+
```python
187+
# system parameters
188+
ed = 1
189+
GammaL = 1
190+
GammaR = 1
191+
192+
# simulation parameters
193+
options = {
194+
"method": "diffrax",
195+
"normalize_output": False,
196+
"stepsize_controller": PIDController(rtol=1e-7, atol=1e-7),
197+
"solver": Tsit5(scan_kind="bounded"),
198+
"progress_bar": False,
199+
}
200+
```
201+
202+
When working with JAX you can choose the type of device / processor to be used.
203+
In our case, we will resort to the CPU since this is a simple Jupyter Notebook.
204+
However, when running this on your machine, you can opt for using your GPU by simpy changing the argument below.
205+
206+
```python
207+
with default_device(devices("cpu")[0]):
208+
with CoreOptions(default_dtype="jaxdia"):
209+
d = destroy(2)
210+
H = ed * d.dag() * d
211+
c_op_L = jnp.sqrt(GammaL) * d.dag()
212+
c_op_R = jnp.sqrt(GammaR) * d
213+
214+
L0 = (
215+
liouvillian(H)
216+
+ lindblad_dissipator(c_op_L)
217+
- 0.5 * spre(c_op_R.dag() * c_op_R)
218+
- 0.5 * spost(c_op_R.dag() * c_op_R)
219+
)
220+
L1 = sprepost(c_op_R, c_op_R.dag())
221+
222+
rho0 = steadystate(L0 + L1)
223+
224+
def rhoz(t, z):
225+
L = L0 + jnp.exp(z) * L1 # jump term
226+
tlist = jnp.linspace(0, t, 50)
227+
result = mesolve(L, rho0, tlist, options=options)
228+
return result.final_state.tr()
229+
230+
# first derivative
231+
drhozdz = jacrev(rhoz, argnums=1)
232+
# second derivative
233+
d2rhozdz = jacfwd(drhozdz, argnums=1)
234+
```
235+
236+
```python
237+
tf = 100
238+
Itest = GammaL * GammaR / (GammaL + GammaR)
239+
shottest = Itest * (1 - 2 * GammaL * GammaR / (GammaL + GammaR) ** 2)
240+
ncurr = drhozdz(tf, 0.0) / tf
241+
nshot = (d2rhozdz(tf, 0.0) - drhozdz(tf, 0.0) ** 2) / tf
242+
243+
print("===== RESULTS =====")
244+
print("Analytical current", Itest)
245+
print("Numerical current", ncurr)
246+
print("Analytical shot noise (2nd cumulant)", shottest)
247+
print("Numerical shot noise (2nd cumulant)", nshot)
248+
```
249+
250+
### Driven One Qubit System & Frequency Optimization
251+
252+
As a second example for auto differentiation, we consider the driven Rabi model, which is given by the time-dependent Hamiltonian
253+
254+
$H(t) = \dfrac{\hbar \omega_0}{2} \sigma_z + \dfrac{\hbar \Omega}{2} \cos (\omega t) \sigma_x$
255+
256+
with the energy splitting $\omega_0$, $\Omega$ as the Rabi frequency, the drive frequency $\omega$ and $\sigma_{x/z}$ are Pauli matrices.
257+
When we add dissipation to the system, the dynamics is given by the Lindblad master equation, which introduces collapse operator $C = \sqrt{\gamma} \sigma_-$ to describe energy relaxation.
258+
259+
For this example, we are interested in the population of the excited state of the qubit
260+
261+
$P_e(t) = \bra{e} \rho(t) \ket{e}$
262+
263+
and its gradient with respect to the frequency $\omega$.
264+
265+
We want to optimize this quantity by adjusting the drive frequency $\omega$.
266+
To achieve this, we compute the gradient of $P_e(t)$ in respect to $\omega$ by using JAX's auto-differentiation tools and QuTiP's `mcsolve()`.
267+
268+
```python
269+
# system parameters
270+
gamma = 0.1 # dissipation rate
271+
```
272+
273+
```python
274+
# time dependent drive
275+
@jit
276+
def driving_coeff(t, omega):
277+
return jnp.cos(omega * t)
278+
279+
280+
# system Hamiltonian
281+
def setup_system():
282+
H_0 = sigmaz()
283+
H_1 = sigmax()
284+
H = [H_0, [H_1, driving_coeff]]
285+
return H
286+
```
287+
288+
```python
289+
# simulation parameters
290+
psi0 = basis(2, 0)
291+
tlist = jnp.linspace(0.0, 10.0, 100)
292+
c_ops = [jnp.sqrt(gamma) * sigmam()]
293+
e_ops = [projection(2, 1, 1)]
294+
```
295+
296+
```python
297+
# Objective function: returns final exc. state population
298+
def f(omega):
299+
H = setup_system()
300+
arg = {"omega": omega}
301+
result = mcsolve(H, psi0, tlist, c_ops, e_ops=e_ops, ntraj=100, args=arg)
302+
return result.expect[0][-1]
303+
```
304+
305+
```python
306+
# Gradient of the excited state population with respect to omega
307+
grad_f = grad(f)(2.0)
308+
```
309+
310+
```python
311+
print(grad_f)
312+
```
313+
314+
## References
315+
316+
317+
318+
319+
[1] [QuTiP-JAX](https://github.com/qutip/qutip-jax)
320+
321+
[2] [QuTiP 5: The Quantum Toolbox in Python](https://arxiv.org/abs/2412.04705)
322+
323+
324+
## About
325+
326+
```python
327+
about()
328+
```
329+
330+
## Testing
331+
332+
```python
333+
assert jnp.isclose(Itest, ncurr, rtol=1e-5), "Current calc. deviates"
334+
assert jnp.isclose(shottest, nshot, rtol=1e-1), "Shot noise calc. deviates."
335+
```

0 commit comments

Comments
 (0)