Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.

Commit 37d12ac

Browse files
authored
Add warning for JAX versions on import of Dynamics (#232)
1 parent 3a337b8 commit 37d12ac

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

docs/conf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,7 @@
8282
nbsphinx_execute = 'always'
8383
nbsphinx_widgets_path = ''
8484
exclude_patterns = ['_build', '**.ipynb_checkpoints']
85+
86+
# this is tied to the temporary restriction to JAX versions <=0.4.6. See issue #190
87+
import os
88+
os.environ["JAX_JIT_PJIT_API_MERGE"] = "0"

qiskit_dynamics/dispatch/backends/jax.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,28 @@
1717

1818
try:
1919
import jax
20-
from jax.interpreters.xla import DeviceArray
20+
from jax import Array
2121
from jax.core import Tracer
22-
from jax.interpreters.ad import JVPTracer
23-
from jax.interpreters.partial_eval import JaxprTracer
2422

25-
JAX_TYPES = (DeviceArray, Tracer, JaxprTracer, JVPTracer)
23+
# warning based on JAX version
24+
from packaging import version
25+
import warnings
2626

27-
try:
28-
# This class was introduced in 0.4.0
29-
from jax import Array
27+
if version.parse(jax.__version__) >= version.parse("0.4.4"):
28+
import os
3029

31-
JAX_TYPES += (Array,)
32-
except ImportError:
33-
pass
34-
35-
try:
36-
# This class is not in older versions of Jax
37-
from jax.interpreters.partial_eval import DynamicJaxprTracer
30+
if (
31+
version.parse(jax.__version__) > version.parse("0.4.6")
32+
or os.environ.get("JAX_JIT_PJIT_API_MERGE", None) != "0"
33+
):
34+
warnings.warn(
35+
"The functionality in the perturbation module of Qiskit Dynamics requires a JAX "
36+
"version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, "
37+
"0.4.5, and 0.4.6, using the perturbation module functionality requires setting "
38+
"os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics."
39+
)
3840

39-
JAX_TYPES += (DynamicJaxprTracer,)
40-
except ImportError:
41-
pass
41+
JAX_TYPES = (Array, Tracer)
4242

4343
from ..dispatch import Dispatch
4444
import numpy as np
@@ -53,7 +53,7 @@
5353
def _jax_asarray(array, dtype=None, order=None):
5454
"""Wrapper for jax.numpy.asarray"""
5555
if (
56-
isinstance(array, DeviceArray)
56+
isinstance(array, JAX_TYPES)
5757
and order is None
5858
and (dtype is None or dtype == array.dtype)
5959
):

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
"sympy>=1.12"
2525
]
2626

27-
jax_extras = ['jax>=0.2.26, <= 0.4.6',
28-
'jaxlib>=0.1.75, <= 0.4.6']
27+
jax_extras = ['jax>=0.4.0, <= 0.4.6',
28+
'jaxlib>=0.4.0, <= 0.4.6']
2929

3030
PACKAGES = setuptools.find_packages(exclude=['test*'])
3131

0 commit comments

Comments
 (0)