diff --git a/qiskit_dynamics/__init__.py b/qiskit_dynamics/__init__.py index 5eb49f98c..79f3eb58a 100644 --- a/qiskit_dynamics/__init__.py +++ b/qiskit_dynamics/__init__.py @@ -45,3 +45,4 @@ from . import signals from . import pulse from . import backend +from . import compute_utils diff --git a/qiskit_dynamics/compute_utils/__init__.py b/qiskit_dynamics/compute_utils/__init__.py new file mode 100644 index 000000000..bcc7af51e --- /dev/null +++ b/qiskit_dynamics/compute_utils/__init__.py @@ -0,0 +1,24 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +==================================================== +Compute Utils (:mod:`qiskit_dynamics.compute_utils`) +==================================================== + +.. currentmodule:: qiskit_dynamics.compute_utils + +This submodule contains utilities to aid in running computations, and is based in JAX. +""" + +from .parallel_maps import grid_map +from .pytree_utils import tree_concatenate, tree_product diff --git a/qiskit_dynamics/compute_utils/parallel_maps.py b/qiskit_dynamics/compute_utils/parallel_maps.py new file mode 100644 index 000000000..6767d45ba --- /dev/null +++ b/qiskit_dynamics/compute_utils/parallel_maps.py @@ -0,0 +1,347 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=invalid-name,no-member + +""" +Utilities for mapping functions over arrays in parallel. +""" + +from typing import Callable, Optional, Tuple, List +from itertools import product +from functools import partial +import inspect + +import jax +import jax.numpy as jnp +import numpy as np +from jax.tree_util import tree_flatten, tree_map +from jax.sharding import Mesh +from jax.experimental.maps import xmap + +from qiskit import QiskitError + +from .pytree_utils import tree_concatenate, tree_product + + +def grid_map( + f: Callable, + *args: Tuple["PyTree"], + devices: Optional[np.array] = None, + max_vmap_size: Optional[int] = None, + nonjax_argnums: Optional[List[int]] = None, + nonjax_argnames: Optional[List[str]] = None, + key: Optional[jnp.ndarray] = None, + keys_per_grid_point: int = 1, +) -> "PyTree": + """Map a function ``f`` over all combinations of inputs specified by the positional arguments, + utilizing a mix of device parallelization and vectorization. + + This function evaluates a function ``f`` with multiple inputs over a grid of input values. For + example, suppose we have a function ``f`` of two inputs, the first being a scalar and the second + being an array of shape ``(2,)``, and whose output is an array. We want to evaluate ``f(a, b)`` + for all combinations of ``a`` in ``a_vals = jnp.array([1, 2])`` and ``b`` in + ``b_vals = jnp.array([[3, 4], [5, 6], [7, 8]])``. This can be done with ``grid_map`` as follows: + + .. code-block:: python + + out = grid_map(f, a_vals, b_vals) + out == jnp.array([ + [f(1, [3, 4]), f(1, [5, 6]), f(1, [7, 8]])], + [f(2, [3, 4]), f(2, [5, 6]), f(2, [7, 8]])] + ]) + + Note that the above output array ``out`` satisfies ``out[i, j] == f(a[i], b[j])``. + + More generally, this function can be used with functions ``f`` with PyTree inputs and output. + Abusing notation, for a PyTree ``pt``, let ``pt[idx]`` denote ``tree_map(pt, lambda a: a[idx])`` + (i.e. ``pt[idx]`` denotes the PyTree when indexing all leaves with ``idx``). Let ``a1``, ..., + ``am`` denote PyTrees whose leaves are all arrays with dimension at least 1, and within each + PyTree, all leaves have the same length. It holds that + ``grid_map(f, a1, ..., am)[idx1, ..., idxm] == f(a1[idx1], ..., am[idxm])``, assuming the + evaluation ``f(a1[idx1], ..., am[idxm])`` is well-defined. + + In addition to this, the arguments ``devices`` and ``max_vmap_size`` enable configuration of + parallelization and vectorization. ``devices`` specify the list of JAX-visible devices to + parallelize over, with the calls to ``f`` being evenly split across devices. Within calls to a + single device, ``max_vmap_size`` controls the number of calls to ``f`` that are executed + simultaneously using vectorization. All function evaluations are executed in a serial loop, in + chunks of size ``max_vmap_size * len(devices)``, with a final iteration of size + ``k * len(devices)`` for some ``k < len(devices)``. + + Finally, the arguments ``key`` and ``keys_per_grid_point`` provide the option to supply every + call to ``f`` with a randomly generated JAX ``key``, used for pseudo-random number generation in + ``f``. If ``key`` is specified, it is assumed that the signature of ``f`` is of the form + ``f(*args, key)``, i.e. the random key is consumed as the last argument, and ``args`` are the + standard arguments of ``f`` being mapped over. The keys provided to ``f`` are generated + pseudo-randomly from the ``key`` provided to ``grid_map``. ``keys_per_grid_point`` controls how + many times ``f`` is evaluated for a given set of deterministic ``args``. If + ``keys_per_grid_point == 1``, the output of ``grid_map`` will have the same format as described + above, except that ``f`` will have been provided with a random key for each evaluation. If + ``keys_per_grid_point > 1``, an additional axis will be added to the output arrays indexing + repeated evaluation of the function for a fixed value of the deterministic arguments, but for + different keys. Lastly, the ``key`` argument of ``f`` is assumed to be a JAX-compatible + argument. + + Notes: + * This function is a convenience wrapper around JAX's ``xmap`` transformation. + * The ``nonjax_argnums`` and ``nonjax_argnames`` arguments can be used to prevent JAX mapping + over a subset of the arguments. If these are used, a normal python loop will be used to map + over the product of these arguments, and the remaining arguments will be mapped over using + JAX's mapping functionality. As such, parallelization will only be utilized for the remaining + arguments. Note that the "non-JAX" arguments specified by ``nonjax_argnums`` and + ``nonjax_argnames`` are assumed to be standard iterators over the different values of the + arguments (in contrast to the PyTree structure of JAX-compatible arguments.) Note, however, + that the output of ``f`` is still assumed to output a PyTree with consistent shape across + all argument values. + + Args: + f: The function to map. + *args: A tuple of PyTrees. Should be the same length as the number of arguments to ``f``. + devices: 1d numpy object array of devices to parallelize over. Defaults to + ``np.array(jax.devices())``. + max_vmap_size: The maximum number of inputs to vectorize over within a device. If the first + device type is CPU, this will default to ``1``, and if GPU, will default to + ``len(input_array) / len(devices)``. + nonjax_argnums: Positional arguments to not map over. + nonjax_argnames: Named arguments to not map over. + key: A JAX key to be used for generating randomness. See the function doc string for + how this impacts the behaviour. + keys_per_grid_point: If ``key is not None``, controls the number of times ``f`` is + evaluated with a random key per the rest of the inputs. + Returns: + PyTree containing ``f`` evaluated on all combinations of inputs. + Raises: + QiskitError: If ``devices`` is of invalid shape. + """ + + if devices is None: + devices = np.array(jax.devices()) + elif not devices.ndim == 1: + raise QiskitError("devices must be a 1d array.") + + if (nonjax_argnums is None) and (nonjax_argnames is None): + # take product of args and map over leading axis + if key is None: + args_product = tree_product(args) + else: + args_product = _tree_product_with_keys( + args, key=key, keys_per_grid_point=keys_per_grid_point + ) + + output_1d = _1d_map(f, *args_product, devices=devices, max_vmap_size=max_vmap_size) + + # reshape first axis and return result + map_shape = tuple(len(tree_flatten(arg)[0][0]) for arg in args) + + # add an extra dimension if more than one key per input was used + if key is not None and keys_per_grid_point > 1: + map_shape = map_shape + (keys_per_grid_point,) + + return tree_map(lambda x: x.reshape(map_shape + x.shape[1:]), output_1d) + + if nonjax_argnums is None: + nonjax_argnums = [] + else: + for idx in nonjax_argnums: + if not isinstance(idx, int): + raise QiskitError("All entries in nonjax_argnums must be ints.") + + # convert argnames to argnums + if nonjax_argnames is not None: + all_argnames = inspect.getfullargspec(f).args + new_argnums = [all_argnames.index(name) for name in nonjax_argnames] + nonjax_argnums = nonjax_argnums + new_argnums + + # get unique argnums and sort them + nonjax_argnums = list(set(nonjax_argnums)) + nonjax_argnums.sort() + + # redefined function with nonjax args moved to the front + g = _move_args_to_front(f, nonjax_argnums) + + nonjax_args = [] + dynamic_args = [] + for idx, arg in enumerate(args): + if idx in nonjax_argnums: + nonjax_args.append(arg) + else: + dynamic_args.append(arg) + nonjax_args = tuple(nonjax_args) + dynamic_args = tuple(dynamic_args) + + nonjax_args_product = product(*nonjax_args) + + # setup dynamic_args_product depending on of randomness is involved + if key is not None: + num_nonjax_combos = np.prod(tuple(len(arg) for arg in nonjax_args)) + keys = jax.random.split(key, num_nonjax_combos) + dynamic_args_product = _tree_product_with_keys(dynamic_args, keys[0], keys_per_grid_point) + + # used to later replace keys without taking whole product + num_keys_per_map = ( + np.prod(tuple(len(tree_flatten(arg)[0][0]) for arg in dynamic_args)) + * keys_per_grid_point + ) + else: + dynamic_args_product = tree_product(dynamic_args) + + outputs = [] + for idx, current_nonjax_args in enumerate(nonjax_args_product): + + if key is not None and idx > 0: + dynamic_args_product = dynamic_args_product[:-1] + ( + jax.random.split(keys[idx], num_keys_per_map), + ) + + outputs.append( + _1d_map( + partial(g, *current_nonjax_args), + *dynamic_args_product, + devices=devices, + max_vmap_size=max_vmap_size, + ) + ) + + output_1d = tree_concatenate(jax.device_put(outputs, devices[0])) + + # reshape first axis to be multidimensional with the arguments in the nonjax + dynamic order + map_shape = tuple(len(arg) for arg in nonjax_args) + tuple( + len(tree_flatten(arg)[0][0]) for arg in dynamic_args + ) + + # if keys_per_grid_point > 1 add a further dimension + if key is not None and keys_per_grid_point > 1: + map_shape = map_shape + (keys_per_grid_point,) + + # reshape based on input shapes + reshaped_output = tree_map(lambda x: x.reshape(map_shape + x.shape[1:]), output_1d) + + # reorder first axes to correspond to the original argument order + num_args = len(args) if (key is None or keys_per_grid_point == 1) else len(args) + 1 + current_arg_order = nonjax_argnums + list( + idx for idx in range(num_args) if idx not in nonjax_argnums + ) + original_arg_location = [current_arg_order.index(idx) for idx in range(num_args)] + + def axis_reorder(x): + x_axis_order = original_arg_location + list(range(num_args, x.ndim)) + return x.transpose(x_axis_order) + + return tree_map(axis_reorder, reshaped_output) + + +def _1d_map( + f: Callable, + *args: Tuple["PyTree"], + devices: Optional[np.array] = None, + max_vmap_size: Optional[int] = None, +) -> jnp.array: + """Map f over the leading axis of args (assumed to be PyTrees) using a combination of device + parallelization and vectorization. + + Implicit in this mapping is the assumption that all leaves are arrays that have at least one + dimension and have the same length. + + The mapping is parallelized over ``devices`` in chunks of ``vmap_size`` per device. Each chunk + of size ``vmap_size`` passed to a single device will be evaluated via vectorization. This is a + convenience wrapper over the ``xmap`` transformation in JAX. + + Args: + f: The function to map, assumed to be a function of a single array. + *args: The arguments to map ``f`` over. + devices: 1d numpy object array of devices to parallelize over. Defaults to + ``np.array(jax.devices())``. + max_vmap_size: The maximum number of inputs to vectorize over within a device. If the first + device type is CPU, this will default to ``1``, and if GPU, will default to + ``len(input_array) / len(devices)``. + Returns: + ``f`` mapped over the leading axis of ``input_array``. + Raises: + QiskitError: If devices are of invalid shape. + """ + + if devices is None: + devices = np.array(jax.devices()) + elif not devices.ndim == 1: + raise QiskitError("devices must be a 1d array.") + + # we should be able to rewrite everything after this using a single evaluation of xmap_f by + # utilizing SerialLoop, but it's currently raising errors when used with odeint + xmap_f = xmap( + f, + in_axes={0: "a"}, + out_axes={0: "a"}, + axis_resources={"a": ("x",)}, + ) + + # get number of inputs being mapped over + axis_size = len(tree_flatten(args[0])[0][0]) + + # set max_vmap_size based on device type + if max_vmap_size is None: + if devices[0].platform == "cpu": + max_vmap_size = 1 + else: + max_vmap_size = int(axis_size / len(devices)) + + def input_index(start_idx, end_idx): + return tree_map(lambda x: x[start_idx:end_idx], args) + + # iterate in chunks + outputs = [] + current_idx = 0 + while current_idx < axis_size: + num_evals_remaining = axis_size - current_idx + last_idx = current_idx + # if there are more evaluations remaining than there are devices, evaluate + if num_evals_remaining > len(devices): + vmap_size = min(int((axis_size - current_idx) / len(devices)), max_vmap_size) + current_idx = last_idx + vmap_size * len(devices) + with Mesh(devices, ("x",)): + outputs.append(xmap_f(*input_index(last_idx, current_idx))) + else: + current_idx = last_idx + num_evals_remaining + with Mesh(devices[:num_evals_remaining], ("x",)): + outputs.append(xmap_f(*input_index(last_idx, current_idx))) + + # combine and return outcomes + return tree_concatenate(jax.device_put(outputs, devices[0])) + + +def _move_args_to_front(f, argnums): + """Define a new function ``g`` giving the same output as ``f``, but with the positional args + whose locations are given by ``argnums`` moved to the beginning of ``g``. ``argnums`` is assumed + to be a sorted list of integers. + """ + + def g(*args): + f_args = list(args[len(argnums) :]) + for idx, arg in zip(argnums, args[: len(argnums)]): + f_args.insert(idx, arg) + + return f(*f_args) + + return g + + +def _tree_product_with_keys(trees, key: jnp.ndarray, keys_per_grid_point: int = 1): + + # take args product with a placeholder for proper structure + key_placeholder = jnp.array([0] * keys_per_grid_point) + args_product = tree_product(trees + (key_placeholder,)) + + # generate an array of keys + num_keys_needed = len(tree_flatten(args_product)[0][0]) + keys = jax.random.split(key, num_keys_needed) + + # replace placeholder with actual keys + return args_product[:-1] + (keys,) diff --git a/qiskit_dynamics/compute_utils/pytree_utils.py b/qiskit_dynamics/compute_utils/pytree_utils.py new file mode 100644 index 000000000..56215f9a5 --- /dev/null +++ b/qiskit_dynamics/compute_utils/pytree_utils.py @@ -0,0 +1,131 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=invalid-name + +"""Utility functions for working with pytrees. See the JAX documentation page on pytrees for more +https://jax.readthedocs.io/en/latest/pytrees.html. +""" + +from typing import Iterable + +import numpy as np +import jax.numpy as jnp +from jax.tree_util import tree_flatten, tree_map +from jax.experimental.maps import xmap + +from qiskit import QiskitError + + +def tree_concatenate(trees: Iterable["PyTree"], axis: int = 0) -> "PyTree": + """Given an Iterable of PyTrees with the same tree definition and whose leaves are all arrays, + return a single PyTree whose leaves are the concatenated arrays of the inputs. + + For the concatenation to be possible, this function necessarily requires that ``leaf.ndim >= 1`` + (i.e. the leaves are not scalars), and ``leaf.shape[1:]`` is the same for each leaf across all + trees. + + Args: + trees: Iterable of the trees to concatenate. + axis: Concatenation axis passed directly to ``jax.numpy.concatenate``. + + Returns: + PyTree: The concatenated PyTree. + + Raises: + QiskitError: If the tree definitions don't agree. If the concatenation fails due to the + assumptions on dimension or shape for the leaves not being satisfied, an error will be + raised directly by JAX. + """ + + leaves0, tree_def = tree_flatten(trees[0]) + + leaves_list = [leaves0] + for tree in trees[1:]: + next_leaves, next_tree_def = tree_flatten(tree) + + if next_tree_def != tree_def: + raise QiskitError("All trees passed to tree_stack must have the same tree definition.") + + leaves_list.append(next_leaves) + + concatenated_leaves = [jnp.concatenate(x, axis=axis) for x in zip(*leaves_list)] + + return tree_def.unflatten(concatenated_leaves) + + +def tree_product(trees: Iterable["PyTree"]) -> "PyTree": + """Take the "Cartesian product" of an iterable of PyTrees along the leading axis of their + leaves. + + The simplest usage of this function is when the trees are simple individual arrays. As an + example, given ``a = jnp.array([1., 2., 3.])`` and ``b = jnp.array([-1, 0, 1])``, + ``a_out, b_out = tree_product([a, b])``, it holds that + ``a_out == jnp.array([1., 1., 1., 2., 2., 2., 3., 3., 3.])`` and + ``b_out == jnp.array([-1, 0, 1, -1, 0, 1, -1, 0, 1])``. I.e., ``zip(a_out, b_out)`` will iterate + over the Cartesian product of the entries of ``a`` and ``b``. + + This behaviour is extended to PyTrees, with the restriction that within a PyTree, each leaf must + have a leading axis of the same length. (The inputs can be viewed as a list of PyTrees that have + been passed through :func:`tree_concatenate`.) Abusing notation, for a PyTree ``pt``, let + ``pt[idx]`` denote ``tree_map(pt, lambda a: a[idx])`` (i.e. ``pt[idx]`` denotes the PyTree when + indexing all leaves with ``idx``). Let ``a1``, ..., ``am`` denote PyTrees whose leaves are all + arrays with dimension at least 1, and within each PyTree, all leaves have the same length. + It holds that + ``tree_product([a1, ..., am])[idx1, ..., idxm] = (a1[idx1], ..., am[idxm])``. + + For example, given + ``a = (jnp.array([1., 2., 3.]), jnp.array([[4., 5.], [6., 7.], [8., 9.]]))`` and + ``b = (jnp.array([0, 1]), jnp.array([0, 2]), jnp.array([3, 4]))`` and + ``a_out, b_out = tree_product([a, b])``, it holds that + ``a_out[0] == jnp.array([1., 1., 2., 2., 3., 3.])`` + ``a_out[1] == jnp.array([[4., 5.], [4., 5.], [6., 7.], [6., 7.], [8., 9.], [8., 9.]])``, + ``b_out[0] == jnp.array([0, 1, 0, 1, 0, 1])``, and + ``b_out[1] == jnp.array([0, 2, 0, 2, 0, 2])``. + + Args: + trees: Iterably of PyTrees. + Returns: + PyTree: A list of PyTrees. + Raises: + QiskitError: if leaves of input do not satisfy the function requirements. + """ + + # validate that, within each tree, the leaves all have the same length + for tree in trees: + leaves, _ = tree_flatten(tree) + if any( + not isinstance(leaf, (np.ndarray, jnp.ndarray)) or leaf.ndim == 0 for leaf in leaves + ): + raise QiskitError("All pytree leaves must be arrays having dimension at least 1.") + + len0 = len(leaves[0]) + if any(len(leaf) != len0 for leaf in leaves[1:]): + raise QiskitError( + "pytree_product requires that all leaves within a given tree have the same " + "length." + ) + + # compute the Cartesian product where first len(trees) leading dimensions index each combination + outer_product_trees = xmap( + lambda *args: args, + in_axes=[{0: f"a{k}"} for k in range(len(trees))], + out_axes=tuple([{k: f"a{k}" for k in range(len(trees))}] * len(trees)), + )(*trees) + + # flatten first len(trees) dimensions + num_trees = len(trees) + + def flatten_func(leaf): + shape = leaf.shape + return leaf.reshape((np.prod(shape[:num_trees]),) + shape[num_trees:]) + + return tree_map(flatten_func, outer_product_trees) diff --git a/test/dynamics/compute_utils/__init__.py b/test/dynamics/compute_utils/__init__.py new file mode 100644 index 000000000..7da6f60fa --- /dev/null +++ b/test/dynamics/compute_utils/__init__.py @@ -0,0 +1,15 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Dynamics compute_utils module tests. +""" diff --git a/test/dynamics/compute_utils/test_parallel_maps.py b/test/dynamics/compute_utils/test_parallel_maps.py new file mode 100644 index 000000000..176de1c60 --- /dev/null +++ b/test/dynamics/compute_utils/test_parallel_maps.py @@ -0,0 +1,406 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=invalid-name + +""" +Tests for parallel maps. Note that we can't test actual parallelism here, we can only verify the +correctness of the output. +""" + +import unittest + +import numpy as np +import jax.numpy as jnp +from jax import random +from qiskit import QiskitError +from qiskit_dynamics.compute_utils.parallel_maps import ( + grid_map, + _move_args_to_front, + _tree_product_with_keys, +) + + +class Testgrid_map(unittest.TestCase): + """Test grid_map.""" + + def test_device_dim_error(self): + """Test error is raised if device dimension is not 1.""" + x = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + with self.assertRaisesRegex(QiskitError, "devices must be a 1d"): + grid_map(jnp.sin, x, devices=np.array([[1, 2], [3, 4]])) + + def test_1d_grid(self): + """Test correct output when run on a 1d grid.""" + x = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + output = grid_map(jnp.sin, x) + expected = jnp.sin(x) + self.assertTrue(np.allclose(output, expected)) + + def test_2d_grid(self): + """Test correct output when run on a 2d grid.""" + x = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + y = np.array([6.0, 7.0, 8.0, 9.0, 10.0]) + + def f(x, y): + return jnp.sin(0.1 * x + 0.2 * y) + + output = grid_map(f, x, y) + expected = jnp.array([[f(a, b) for b in y] for a in x]) + self.assertTrue(np.allclose(output, expected)) + + def test_3d_grid(self): + """Test correct output when run on a 3d grid.""" + x = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + y = np.array([6.0, 7.0, 8.0, 9.0, 10.0]) + z = np.array([11.0, 12.0]) + + def f(x, y, z): + return jnp.sin(0.1 * x + 0.2 * y + 0.3 * z) + + output = grid_map(f, x, y, z) + expected = jnp.array([[[f(a, b, c) for c in z] for b in y] for a in x]) + self.assertTrue(np.allclose(output, expected)) + + def test_1d_grid_pytree(self): + """Test correct function on 1d pytree grid.""" + x = (np.array([1.0, 2.0, 3.0]), np.array([4.0, 5.0, 6.0])) + + def f(a): + return a[0] * a[1] + + output = grid_map(f, x) + expected = jnp.array([4.0, 10.0, 18.0]) + self.assertTrue(np.allclose(output, expected)) + + def test_2d_grid_pytree_output(self): + """Test correct function on 2d pytree grid, with a pytree output.""" + + x = (np.array([1.0, 2.0, 3.0]), np.array([4.0, 5.0, 6.0])) + y = {"arg0": jnp.array([[0, 1], [1, 0]]), "arg1": jnp.array([3, 4])} + + def f(a, b): + return (a[0] * a[1], b["arg0"] * b["arg1"]) + + output = grid_map(f, x, y) + expected = ( + jnp.array([[4.0, 4.0], [10.0, 10.0], [18.0, 18.0]]), + jnp.array([[[0, 3], [4, 0]], [[0, 3], [4, 0]], [[0, 3], [4, 0]]]), + ) + self.assertTrue(len(output) == 2) + self.assertTrue(isinstance(output, tuple)) + self.assertTrue(np.allclose(output[0], expected[0])) + self.assertTrue(np.allclose(output[1], expected[1])) + + def test_arrays_of_different_shape(self): + """Test correct mapping when input arrays are of different shape.""" + x = np.array([[1.0, 2.0], [3.0, 4.0]]) + y = np.array( + [[[5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]]] + ) + + def f(x, y): + return y @ x + + output = grid_map(f, x, y) + expected = jnp.array( + [[y[0] @ x[0], y[1] @ x[0], y[2] @ x[0]], [y[0] @ x[1], y[1] @ x[1], y[2] @ x[1]]] + ) + self.assertTrue(np.allclose(output, expected)) + + def test_correct_mapping_max_vmap(self): + """Test correct mapping with a max vmap.""" + + x = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + output = grid_map(jnp.sin, x, max_vmap_size=3) + expected = jnp.sin(x) + self.assertTrue(np.allclose(output, expected)) + + def test_key_inclusion(self): + """Test correct handling of key generation.""" + + def f(a, b, key): + return {"a": a, "b": b, "key": key} + + key = random.PRNGKey(1234) + output = grid_map(f, jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]), key=key) + + expected_keys = random.split(key, 4).reshape(2, 2, 2) + self.assertTrue(np.allclose(output["a"], jnp.array([1.0, 1.0, 2.0, 2.0]).reshape(2, 2))) + self.assertTrue(np.allclose(output["b"], jnp.array([3.0, 4.0, 3.0, 4.0]).reshape(2, 2))) + self.assertTrue(np.allclose(output["key"], expected_keys)) + + def test_key_inclusion_2_per(self): + """Test correct handling of key generation with 2 per key.""" + + def f(a, b, key): + return {"a": a, "b": b, "key": key} + + key = random.PRNGKey(1234) + output = grid_map( + f, jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]), key=key, keys_per_grid_point=2 + ) + + expected_keys = random.split(key, 8).reshape(2, 2, 2, 2) + self.assertTrue( + np.allclose( + output["a"], jnp.array([1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0]).reshape(2, 2, 2) + ) + ) + self.assertTrue( + np.allclose( + output["b"], jnp.array([3.0, 3.0, 4.0, 4.0, 3.0, 3.0, 4.0, 4.0]).reshape(2, 2, 2) + ) + ) + self.assertTrue(np.allclose(output["key"], expected_keys)) + + +class Testgrid_map_nonjax_args(unittest.TestCase): + """Test grid_map with nonjax arguments.""" + + def test_non_int_argnum(self): + """Validate error is raised for non integer argnum.""" + + with self.assertRaisesRegex(QiskitError, "must be int"): + grid_map(None, None, nonjax_argnums=["not_an_int"]) + + def test_3_args_2_nonjax_args(self): + """Test case with 3 args and 2 nonjax args.""" + + def f(a, b, c): + if b: + if c == "string": + return a + else: + return a**2 + else: + if c == "string": + return a**3 + else: + return a**4 + + a_list = jnp.array([2.0, 3.0, 4.0]) + b_list = [True, False] + c_list = ["string", "notstring"] + + expected = np.zeros((len(a_list), len(b_list), len(c_list))) + for a_idx, a in enumerate(a_list): + for b_idx, b in enumerate(b_list): + for c_idx, c in enumerate(c_list): + expected[a_idx, b_idx, c_idx] = f(a, b, c) + + self.assertTrue( + np.allclose(expected, grid_map(f, a_list, b_list, c_list, nonjax_argnums=[1, 2])) + ) + + def test_4_args_2_nonjax_args(self): + """Test case with 4 args and 2 nonjax args.""" + + def f(a, b, c, d): + if b: + if c == "string": + return a + 1j * d + else: + return a**2 + 2j * d + else: + if c == "string": + return a**3 + 3j * d + else: + return a**4 + 4j * d + + a_list = jnp.array([2.0, 3.0, 4.0]) + b_list = [True, False] + c_list = ["string", "notstring"] + d_list = jnp.array([5.0, 6.0, 7.0, 8.0]) + + expected = np.zeros((len(a_list), len(b_list), len(c_list), len(d_list)), dtype=complex) + for a_idx, a in enumerate(a_list): + for b_idx, b in enumerate(b_list): + for c_idx, c in enumerate(c_list): + for d_idx, d in enumerate(d_list): + expected[a_idx, b_idx, c_idx, d_idx] = f(a, b, c, d) + + self.assertTrue( + np.allclose( + expected, grid_map(f, a_list, b_list, c_list, d_list, nonjax_argnames=["b", "c"]) + ) + ) + + def test_4_args_2_nonjax_args_non_consecutive(self): + """Test case with 4 args and 2 non-consecutive nonjax args.""" + + def f(a, b, d, c): + if b: + if c == "string": + return a + 1j * d + else: + return a**2 + 2j * d + else: + if c == "string": + return a**3 + 3j * d + else: + return a**4 + 4j * d + + a_list = jnp.array([2.0, 3.0, 4.0]) + b_list = [True, False] + c_list = ["string", "notstring"] + d_list = jnp.array([5.0, 6.0, 7.0, 8.0]) + + expected = np.zeros((len(a_list), len(b_list), len(d_list), len(c_list)), dtype=complex) + for a_idx, a in enumerate(a_list): + for b_idx, b in enumerate(b_list): + for d_idx, d in enumerate(d_list): + for c_idx, c in enumerate(c_list): + expected[a_idx, b_idx, d_idx, c_idx] = f(a, b, d, c) + + self.assertTrue( + np.allclose( + expected, grid_map(f, a_list, b_list, d_list, c_list, nonjax_argnames=["b", "c"]) + ) + ) + + def test_key_inclusion(self): + """Test correct handling of key generation with nonjax argnums.""" + + def f(a, b, key): + return {"a": a, "b": b, "key": key} + + key = random.PRNGKey(1234) + output = grid_map( + f, jnp.array([1.0, 2.0, 3.0]), jnp.array([4.0, 5.0]), key=key, nonjax_argnums=[1] + ) + + nonjax_keys = random.split(key, 2) + + expected_keys = jnp.array( + [random.split(nonjax_keys[0], 3), random.split(nonjax_keys[1], 3)] + ).transpose((1, 0, 2)) + + self.assertTrue( + np.allclose(output["a"], jnp.array([1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).reshape(3, 2)) + ) + self.assertTrue( + np.allclose(output["b"], jnp.array([4.0, 5.0, 4.0, 5.0, 4.0, 5.0]).reshape(3, 2)) + ) + self.assertTrue(np.allclose(output["key"], expected_keys)) + + def test_key_inclusion_2_per(self): + """Test correct handling of key generation with nonjax argnums and 2 keys per input.""" + + def f(a, b, key): + return {"a": a, "b": b, "key": key} + + key = random.PRNGKey(1234) + output = grid_map( + f, + jnp.array([1.0, 2.0, 3.0]), + jnp.array([4.0, 5.0]), + key=key, + keys_per_grid_point=2, + nonjax_argnums=[1], + ) + + nonjax_keys = random.split(key, 2) + + expected_keys = ( + jnp.array([random.split(nonjax_keys[0], 6), random.split(nonjax_keys[1], 6)]) + .reshape(2, 3, 2, 2) + .transpose((1, 0, 2, 3)) + ) + + self.assertTrue( + np.allclose( + output["a"], + jnp.array([1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0]).reshape( + 3, 2, 2 + ), + ) + ) + self.assertTrue( + np.allclose( + output["b"], + jnp.array([4.0, 4.0, 5.0, 5.0, 4.0, 4.0, 5.0, 5.0, 4.0, 4.0, 5.0, 5.0]).reshape( + 3, 2, 2 + ), + ) + ) + self.assertTrue(np.allclose(output["key"], expected_keys)) + + +class Test_move_args_to_front(unittest.TestCase): + """Tests for helper function _move_args_to_front.""" + + def test_array_building_func_case1(self): + """Test a function that compiles scalar inputs into a 1d array.""" + + def f(a, b, c, d, e): + return np.array([a, b, c, d, e]) + + g = _move_args_to_front(f, argnums=[1, 4]) + + # b e a c d + out = g(1, 2, 3, 4, 5) + expected = np.array([3, 1, 4, 5, 2]) + + self.assertTrue(np.allclose(out, expected)) + + def test_array_building_func_case2(self): + """Test a function that compiles scalar inputs into a 1d array.""" + + def f(a, b, c, d, e): + return np.array([a, b, c, d, e]) + + g = _move_args_to_front(f, argnums=[1, 3]) + + # b d a c e + out = g(1, 2, 3, 4, 5) + expected = np.array([3, 1, 4, 2, 5]) + + self.assertTrue(np.allclose(out, expected)) + + +class Test_tree_product_with_keys(unittest.TestCase): + """Test cases for _tree_product_with_keys.""" + + def test_invalid_key(self): + """Test key of incorrect type.""" + + with self.assertRaisesRegex(QiskitError, "Invalid format"): + _tree_product_with_keys((jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), key="key") + + def test_case1(self): + """Simple test case.""" + key = random.PRNGKey(1234) + + output = _tree_product_with_keys( + (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), + key=key, + ) + + expected_keys = random.split(key, 4) + self.assertTrue(isinstance(output, tuple) and len(output) == 3) + self.assertTrue(np.allclose(output[0], np.array([1.0, 1.0, 2.0, 2.0]))) + self.assertTrue(np.allclose(output[1], np.array([3.0, 4.0, 3.0, 4.0]))) + self.assertTrue(np.allclose(output[2], expected_keys)) + + def test_case_2_per(self): + """Test case with with 2 keys per input.""" + key = random.PRNGKey(1234) + + output = _tree_product_with_keys( + (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), key=key, keys_per_grid_point=2 + ) + + expected_keys = random.split(key, 8) + self.assertTrue(isinstance(output, tuple) and len(output) == 3) + self.assertTrue(np.allclose(output[0], np.array([1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0]))) + self.assertTrue(np.allclose(output[1], np.array([3.0, 3.0, 4.0, 4.0, 3.0, 3.0, 4.0, 4.0]))) + self.assertTrue(np.allclose(output[2], expected_keys)) diff --git a/test/dynamics/compute_utils/test_pytree_utils.py b/test/dynamics/compute_utils/test_pytree_utils.py new file mode 100644 index 000000000..815e3e10b --- /dev/null +++ b/test/dynamics/compute_utils/test_pytree_utils.py @@ -0,0 +1,125 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2024. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +# pylint: disable=invalid-name + +""" +Tests for pytree utils. +""" + +import unittest + +import jax.numpy as jnp + +from qiskit import QiskitError + +from qiskit_dynamics.compute_utils.pytree_utils import tree_concatenate, tree_product + + +class Testtree_concatenate(unittest.TestCase): + """Test tree_concatenate.""" + + def test_arrays(self): + """Test on raw arrays.""" + out = tree_concatenate([jnp.array([1.0, 2.0, 3.0]), jnp.array([4.0, 5.0])]) + self.assertTrue(all(out == jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]))) + + def test_pytree_different_shapes(self): + """Test on pytrees whose entries have different shapes.""" + tree0 = (jnp.array([1.0, 2.0]), [jnp.array([[3.0, 4.0], [5.0, 6.0]])]) + tree1 = (jnp.array([7.0, 8.0, 9.0]), [jnp.array([[10.0, 11.0]])]) + + out = tree_concatenate([tree0, tree1]) + self.assertTrue(len(out) == 2) + self.assertTrue(all(out[0] == jnp.array([1.0, 2.0, 7.0, 8.0, 9.0]))) + self.assertTrue( + all((out[1][0] == jnp.array([[3.0, 4.0], [5.0, 6.0], [10.0, 11.0]])).flatten()) + ) + + def test_tree_def_error(self): + """Test that inconsistent tree defs results in a raised error.""" + tree0 = (1, 2) + tree1 = (1,) + + with self.assertRaisesRegex(QiskitError, "same tree def"): + tree_concatenate([tree0, tree1]) + + +class Testtree_product(unittest.TestCase): + """Test tree_product.""" + + def test_arrays(self): + """Test on raw arrays.""" + + a = jnp.array([1.0, 2.0, 3.0]) + b = jnp.array([-1, 0, 1]) + out = tree_product([a, b]) + self.assertTrue(all(out[0] == jnp.array([1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0]))) + self.assertTrue(all(out[1] == jnp.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]))) + + def test_3_arrays(self): + """Test on raw arrays.""" + + a = jnp.array([1.0, 2.0, 3.0]) + b = jnp.array([-1, 0]) + c = jnp.array([[0.0, 1.0]]) + out = tree_product([a, b, c]) + self.assertTrue(all(out[0] == jnp.array([1.0, 1.0, 2.0, 2.0, 3.0, 3.0]))) + self.assertTrue(all(out[1] == jnp.array([-1, 0, -1, 0, -1, 0]))) + self.assertTrue( + all( + ( + out[2] + == jnp.array( + [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0]] + ) + ).flatten() + ) + ) + + def test_1_input(self): + """Test edge case of only one input.""" + + a = jnp.array([1.0, 2.0, 3.0]) + out = tree_product([a]) + self.assertTrue(all(out[0] == a)) + + def test_simple_pytree(self): + """Test a case of two simple pytrees.""" + + a = (jnp.array([1.0, 2.0, 3.0]), jnp.array([[4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])) + b = (jnp.array([0, 1]), jnp.array([0, 2]), jnp.array([3, 4])) + a_out, b_out = tree_product([a, b]) + self.assertTrue(all(a_out[0] == jnp.array([1.0, 1.0, 2.0, 2.0, 3.0, 3.0]))) + self.assertTrue( + all( + ( + a_out[1] + == jnp.array( + [[4.0, 5.0], [4.0, 5.0], [6.0, 7.0], [6.0, 7.0], [8.0, 9.0], [8.0, 9.0]] + ) + ).flatten() + ) + ) + self.assertTrue(all(b_out[0] == jnp.array([0, 1, 0, 1, 0, 1]))) + self.assertTrue(all(b_out[1] == jnp.array([0, 2, 0, 2, 0, 2]))) + + def test_scalar_error(self): + """Test an error is raised if a scalar is leaf is present.""" + + with self.assertRaisesRegex(QiskitError, "dimension at least 1"): + tree_product([jnp.array(1)]) + + def test_length_error(self): + """Test an error is raised if two leaves in the same tree have different lengths.""" + + with self.assertRaisesRegex(QiskitError, "all leaves within a given tree have the same"): + tree_product([(jnp.array([1]), jnp.array([2, 3]))])