|
17 | 17 |
|
18 | 18 | try: |
19 | 19 | import jax |
20 | | - from jax.interpreters.xla import DeviceArray |
| 20 | + from jax import Array |
21 | 21 | from jax.core import Tracer |
22 | | - from jax.interpreters.ad import JVPTracer |
23 | | - from jax.interpreters.partial_eval import JaxprTracer |
24 | 22 |
|
25 | | - JAX_TYPES = (DeviceArray, Tracer, JaxprTracer, JVPTracer) |
| 23 | + # warning based on JAX version |
| 24 | + from packaging import version |
| 25 | + import warnings |
26 | 26 |
|
27 | | - try: |
28 | | - # This class was introduced in 0.4.0 |
29 | | - from jax import Array |
| 27 | + if version.parse(jax.__version__) >= version.parse("0.4.4"): |
| 28 | + import os |
30 | 29 |
|
31 | | - JAX_TYPES += (Array,) |
32 | | - except ImportError: |
33 | | - pass |
34 | | - |
35 | | - try: |
36 | | - # This class is not in older versions of Jax |
37 | | - from jax.interpreters.partial_eval import DynamicJaxprTracer |
| 30 | + if ( |
| 31 | + version.parse(jax.__version__) > version.parse("0.4.6") |
| 32 | + or os.environ.get("JAX_JIT_PJIT_API_MERGE", None) != "0" |
| 33 | + ): |
| 34 | + warnings.warn( |
| 35 | + "The functionality in the perturbation module of Qiskit Dynamics requires a JAX " |
| 36 | + "version <= 0.4.6, due to a bug in JAX versions > 0.4.6. For versions 0.4.4, " |
| 37 | + "0.4.5, and 0.4.6, using the perturbation module functionality requires setting " |
| 38 | + "os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' before importing JAX or Dynamics." |
| 39 | + ) |
38 | 40 |
|
39 | | - JAX_TYPES += (DynamicJaxprTracer,) |
40 | | - except ImportError: |
41 | | - pass |
| 41 | + JAX_TYPES = (Array, Tracer) |
42 | 42 |
|
43 | 43 | from ..dispatch import Dispatch |
44 | 44 | import numpy as np |
|
53 | 53 | def _jax_asarray(array, dtype=None, order=None): |
54 | 54 | """Wrapper for jax.numpy.asarray""" |
55 | 55 | if ( |
56 | | - isinstance(array, DeviceArray) |
| 56 | + isinstance(array, JAX_TYPES) |
57 | 57 | and order is None |
58 | 58 | and (dtype is None or dtype == array.dtype) |
59 | 59 | ): |
|
0 commit comments