-
Notifications
You must be signed in to change notification settings - Fork 49
QuTiPv5 Paper Notebook: JAX #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
74f4433
5ff9708
b00fff2
df0ad30
2e54134
dcb7d77
a07e0e7
b58adfc
f135053
1f48f63
a0d5b13
1a9cc9d
635374d
cabfc70
b6b9cff
4f282da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,327 @@ | ||
--- | ||
jupyter: | ||
jupytext: | ||
text_representation: | ||
extension: .md | ||
format_name: markdown | ||
format_version: '1.3' | ||
jupytext_version: 1.13.8 | ||
kernelspec: | ||
display_name: Python 3 | ||
language: python | ||
name: python3 | ||
--- | ||
|
||
# QuTiPv5 Paper Example: QuTiP-JAX with mesolve and auto-differnetiation | ||
|
||
Authors: Maximilian Meyer-Mölleringhof ([email protected]), Neill Lambert ([email protected]) | ||
|
||
For many years now, GPUs have been a fundamental tool for accelerating numerical tasks. | ||
Today, many libraries enable off-the-shelf methods to leverage GPUs' potential to speed up costly calculations. | ||
QuTiP’s flexible data layer can directly be used with many such libraries and thereby drastically reduce computation time. | ||
Despite a big variety of frameworks, in connection to QuTiP, development has centered on the QuTiP-JAX integration [\[1\]] due to JAX's robust auto-differentiation features and widespread adoption in machine learning. | ||
|
||
In these examples we illustrate how JAX naturally integrates into QuTiP v5 [\[2\]] using the QuTiP-JAX package. | ||
As a simple first example, we look at a one-dimensional spin chain and how we might employ `mesolve()` and JAX to solve the related master equation. | ||
In the second part, we focus on the auto-differentiation capabilities. | ||
For this we first consider the counting statistics of an open quantum system connected to an environment. | ||
Lastly, we look at a driven qubit system and computing the gradient of its state population in respect to the drive's frequency using `mcsolve()`. | ||
|
||
## Introduction | ||
|
||
In addition to the standrad QuTiP package, in order to use QuTiP-JAX, the JAX package needs to be installed. | ||
This package also comes with `jax.numpy` which mirrors all of `numpy`s functionality for seamless integration with JAX. | ||
|
||
```python | ||
import jax.numpy as jnp | ||
import matplotlib.pyplot as plt | ||
import qutip_jax as qj | ||
from diffrax import PIDController, Tsit5 | ||
from jax import default_device, devices, grad, jacfwd, jacrev, jit | ||
from qutip import (CoreOptions, about, basis, destroy, lindblad_dissipator, | ||
liouvillian, mcsolve, mesolve, projection, qeye, settings, | ||
sigmam, sigmax, sigmay, sigmaz, spost, spre, sprepost, | ||
steadystate, tensor) | ||
|
||
%matplotlib inline | ||
``` | ||
|
||
An immediate effect of importing `qutip_jax` is the availability of the data layer formats `jax` and `jax_dia`. | ||
They allow for dense (`jax`) and custom sparse (`jax_dia`) formats. | ||
|
||
```python | ||
print(qeye(3, dtype="jax").dtype.__name__) | ||
print(qeye(3, dtype="jaxdia").dtype.__name__) | ||
``` | ||
|
||
In order to use the JAX data layer also within the master equation solver, there are two settings we can choose. | ||
First, is simply adding `method: diffrax` to the options parameter of the solver. | ||
Second, is to use the `qutip_jax.set_as_default()` method. | ||
It automatically switches all data data tyeps to JAX compatible versions and sets the default solver method to `diffrax`. | ||
Langhaarzombie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
qj.set_as_default() | ||
``` | ||
|
||
To revert this setting, we can set the function parameter `revert = True`. | ||
|
||
```python | ||
# qj.set_as_default(revert = True) | ||
``` | ||
|
||
## Using the GPU with JAX and Diffrax - 1D Ising Spin Chain | ||
|
||
Before diving into the example, it is worth noting here that GPU acceleration depends heavily on the type of problem. | ||
GPUs are good at parallelizing many small matrix-vector operations, such as integrating small systems across multiple parameters or simulating quantum circuits with repeated small matrix operations. | ||
For a single ODE involving large matrices, the advantages are less straightforward since ODE solvers are inherently sequential. | ||
However, as it is illustrated in the [QuTiPv5 paper](#References), there is a cross-over point at which using JAX is beneficial. | ||
|
||
### 1D Ising Spin Chain | ||
|
||
To illustrate the usage of QuTiP-JAX, we look at the one-dimensional spin chain with the Hamiltonian | ||
|
||
$H = \sum_{i=1}^N g_0 \sigma_z^{(n)} - \sum_{n=1}^{N-1} J_0 \sigma_x^{(n)} \sigma_x^{(n+1)}$. | ||
|
||
We hereby consider $N$ spins that share an energy splitting of $g_0$ and have a coupling strength $J_0$. | ||
The end of the chain connects to an environment described by a Lindbladian dissipator which we model with the collapse operator $\sigma_x^{(N-1)}$ and coupling rate $\gamma$. | ||
|
||
As part of the [QuTiPv5 paper](#References), we see an extensive study on the computation time depending on the dimensionality $N$. | ||
In this example we cannot replicate the performance of a supercomputer of course, so we rather focus on the correct implementation to solve the Lindablad equation for this system using JAX and `mesolve()`. | ||
Langhaarzombie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
# system parameters | ||
N = 4 # number of spins | ||
g0 = 1 # energy splitting | ||
J0 = 1.4 # coupling strength | ||
gamma = 0.1 # dissipation rate | ||
|
||
# simulation parameters | ||
tlist = jnp.linspace(0, 5, 100) | ||
opt = { | ||
"normalize_output": False, | ||
"store_states": True, | ||
"method": "diffrax", | ||
"stepsize_controller": PIDController( | ||
rtol=settings.core["rtol"], atol=settings.core["atol"] | ||
), | ||
"solver": Tsit5(), | ||
} | ||
``` | ||
|
||
```python | ||
with CoreOptions(default_dtype="CSR"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i guess here we should have "jaxdia" instead of csr? |
||
# Operators for individual qubits | ||
sx_list, sy_list, sz_list = [], [], [] | ||
for i in range(N): | ||
op_list = [qeye(2)] * N | ||
op_list[i] = sigmax() | ||
sx_list.append(tensor(op_list)) | ||
op_list[i] = sigmay() | ||
sy_list.append(tensor(op_list)) | ||
op_list[i] = sigmaz() | ||
sz_list.append(tensor(op_list)) | ||
|
||
# Hamiltonian - Energy splitting terms | ||
H = 0.0 | ||
for i in range(N): | ||
H += g0 * sz_list[i] | ||
|
||
# Interaction terms | ||
for n in range(N - 1): | ||
H += -J0 * sx_list[n] * sx_list[n + 1] | ||
|
||
# Collapse operator acting locally on single spin | ||
c_ops = [gamma * sx_list[N - 1]] | ||
|
||
# Initial state | ||
state_list = [basis(2, 1)] * (N - 1) | ||
state_list.append(basis(2, 0)) | ||
psi0 = tensor(state_list) | ||
|
||
result = mesolve(H, psi0, tlist, c_ops, e_ops=sz_list, options=opt) | ||
``` | ||
|
||
```python | ||
for i, s in enumerate(result.expect): | ||
plt.plot(tlist, s, label=rf"$n = {i+1}$") | ||
plt.xlabel("Time") | ||
plt.ylabel(r"$\langle \sigma^{(n)}_z \rangle$") | ||
plt.legend() | ||
plt.show() | ||
``` | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could add some funky animation plot!
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it normal that this takes >20min to compute (on a normal-ish laptop)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ouch. i tried a different example the other day and it was fairly slow too. we could reduce the number of points in tlist, but perhaps just better not to add here! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From my experience, a long |
||
## Auto-Differentiation | ||
|
||
We have seen in the previous example how the new JAX data-layer in QuTiP enables us to run calculations on the GPU. | ||
On top of that, JAX adds the features of auto-differentiation. | ||
To compute derivatives, it is often numerical approximations (e.g., finite difference method) that need to be employed. | ||
Especially for higher order derivatives, these methods can turn into costly and inaccurate calculations. | ||
|
||
Auto-differenciation, on the other hand, exploits the chain rule to compute such derivatives. | ||
Langhaarzombie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The idea is that any numerical function can be expressed by elementary analytical functions and operations. | ||
Consequently, using the chain rule, the derivatives of almost any higher-level function become accessible. | ||
|
||
Although there are many applications for this technique, in chapter we want to focus two examples where auto-differentiation becomes relevant. | ||
|
||
### Statistics of Excitations between Quantum System and Environment | ||
|
||
We consider an open quantum system that is in contact with an evironment via a single jump operator. | ||
Additionally, we have a measurement device that tracks the flow of excitations between the system and the environment. | ||
The probability distribution that describes the number of such exchanged excitations $n$ in a certain time $t$ is called the full counting statistics and denoted by $P_n(t)$. | ||
This statistics is a defining property that allows to derive many experimental observables like shot noise or current. | ||
|
||
For the example here, we can calculate this statistics by using a modified version of the density operator and Lindblad master equation. | ||
We introduce the *tilted* density operator $G(z,t) = \sum_n e^{zn} \rho^n (t)$ with $\rho^n(t)$ being the density operator of the system conditioned on $n$ exchanges by time $t$, so $\text{Tr}[\rho^n(t)] = P_n(t)$. | ||
The master equation for this operator, including the jump operator $C$, is then given as | ||
|
||
$\dot{G}(z,t) = -\dfrac{i}{\hbar} [H(t), G(z,t)] + \dfrac{1}{2} [2 e^z C \rho(t)C^\dagger - \rho C^\dagger C - C^\dagger C \rho(t)]$. | ||
|
||
We see that for $z = 0$, this master equation becomes the regular Lindblad equation and $G(0,t) = \rho(t)$. | ||
However, it also allows us to describe the counting statistics through its derivatives | ||
|
||
$\langle n^m \rangle (t) = \sum_n n^m \text{Tr} [\rho^n (t)] = \dfrac{d^m}{dz^m} \text{Tr} [G(z,t)]|_{z=0}$. | ||
|
||
These derivatives are precisely where the auto-differention by JAX finds its application for us. | ||
|
||
```python | ||
# system parameters | ||
ed = 1 | ||
GammaL = 1 | ||
GammaR = 1 | ||
|
||
# simulation parameters | ||
options = { | ||
"method": "diffrax", | ||
"normalize_output": False, | ||
"stepsize_controller": PIDController(rtol=1e-7, atol=1e-7), | ||
"solver": Tsit5(scan_kind="bounded"), | ||
"progress_bar": False, | ||
} | ||
``` | ||
|
||
```python | ||
with default_device(devices("gpu")[0]): | ||
with CoreOptions(default_dtype="jaxdia"): | ||
d = destroy(2) | ||
H = ed * d.dag() * d | ||
c_op_L = jnp.sqrt(GammaL) * d.dag() | ||
c_op_R = jnp.sqrt(GammaR) * d | ||
|
||
L0 = ( | ||
liouvillian(H) | ||
+ lindblad_dissipator(c_op_L) | ||
- 0.5 * spre(c_op_R.dag() * c_op_R) | ||
- 0.5 * spost(c_op_R.dag() * c_op_R) | ||
) | ||
L1 = sprepost(c_op_R, c_op_R.dag()) | ||
|
||
rho0 = steadystate(L0 + L1) | ||
|
||
def rhoz(t, z): | ||
L = L0 + jnp.exp(z) * L1 # jump term | ||
tlist = jnp.linspace(0, t, 50) | ||
result = mesolve(L, rho0, tlist, options=options) | ||
return result.final_state.tr() | ||
|
||
# first derivative | ||
drhozdz = jacrev(rhoz, argnums=1) | ||
# second derivative | ||
d2rhozdz = jacfwd(drhozdz, argnums=1) | ||
``` | ||
|
||
```python | ||
tf = 100 | ||
Itest = GammaL * GammaR / (GammaL + GammaR) | ||
shottest = Itest * (1 - 2 * GammaL * GammaR / (GammaL + GammaR) ** 2) | ||
ncurr = drhozdz(tf, 0.0) / tf | ||
nshot = (d2rhozdz(tf, 0.0) - drhozdz(tf, 0.0) ** 2) / tf | ||
|
||
print("===== RESULTS =====") | ||
print("Analytic current", Itest) | ||
print("Numerical current", ncurr) | ||
print("Analytical shot noise (2nd cumulant)", shottest) | ||
print("Numerical shot noise (2nd cumulant)", nshot) | ||
``` | ||
|
||
### Driven One Qubit System & Frequency Optimization | ||
|
||
As a second example for auto differentiation, we consider the driven Rabi model, which is given by the time-dependent Hamiltonian | ||
|
||
$H(t) = \dfrac{\hbar \omega_0}{2} \sigma_z + \dfrac{\hbar \Omega}{2} \cos (\omega t) \sigma_x$ | ||
|
||
with the energy splitting $\omega_0$, $\Omega$ as the Rabi frequency, the drive frequency $\omega$ and $\sigma_{x/z}$ are Pauli matrices. | ||
When we add dissipation to the system, the dynamics is given by the Lindblad master equation, which introduces collapse operator $C = \sqrt{\gamma} \sigma_-$ to describe energy relaxation. | ||
|
||
For this example, we are interested in the population of the excited state of the qubit | ||
|
||
$P_e(t) = \bra{e} \rho(t) \ket{e}$ | ||
|
||
and its gradient with respect to the frequency $\omega$. | ||
|
||
We want to optimize this quantity by adjusting the drive frequency $\omega$. | ||
To achieve this, we compute the gradient of $P_e(t)$ in respect to $\omega$ by using JAX's auto-differentiation tools nad QuTiP's `mcsolve()`. | ||
Langhaarzombie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```python | ||
# system parameters | ||
gamma = 0.1 # dissipation rate | ||
``` | ||
|
||
```python | ||
# time dependent drive | ||
@jit | ||
def driving_coeff(t, omega): | ||
return jnp.cos(omega * t) | ||
|
||
|
||
# system Hamiltonian | ||
def setup_system(): | ||
H_0 = sigmaz() | ||
H_1 = sigmax() | ||
H = [H_0, [H_1, driving_coeff]] | ||
return H | ||
``` | ||
|
||
```python | ||
# simulation parameters | ||
psi0 = basis(2, 0) | ||
tlist = jnp.linspace(0.0, 10.0, 100) | ||
c_ops = [jnp.sqrt(gamma) * sigmam()] | ||
e_ops = [projection(2, 1, 1)] | ||
``` | ||
|
||
```python | ||
# Objective function: returns final exc. state population | ||
def f(omega): | ||
H = setup_system() | ||
arg = {"omega": omega} | ||
result = mcsolve(H, psi0, tlist, c_ops, e_ops=e_ops, ntraj=100, args=arg) | ||
return result.expect[0][-1] | ||
``` | ||
|
||
```python | ||
# Compute gradient of the exc. state population w.r.t. omega | ||
grad_f = grad(f)(2.0) | ||
``` | ||
|
||
## References | ||
|
||
|
||
|
||
|
||
[1] [QuTiP-JAX](https://github.com/qutip/qutip-jax) | ||
|
||
[2] [QuTiP v5: The Quantum Toolbox in Python](about:blank) | ||
|
||
|
||
## About | ||
|
||
```python | ||
about() | ||
``` | ||
|
||
## Testing | ||
|
||
```python | ||
assert jnp.isclose(Itest, ncurr, rtol=1e-5), "Current calc. deviates" | ||
assert jnp.isclose(shottest, nshot, rtol=1e-1), "Shot noise calc. deviates." | ||
Langhaarzombie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` |
Uh oh!
There was an error while loading. Please reload this page.