If one follows the documentation example and builds the following code
import jax.numpy as jnp
import qutip
import qutip_jax
from functools import partial
import jax
N = 10000
tlist = jnp.linspace(0.0, 10.0, 200)
# ``jaxdia`` operators support higher dimensional Hilbert spaces in the GPU
with qutip.CoreOptions(default_dtype="jaxdia"):
a = qutip.tensor(qutip.qeye(2), qutip.destroy(N))
sm = qutip.tensor(qutip.destroy(2), qutip.qeye(N))
H = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm + 2.0 * jnp.pi * 0.25 * (sm * a.dag() + sm.dag() * a)
# using ``jax`` dtype since ``DiffraxIntegrator`` anyway converts the final state to ``jax``
state = qutip.tensor(qutip.fock(2, 0, dtype="jax"), qutip.fock(N, 8, dtype="jax"))
c_ops = [jnp.sqrt(0.1) * a]
e_ops = [a.dag() * a, sm.dag() * sm]
@partial(jax.jit, static_argnames=("omega", ))
def H_1_coeff(t, omega):
return 2.0 * jnp.pi * 0.25 * jnp.cos(2.0 * omega * t)
H_0 = 2.0 * jnp.pi * a.dag() * a + 2.0 * jnp.pi * sm.dag() * sm
H_1_op = sm * a.dag() + sm.dag() * a
H = [H_0, [H_1_op, H_1_coeff]]
result = qutip.mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={
"method": "diffrax"
}, args={
"omega": 1.0 # arguments for the coefficient function are passed here
})
It throws an error
ValueError: Terms are not compatible with solver! Got:
ODETerm(vector_field=<PjitFunction of <function dstate at 0x7f19cd28ab60>>)
but expected:
diffrax.AbstractTerm
Note that terms are checked recursively: if you scroll up you may find a root-cause error that is more specific.
This is fixed by replacing @partial(jax.jit, static_argnames=("omega", )) with @jax.jit.
But the actual reason is unclear to me. Maybe it is because omega is a float number and cannot be trivially treated as a static arguments?
If one follows the documentation example and builds the following code
It throws an error
This is fixed by replacing
@partial(jax.jit, static_argnames=("omega", ))with@jax.jit.But the actual reason is unclear to me. Maybe it is because omega is a float number and cannot be trivially treated as a static arguments?