|
| 1 | +--- |
| 2 | +jupyter: |
| 3 | + jupytext: |
| 4 | + text_representation: |
| 5 | + extension: .md |
| 6 | + format_name: markdown |
| 7 | + format_version: '1.3' |
| 8 | + jupytext_version: 1.13.8 |
| 9 | + kernelspec: |
| 10 | + display_name: Python 3 (ipykernel) |
| 11 | + language: python |
| 12 | + name: python3 |
| 13 | +--- |
| 14 | + |
| 15 | +# QuTiPv5 Paper Example: QuTiP-JAX with mesolve and auto-differnetiation |
| 16 | + |
| 17 | +Authors: Maximilian Meyer-Mölleringhof ( [email protected]), Rochisha Agarwal ( [email protected]), Neill Lambert ( [email protected]) |
| 18 | + |
| 19 | +For many years now, GPUs have been a fundamental tool for accelerating numerical tasks. |
| 20 | +Today, many libraries enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations. |
| 21 | +QuTiP’s flexible data layer can directly be used with many such libraries and thereby drastically reduce computation time. |
| 22 | +Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration [\[1\]](#References) due to JAX's robust auto-differentiation features and widespread adoption in machine learning. |
| 23 | + |
| 24 | +In these examples we illustrate how JAX naturally integrates into QuTiP v5 [\[2\]](#References) using the QuTiP-JAX package. |
| 25 | +As a simple first example, we look at a one-dimensional spin chain and how we might employ `mesolve()` and JAX to solve the related master equation. |
| 26 | +In the second part, we focus on the auto-differentiation capabilities. |
| 27 | +For this we first consider the counting statistics of an open quantum system connected to an environment. |
| 28 | +Lastly, we look at a driven qubit system and computing the gradient of its state population in respect to the drive's frequency using `mcsolve()`. |
| 29 | + |
| 30 | +## Introduction |
| 31 | + |
| 32 | +In addition to the standard QuTiP package, in order to use QuTiP-JAX, the JAX package needs to be installed. |
| 33 | +This package also comes with `jax.numpy` which mirrors all of `numpy`s functionality for seamless integration with JAX. |
| 34 | + |
| 35 | +```python |
| 36 | +import jax.numpy as jnp |
| 37 | +import matplotlib.pyplot as plt |
| 38 | +import qutip_jax as qj |
| 39 | +from diffrax import PIDController, Tsit5 |
| 40 | +from jax import default_device, devices, grad, jacfwd, jacrev, jit |
| 41 | +from qutip import (CoreOptions, about, basis, destroy, lindblad_dissipator, |
| 42 | + liouvillian, mcsolve, mesolve, projection, qeye, settings, |
| 43 | + sigmam, sigmax, sigmay, sigmaz, spost, spre, sprepost, |
| 44 | + steadystate, tensor) |
| 45 | + |
| 46 | +%matplotlib inline |
| 47 | +``` |
| 48 | + |
| 49 | +An immediate effect of importing `qutip_jax` is the availability of the data layer formats `jax` and `jax_dia`. |
| 50 | +They allow for dense (`jax`) and custom sparse (`jax_dia`) formats. |
| 51 | + |
| 52 | +```python |
| 53 | +print(qeye(3, dtype="jax").dtype.__name__) |
| 54 | +print(qeye(3, dtype="jaxdia").dtype.__name__) |
| 55 | +``` |
| 56 | + |
| 57 | +In order to use the JAX data layer also within the master equation solver, there are two settings we can choose. |
| 58 | +First, is simply adding `method: diffrax` to the options parameter of the solver. |
| 59 | +Second, is to use the `qutip_jax.set_as_default()` method. |
| 60 | +It automatically switches all data data types to JAX compatible versions and sets the default solver method to `diffrax`. |
| 61 | + |
| 62 | +```python |
| 63 | +qj.set_as_default() |
| 64 | +``` |
| 65 | + |
| 66 | +To revert this setting, we can set the function parameter `revert = True`. |
| 67 | + |
| 68 | +```python |
| 69 | +# qj.set_as_default(revert = True) |
| 70 | +``` |
| 71 | + |
| 72 | +## Using the GPU with JAX and Diffrax - 1D Ising Spin Chain |
| 73 | + |
| 74 | +Before diving into the example, it is worth noting here that GPU acceleration depends heavily on the type of problem. |
| 75 | +GPUs are good at parallelizing many small matrix-vector operations, such as integrating small systems across multiple parameters or simulating quantum circuits with repeated small matrix operations. |
| 76 | +For a single ODE involving large matrices, the advantages are less straightforward since ODE solvers are inherently sequential. |
| 77 | +However, as it is illustrated in the QuTiP v5 paper [\[2\]](#References), there is a cross-over point at which using JAX becomes beneficial. |
| 78 | + |
| 79 | +### 1D Ising Spin Chain |
| 80 | + |
| 81 | +To illustrate the usage of QuTiP-JAX, we look at the one-dimensional spin chain with the Hamiltonian |
| 82 | + |
| 83 | +$H = \sum_{i=1}^N g_0 \sigma_z^{(n)} - \sum_{n=1}^{N-1} J_0 \sigma_x^{(n)} \sigma_x^{(n+1)}$. |
| 84 | + |
| 85 | +We hereby consider $N$ spins that share an energy splitting of $g_0$ and have a coupling strength $J_0$. |
| 86 | +The end of the chain connects to an environment described by a Lindbladian dissipator which we model with the collapse operator $\sigma_x^{(N-1)}$ and coupling rate $\gamma$. |
| 87 | + |
| 88 | +As part of the [QuTiPv5 paper](#References), we see an extensive study on the computation time depending on the dimensionality $N$. |
| 89 | +In this example we cannot replicate the performance of a supercomputer of course, so we rather focus on the correct implementation to solve the Lindblad equation for this system using JAX and `mesolve()`. |
| 90 | + |
| 91 | +```python |
| 92 | +# system parameters |
| 93 | +N = 4 # number of spins |
| 94 | +g0 = 1 # energy splitting |
| 95 | +J0 = 1.4 # coupling strength |
| 96 | +gamma = 0.1 # dissipation rate |
| 97 | + |
| 98 | +# simulation parameters |
| 99 | +tlist = jnp.linspace(0, 5, 100) |
| 100 | +opt = { |
| 101 | + "normalize_output": False, |
| 102 | + "store_states": True, |
| 103 | + "method": "diffrax", |
| 104 | + "stepsize_controller": PIDController( |
| 105 | + rtol=settings.core["rtol"], atol=settings.core["atol"] |
| 106 | + ), |
| 107 | + "solver": Tsit5(), |
| 108 | +} |
| 109 | +``` |
| 110 | + |
| 111 | +```python |
| 112 | +with CoreOptions(default_dtype="jaxdia"): |
| 113 | + # Operators for individual qubits |
| 114 | + sx_list, sy_list, sz_list = [], [], [] |
| 115 | + for i in range(N): |
| 116 | + op_list = [qeye(2)] * N |
| 117 | + op_list[i] = sigmax() |
| 118 | + sx_list.append(tensor(op_list)) |
| 119 | + op_list[i] = sigmay() |
| 120 | + sy_list.append(tensor(op_list)) |
| 121 | + op_list[i] = sigmaz() |
| 122 | + sz_list.append(tensor(op_list)) |
| 123 | + |
| 124 | + # Hamiltonian - Energy splitting terms |
| 125 | + H = 0.0 |
| 126 | + for i in range(N): |
| 127 | + H += g0 * sz_list[i] |
| 128 | + |
| 129 | + # Interaction terms |
| 130 | + for n in range(N - 1): |
| 131 | + H += -J0 * sx_list[n] * sx_list[n + 1] |
| 132 | + |
| 133 | + # Collapse operator acting locally on single spin |
| 134 | + c_ops = [gamma * sx_list[N - 1]] |
| 135 | + |
| 136 | + # Initial state |
| 137 | + state_list = [basis(2, 1)] * (N - 1) |
| 138 | + state_list.append(basis(2, 0)) |
| 139 | + psi0 = tensor(state_list) |
| 140 | + |
| 141 | + result = mesolve(H, psi0, tlist, c_ops, e_ops=sz_list, options=opt) |
| 142 | +``` |
| 143 | + |
| 144 | +```python |
| 145 | +for i, s in enumerate(result.expect): |
| 146 | + plt.plot(tlist, s, label=rf"$n = {i+1}$") |
| 147 | +plt.xlabel("Time") |
| 148 | +plt.ylabel(r"$\langle \sigma^{(n)}_z \rangle$") |
| 149 | +plt.legend() |
| 150 | +plt.show() |
| 151 | +``` |
| 152 | + |
| 153 | +## Auto-Differentiation |
| 154 | + |
| 155 | +We have seen in the previous example how the new JAX data-layer in QuTiP works. |
| 156 | +On top of that, JAX adds the features of auto-differentiation. |
| 157 | +To compute derivatives, it is often numerical approximations (e.g., finite difference method) that need to be employed. |
| 158 | +Especially for higher order derivatives, these methods can turn into costly and inaccurate calculations. |
| 159 | + |
| 160 | +Auto-differentiation, on the other hand, exploits the chain rule to compute such derivatives. |
| 161 | +The idea is that any numerical function can be expressed by elementary analytical functions and operations. |
| 162 | +Consequently, using the chain rule, the derivatives of almost any higher-level function become accessible. |
| 163 | + |
| 164 | +Although there are many applications for this technique, in this chapter we want to focus on two examples where auto-differentiation becomes relevant. |
| 165 | + |
| 166 | +### Statistics of Excitations between Quantum System and Environment |
| 167 | + |
| 168 | +We consider an open quantum system that is in contact with an evironment via a single jump operator. |
| 169 | +Additionally, we have a measurement device that tracks the flow of excitations between the system and the environment. |
| 170 | +The probability distribution that describes the number of such exchanged excitations $n$ in a certain time $t$ is called the full counting statistics and denoted by $P_n(t)$. |
| 171 | +This statistics is a defining property that allows to derive many experimental observables like shot noise or current. |
| 172 | + |
| 173 | +For the example here, we can calculate this statistics by using a modified version of the density operator and Lindblad master equation. |
| 174 | +We introduce the *tilted* density operator $G(z,t) = \sum_n e^{zn} \rho^n (t)$ with $\rho^n(t)$ being the density operator of the system conditioned on $n$ exchanges by time $t$, so $\text{Tr}[\rho^n(t)] = P_n(t)$. |
| 175 | +The master equation for this operator, including the jump operator $C$, is then given as |
| 176 | + |
| 177 | +$\dot{G}(z,t) = -\dfrac{i}{\hbar} [H(t), G(z,t)] + \dfrac{1}{2} [2 e^z C \rho(t)C^\dagger - \rho C^\dagger C - C^\dagger C \rho(t)]$. |
| 178 | + |
| 179 | +We see that for $z = 0$, this master equation becomes the regular Lindblad equation and $G(0,t) = \rho(t)$. |
| 180 | +However, it also allows us to describe the counting statistics through its derivatives |
| 181 | + |
| 182 | +$\langle n^m \rangle (t) = \sum_n n^m \text{Tr} [\rho^n (t)] = \dfrac{d^m}{dz^m} \text{Tr} [G(z,t)]|_{z=0}$. |
| 183 | + |
| 184 | +These derivatives are precisely where the auto-differention by JAX finds its application for us. |
| 185 | + |
| 186 | +```python |
| 187 | +# system parameters |
| 188 | +ed = 1 |
| 189 | +GammaL = 1 |
| 190 | +GammaR = 1 |
| 191 | + |
| 192 | +# simulation parameters |
| 193 | +options = { |
| 194 | + "method": "diffrax", |
| 195 | + "normalize_output": False, |
| 196 | + "stepsize_controller": PIDController(rtol=1e-7, atol=1e-7), |
| 197 | + "solver": Tsit5(scan_kind="bounded"), |
| 198 | + "progress_bar": False, |
| 199 | +} |
| 200 | +``` |
| 201 | + |
| 202 | +When working with JAX you can choose the type of device / processor to be used. |
| 203 | +In our case, we will resort to the CPU since this is a simple Jupyter Notebook. |
| 204 | +However, when running this on your machine, you can opt for using your GPU by simpy changing the argument below. |
| 205 | + |
| 206 | +```python |
| 207 | +with default_device(devices("cpu")[0]): |
| 208 | + with CoreOptions(default_dtype="jaxdia"): |
| 209 | + d = destroy(2) |
| 210 | + H = ed * d.dag() * d |
| 211 | + c_op_L = jnp.sqrt(GammaL) * d.dag() |
| 212 | + c_op_R = jnp.sqrt(GammaR) * d |
| 213 | + |
| 214 | + L0 = ( |
| 215 | + liouvillian(H) |
| 216 | + + lindblad_dissipator(c_op_L) |
| 217 | + - 0.5 * spre(c_op_R.dag() * c_op_R) |
| 218 | + - 0.5 * spost(c_op_R.dag() * c_op_R) |
| 219 | + ) |
| 220 | + L1 = sprepost(c_op_R, c_op_R.dag()) |
| 221 | + |
| 222 | + rho0 = steadystate(L0 + L1) |
| 223 | + |
| 224 | + def rhoz(t, z): |
| 225 | + L = L0 + jnp.exp(z) * L1 # jump term |
| 226 | + tlist = jnp.linspace(0, t, 50) |
| 227 | + result = mesolve(L, rho0, tlist, options=options) |
| 228 | + return result.final_state.tr() |
| 229 | + |
| 230 | + # first derivative |
| 231 | + drhozdz = jacrev(rhoz, argnums=1) |
| 232 | + # second derivative |
| 233 | + d2rhozdz = jacfwd(drhozdz, argnums=1) |
| 234 | +``` |
| 235 | + |
| 236 | +```python |
| 237 | +tf = 100 |
| 238 | +Itest = GammaL * GammaR / (GammaL + GammaR) |
| 239 | +shottest = Itest * (1 - 2 * GammaL * GammaR / (GammaL + GammaR) ** 2) |
| 240 | +ncurr = drhozdz(tf, 0.0) / tf |
| 241 | +nshot = (d2rhozdz(tf, 0.0) - drhozdz(tf, 0.0) ** 2) / tf |
| 242 | + |
| 243 | +print("===== RESULTS =====") |
| 244 | +print("Analytical current", Itest) |
| 245 | +print("Numerical current", ncurr) |
| 246 | +print("Analytical shot noise (2nd cumulant)", shottest) |
| 247 | +print("Numerical shot noise (2nd cumulant)", nshot) |
| 248 | +``` |
| 249 | + |
| 250 | +### Driven One Qubit System & Frequency Optimization |
| 251 | + |
| 252 | +As a second example for auto differentiation, we consider the driven Rabi model, which is given by the time-dependent Hamiltonian |
| 253 | + |
| 254 | +$H(t) = \dfrac{\hbar \omega_0}{2} \sigma_z + \dfrac{\hbar \Omega}{2} \cos (\omega t) \sigma_x$ |
| 255 | + |
| 256 | +with the energy splitting $\omega_0$, $\Omega$ as the Rabi frequency, the drive frequency $\omega$ and $\sigma_{x/z}$ are Pauli matrices. |
| 257 | +When we add dissipation to the system, the dynamics is given by the Lindblad master equation, which introduces collapse operator $C = \sqrt{\gamma} \sigma_-$ to describe energy relaxation. |
| 258 | + |
| 259 | +For this example, we are interested in the population of the excited state of the qubit |
| 260 | + |
| 261 | +$P_e(t) = \bra{e} \rho(t) \ket{e}$ |
| 262 | + |
| 263 | +and its gradient with respect to the frequency $\omega$. |
| 264 | + |
| 265 | +We want to optimize this quantity by adjusting the drive frequency $\omega$. |
| 266 | +To achieve this, we compute the gradient of $P_e(t)$ in respect to $\omega$ by using JAX's auto-differentiation tools and QuTiP's `mcsolve()`. |
| 267 | + |
| 268 | +```python |
| 269 | +# system parameters |
| 270 | +gamma = 0.1 # dissipation rate |
| 271 | +``` |
| 272 | + |
| 273 | +```python |
| 274 | +# time dependent drive |
| 275 | +@jit |
| 276 | +def driving_coeff(t, omega): |
| 277 | + return jnp.cos(omega * t) |
| 278 | + |
| 279 | + |
| 280 | +# system Hamiltonian |
| 281 | +def setup_system(): |
| 282 | + H_0 = sigmaz() |
| 283 | + H_1 = sigmax() |
| 284 | + H = [H_0, [H_1, driving_coeff]] |
| 285 | + return H |
| 286 | +``` |
| 287 | + |
| 288 | +```python |
| 289 | +# simulation parameters |
| 290 | +psi0 = basis(2, 0) |
| 291 | +tlist = jnp.linspace(0.0, 10.0, 100) |
| 292 | +c_ops = [jnp.sqrt(gamma) * sigmam()] |
| 293 | +e_ops = [projection(2, 1, 1)] |
| 294 | +``` |
| 295 | + |
| 296 | +```python |
| 297 | +# Objective function: returns final exc. state population |
| 298 | +def f(omega): |
| 299 | + H = setup_system() |
| 300 | + arg = {"omega": omega} |
| 301 | + result = mcsolve(H, psi0, tlist, c_ops, e_ops=e_ops, ntraj=100, args=arg) |
| 302 | + return result.expect[0][-1] |
| 303 | +``` |
| 304 | + |
| 305 | +```python |
| 306 | +# Gradient of the excited state population with respect to omega |
| 307 | +grad_f = grad(f)(2.0) |
| 308 | +``` |
| 309 | + |
| 310 | +```python |
| 311 | +print(grad_f) |
| 312 | +``` |
| 313 | + |
| 314 | +## References |
| 315 | + |
| 316 | + |
| 317 | + |
| 318 | + |
| 319 | +[1] [QuTiP-JAX](https://github.com/qutip/qutip-jax) |
| 320 | + |
| 321 | +[2] [QuTiP 5: The Quantum Toolbox in Python](https://arxiv.org/abs/2412.04705) |
| 322 | + |
| 323 | + |
| 324 | +## About |
| 325 | + |
| 326 | +```python |
| 327 | +about() |
| 328 | +``` |
| 329 | + |
| 330 | +## Testing |
| 331 | + |
| 332 | +```python |
| 333 | +assert jnp.isclose(Itest, ncurr, rtol=1e-5), "Current calc. deviates" |
| 334 | +assert jnp.isclose(shottest, nshot, rtol=1e-1), "Shot noise calc. deviates." |
| 335 | +``` |
0 commit comments