From de67dc1b8bc251f244d9727162cad03493ff1792 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sat, 25 Jan 2025 23:42:49 +0000 Subject: [PATCH 1/2] ENH: New function `lazy_apply` --- docs/api-lazy.md | 15 + docs/conf.py | 2 + docs/index.md | 2 +- docs/testing-utils.md | 14 - pixi.lock | 31 +- pyproject.toml | 8 +- src/array_api_extra/__init__.py | 2 + src/array_api_extra/_lib/_lazy.py | 368 +++++++++++++++++++++ src/array_api_extra/_lib/_utils/_typing.py | 3 +- src/array_api_extra/testing.py | 48 ++- tests/test_lazy.py | 82 +++++ 11 files changed, 535 insertions(+), 40 deletions(-) create mode 100644 docs/api-lazy.md delete mode 100644 docs/testing-utils.md create mode 100644 src/array_api_extra/_lib/_lazy.py create mode 100644 tests/test_lazy.py diff --git a/docs/api-lazy.md b/docs/api-lazy.md new file mode 100644 index 00000000..150fd62e --- /dev/null +++ b/docs/api-lazy.md @@ -0,0 +1,15 @@ +# Tools for lazy backends + +These additional functions are meant to be used to support compatibility with +lazy backends, e.g. Dask or Jax: + +```{eval-rst} +.. currentmodule:: array_api_extra +.. autosummary:: + :nosignatures: + :toctree: generated + + lazy_apply + testing.lazy_xp_function + testing.patch_lazy_xp_functions +``` diff --git a/docs/conf.py b/docs/conf.py index 79000c96..4696e7a6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,6 +53,8 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "dask": ("https://docs.dask.org/en/stable", None), "jax": ("https://jax.readthedocs.io/en/latest", None), } diff --git a/docs/index.md b/docs/index.md index f7c51574..a5c6d7bf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,7 +5,7 @@ :hidden: self api-reference.md -testing-utils.md +api-lazy.md contributing.md contributors.md ``` diff --git a/docs/testing-utils.md b/docs/testing-utils.md deleted file mode 100644 index 49aeb306..00000000 --- a/docs/testing-utils.md +++ /dev/null @@ -1,14 +0,0 @@ -# Testing Utilities - -These additional functions are meant to be used while unit testing Array API -compliant packages: - -```{eval-rst} -.. currentmodule:: array_api_extra.testing -.. autosummary:: - :nosignatures: - :toctree: generated - - lazy_xp_function - patch_lazy_xp_functions -``` diff --git a/pixi.lock b/pixi.lock index 00c8feef..e8da46d4 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1620,16 +1620,22 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/hpack-4.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/hyperframe-6.1.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/icu-75.1-he02047a_0.conda - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/libblas-3.9.0-26_linux64_mkl.conda + - conda: https://prefix.dev/conda-forge/linux-64/libcblas-3.9.0-26_linux64_mkl.conda - conda: https://prefix.dev/conda-forge/linux-64/libexpat-2.6.4-h5888daf_0.conda - conda: https://prefix.dev/conda-forge/linux-64/libffi-3.4.2-h7f98852_5.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/libgcc-14.2.0-h77fa898_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libgcc-ng-14.2.0-h69a702a_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/libhwloc-2.11.2-default_h0d58e46_1001.conda + - conda: https://prefix.dev/conda-forge/linux-64/libiconv-1.17-hd590300_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/liblapack-3.9.0-26_linux64_mkl.conda - conda: https://prefix.dev/conda-forge/linux-64/liblzma-5.6.3-hb9d3cd8_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda - conda: https://prefix.dev/conda-forge/linux-64/libsqlite-3.48.0-hee588c1_1.conda @@ -1637,6 +1643,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/libstdcxx-ng-14.2.0-h4852527_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda - conda: https://prefix.dev/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/libxml2-2.13.5-h8d12d68_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/llvm-openmp-19.1.7-h024ca30_0.conda - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 @@ -1644,8 +1651,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/markupsafe-3.0.2-py312h178313f_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdit-py-plugins-0.4.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/mkl-2024.2.2-ha957f24_16.conda - conda: https://prefix.dev/conda-forge/noarch/myst-parser-4.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/ncurses-6.5-h2d0b736_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/numpy-2.0.2-py312h58c1407_1.conda - conda: https://prefix.dev/conda-forge/linux-64/openssl-3.4.0-h7b32b05_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -1672,6 +1681,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-qthelp-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/tbb-2021.13.0-hceb3a55_1.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda @@ -1711,12 +1721,19 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libblas-3.9.0-26_osxarm64_openblas.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libcblas-3.9.0-26_osxarm64_openblas.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libcxx-19.1.7-ha82da77_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.6.4-h286801f_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libffi-3.4.2-h3422bc3_5.tar.bz2 + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran-5.0.0-13_2_0_hd922786_3.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libgfortran5-13.2.0-hf226fd6_3.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/liblapack-3.9.0-26_osxarm64_openblas.conda - conda: https://prefix.dev/conda-forge/osx-arm64/liblzma-5.6.3-h39f12f2_1.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/libopenblas-0.3.28-openmp_hf332438_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libsqlite-3.48.0-h3f77e49_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libzlib-1.3.1-h8359307_2.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/llvm-openmp-19.1.7-hdb05f8b_0.conda - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/markupsafe-3.0.2-py312h998013c_1.conda @@ -1724,6 +1741,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/myst-parser-4.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ncurses-6.5-h5e97a16_2.conda + - conda: https://prefix.dev/conda-forge/osx-arm64/numpy-2.0.2-py312h94ee1e1_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/openssl-3.4.0-h81ee809_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -1788,18 +1806,28 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/win-64/intel-openmp-2024.2.1-h57928b3_1083.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/win-64/libblas-3.9.0-26_win64_mkl.conda + - conda: https://prefix.dev/conda-forge/win-64/libcblas-3.9.0-26_win64_mkl.conda - conda: https://prefix.dev/conda-forge/win-64/libexpat-2.6.4-he0c23c2_0.conda - conda: https://prefix.dev/conda-forge/win-64/libffi-3.4.2-h8ffe710_5.tar.bz2 + - conda: https://prefix.dev/conda-forge/win-64/libhwloc-2.11.2-default_ha69328c_1001.conda + - conda: https://prefix.dev/conda-forge/win-64/libiconv-1.17-hcfcfb64_2.conda + - conda: https://prefix.dev/conda-forge/win-64/liblapack-3.9.0-26_win64_mkl.conda - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.48.0-h67fdade_1.conda + - conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_9.conda + - conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/markupsafe-3.0.2-py312h31fea79_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdit-py-plugins-0.4.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/win-64/mkl-2024.2.2-h66d3029_15.conda - conda: https://prefix.dev/conda-forge/noarch/myst-parser-4.0.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/win-64/numpy-2.0.2-py312h49bc9c5_1.conda - conda: https://prefix.dev/conda-forge/win-64/openssl-3.4.0-ha4e3fda_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda @@ -1825,6 +1853,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-jsmath-1.0.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-qthelp-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/win-64/tbb-2021.13.0-h62715c5_1.conda - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda @@ -3767,7 +3796,7 @@ packages: - pypi: . name: array-api-extra version: 0.6.1.dev0 - sha256: bb6cd89a7f100a73d3f853de571b2f4fff0e70de8df0d113f2f5c1559744e6b6 + sha256: 1e032f707df46a29e306ede97d65b2129e0944b361b96317e5653bd74e695ce2 requires_dist: - array-api-compat>=1.10.0,<2 requires_python: '>=3.10' diff --git a/pyproject.toml b/pyproject.toml index d15aba84..26a73fc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ sphinx-autodoc-typehints = "*" dask-core = "*" pytest = "*" typing-extensions = "*" +numpy = "*" [tool.pixi.feature.docs.tasks] docs = { cmd = "sphinx-build . build/", cwd = "docs" } @@ -311,10 +312,5 @@ checks = [ "ES01", # most docstrings do not need an extended summary ] exclude = [ # don't report on objects that match any of these regex - '.*test_at.*', - '.*test_funcs.*', - '.*test_testing.*', - '.*test_utils.*', - '.*test_version.*', - '.*test_vendor.*', + '.*test_*', ] diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 840dd8e7..aeedd9da 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -12,6 +12,7 @@ setdiff1d, sinc, ) +from ._lib._lazy import lazy_apply __version__ = "0.6.1.dev0" @@ -25,6 +26,7 @@ "expand_dims", "isclose", "kron", + "lazy_apply", "nunique", "pad", "setdiff1d", diff --git a/src/array_api_extra/_lib/_lazy.py b/src/array_api_extra/_lib/_lazy.py new file mode 100644 index 00000000..47a2cc83 --- /dev/null +++ b/src/array_api_extra/_lib/_lazy.py @@ -0,0 +1,368 @@ +"""Public API Functions.""" + +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 +from __future__ import annotations + +import math +from collections.abc import Callable, Sequence +from functools import partial, wraps +from types import ModuleType +from typing import TYPE_CHECKING, Any, cast, overload + +from ._utils._compat import ( + array_namespace, + is_array_api_obj, + is_dask_namespace, + is_jax_array, + is_jax_namespace, +) +from ._utils._typing import Array, DType + +if TYPE_CHECKING: + # TODO move outside TYPE_CHECKING + # depends on scikit-learn abandoning Python 3.9 + # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 + from typing import ParamSpec, TypeAlias + + import numpy as np + from numpy.typing import ArrayLike + + NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[no-any-explicit] + P = ParamSpec("P") +else: + # Sphinx hacks + NumPyObject = Any + + class P: # pylint: disable=missing-class-docstring + args: tuple + kwargs: dict + + +@overload +def lazy_apply( # type: ignore[valid-type] + func: Callable[P, ArrayLike], + *args: Array, + shape: tuple[int | None, ...] | None = None, + dtype: DType | None = None, + as_numpy: bool = False, + xp: ModuleType | None = None, + **kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues] +) -> Array: ... # numpydoc ignore=GL08 + + +@overload +def lazy_apply( # type: ignore[valid-type] + func: Callable[P, Sequence[ArrayLike]], + *args: Array, + shape: Sequence[tuple[int | None, ...]], + dtype: Sequence[DType] | None = None, + as_numpy: bool = False, + xp: ModuleType | None = None, + **kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues] +) -> tuple[Array, ...]: ... # numpydoc ignore=GL08 + + +def lazy_apply( # type: ignore[valid-type] # numpydoc ignore=GL07,SA04 + func: Callable[P, Array | Sequence[ArrayLike]], + *args: Array, + shape: tuple[int | None, ...] | Sequence[tuple[int | None, ...]] | None = None, + dtype: DType | Sequence[DType] | None = None, + as_numpy: bool = False, + xp: ModuleType | None = None, + **kwargs: P.kwargs, # pyright: ignore[reportGeneralTypeIssues] +) -> Array | tuple[Array, ...]: + """ + Lazily apply an eager function. + + If the backend of the input arrays is lazy, e.g. Dask or jitted JAX, the execution + of the function is delayed until the graph is materialized; if it's eager, the + function is executed immediately. + + Parameters + ---------- + func : callable + The function to apply. + + It must accept one or more array API compliant arrays as positional arguments. + If `as_numpy=True`, inputs are converted to NumPy before they are passed to + `func`. + It must return either a single array-like or a sequence of array-likes. + + `func` must be a pure function, i.e. without side effects, as depending on the + backend it may be executed more than once. + *args : Array + One or more Array API compliant arrays. + + If `as_numpy=True`, you need to be able to apply :func:`numpy.asarray` to them + to convert them to numpy; read notes below about specific backends. + shape : tuple[int | None, ...] | Sequence[tuple[int, ...]], optional + Output shape or sequence of output shapes, one for each output of `func`. + Default: assume single output and broadcast shapes of the input arrays. + dtype : DType | Sequence[DType], optional + Output dtype or sequence of output dtypes, one for each output of `func`. + dtype(s) must belong to the same array namespace as the input arrays. + Default: infer the result type(s) from the input arrays. + as_numpy : bool, optional + If True, convert the input arrays to NumPy before passing them to `func`. + This is particularly useful to make numpy-only functions, e.g. written in Cython + or Numba, work transparently API arrays. + Default: False. + xp : array_namespace, optional + The standard-compatible namespace for `args`. Default: infer. + **kwargs : Any, optional + Additional keyword arguments to pass verbatim to `func`. + Any array objects in them will be converted to numpy when ``as_numpy=True``. + + Returns + ------- + Array | tuple[Array, ...] + The result(s) of `func` applied to the input arrays, wrapped in the same + array namespace as the inputs. + If shape is omitted or a `tuple[int | None, ...]`, this is a single array. + Otherwise, it's a tuple of arrays. + + Notes + ----- + JAX + This allows applying eager functions to jitted JAX arrays, which are lazy. + The function won't be applied until the JAX array is materialized. + When running inside `jax.jit`, `shape` must be fully known, i.e. it cannot + contain any `None` elements. + + Using this with `as_numpy=False` is particularly useful to apply non-jittable + JAX functions to arrays on GPU devices. + If `as_numpy=True`, the :doc:`jax:transfer_guard` may prevent arrays on a GPU + device from being transferred back to CPU. This is treated as an implicit + transfer. + + PyTorch, CuPy + If `as_numpy=True`, these backends raise by default if you attempt to convert + arrays on a GPU device to NumPy. + + Sparse + If `as_numpy=True`, by default sparse prevents implicit densification through + :func:`numpy.asarray`. `This safety mechanism can be disabled + `_. + + Dask + This allows applying eager functions to dask arrays. + The dask graph won't be computed. + + `lazy_apply` doesn't know if `func` reduces along any axes; also, shape + changes are non-trivial in chunked Dask arrays. For these reasons, all inputs + will be rechunked into a single chunk. + + .. warning:: + + The whole operation needs to fit in memory all at once on a single worker. + + The outputs will also be returned as a single chunk and you should consider + rechunking them into smaller chunks afterwards. + + If you want to distribute the calculation across multiple workers, you + should use :func:`dask.array.map_blocks`, :func:`dask.array.map_overlap`, + :func:`dask.array.blockwise`, or a native Dask wrapper instead of + `lazy_apply`. + + Dask wrapping around other backends + If `as_numpy=False`, `func` will receive in input eager arrays of the meta + namespace, as defined by the `._meta` attribute of the input Dask arrays. + The outputs of `func` will be wrapped by the meta namespace, and then wrapped + again by Dask. + + Raises + ------ + jax.errors.TracerArrayConversionError + When `xp=jax.numpy`, `shape` is unknown (it contains None on one or more axes) + and this function was called inside `jax.jit`. + RuntimeError + When `xp=sparse` and auto-densification is disabled. + Exception (backend-specific) + When the backend disallows implicit device to host transfers and the input + arrays are on a device, e.g. on GPU. + + See Also + -------- + jax.transfer_guard + jax.pure_callback + dask.array.map_blocks + dask.array.map_overlap + dask.array.blockwise + """ + if xp is None: + xp = array_namespace(*args) + + # Normalize and validate shape and dtype + shapes: list[tuple[int | None, ...]] + dtypes: list[DType] + multi_output = False + + if shape is None: + shapes = [xp.broadcast_shapes(*(arg.shape for arg in args))] + elif isinstance(shape, tuple) and all(isinstance(s, int | None) for s in shape): + shapes = [shape] # pyright: ignore[reportAssignmentType] + else: + shapes = list(shape) # type: ignore[arg-type] # pyright: ignore[reportAssignmentType] + multi_output = True + + if dtype is None: + dtypes = [xp.result_type(*args)] * len(shapes) + elif multi_output: + if not isinstance(dtype, Sequence): + msg = "Got sequence of shapes but only one dtype" + raise TypeError(msg) + dtypes = list(dtype) # pyright: ignore[reportUnknownArgumentType] + else: + if isinstance(dtype, Sequence): + msg = "Got single shape but multiple dtypes" + raise TypeError(msg) + dtypes = [dtype] + + if len(shapes) != len(dtypes): + msg = f"Got {len(shapes)} shapes and {len(dtypes)} dtypes" + raise ValueError(msg) + if len(shapes) == 0: + msg = "func must return one or more output arrays" + raise ValueError(msg) + del shape + del dtype + + # Backend-specific branches + if is_dask_namespace(xp): + import dask + + metas = [arg._meta for arg in args if hasattr(arg, "_meta")] # pylint: disable=protected-access + meta_xp = array_namespace(*metas) + + wrapped = dask.delayed( # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage] + _lazy_apply_wrapper(func, as_numpy, multi_output, meta_xp), + pure=True, + ) + # This finalizes each arg, which is the same as arg.rechunk(-1). + # Please read docstring above for why we're not using + # dask.array.map_blocks or dask.array.blockwise! + delayed_out = wrapped(*args, **kwargs) + + out = tuple( + xp.from_delayed( + delayed_out[i], # pyright: ignore[reportIndexIssue] + # Dask's unknown shapes diverge from the Array API specification + shape=tuple(math.nan if s is None else s for s in shape), + dtype=dtype, + meta=metas[0], + ) + for i, (shape, dtype) in enumerate(zip(shapes, dtypes, strict=True)) + ) + + elif is_jax_namespace(xp): + # If we're inside jax.jit, we can't eagerly convert + # the JAX tracer objects to numpy. + # Instead, we delay calling wrapped, which will receive + # as arguments and will return JAX eager arrays. + + import jax + + # Shield eager kwargs from being coerced into JAX arrays. + # jax.pure_callback calls jax.jit under the hood, but without the chance of + # passing static_argnames / static_argnums. + lazy_kwargs = {} + eager_kwargs = {} + for k, v in kwargs.items(): + if _contains_jax_arrays(v): + lazy_kwargs[k] = v + else: + eager_kwargs[k] = v + + wrapped = _lazy_apply_wrapper( + partial(func, **eager_kwargs), as_numpy, multi_output, xp + ) + + if any(s is None for shape in shapes for s in shape): + # Unknown output shape. Won't work with jax.jit, but it + # can work with eager jax. + # Raises jax.errors.TracerArrayConversionError if we're inside jax.jit. + out = wrapped(*args, **lazy_kwargs) + + else: + # suppress unused-ignore to run mypy in -e lint as well as -e dev + out = cast( # type: ignore[bad-cast,unused-ignore] + tuple[Array, ...], + jax.pure_callback( + wrapped, + tuple( + jax.ShapeDtypeStruct(shape, dtype) # pyright: ignore[reportUnknownArgumentType] + for shape, dtype in zip(shapes, dtypes, strict=True) + ), + *args, + **lazy_kwargs, + ), + ) + + else: + # Eager backends + wrapped = _lazy_apply_wrapper(func, as_numpy, multi_output, xp) + out = wrapped(*args, **kwargs) + + return out if multi_output else out[0] + + +def _contains_jax_arrays(x: object) -> bool: # numpydoc ignore=PR01,RT01 + """ + Test if x is a JAX array or a nested collection with any JAX arrays in it. + """ + if is_jax_array(x): + return True + if isinstance(x, list | tuple): + return any(_contains_jax_arrays(i) for i in x) # pyright: ignore[reportUnknownArgumentType] + if isinstance(x, dict): + return any(_contains_jax_arrays(i) for i in x.values()) # pyright: ignore[reportUnknownArgumentType] + return False + + +def _as_numpy(x: object) -> Any: # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 + """Recursively convert Array API objects in x to NumPy.""" + import numpy as np # pylint: disable=import-outside-toplevel + + if is_array_api_obj(x): + return np.asarray(x) + if isinstance(x, list) or type(x) is tuple: # pylint: disable=unidiomatic-typecheck + return type(x)(_as_numpy(i) for i in x) # pyright: ignore[reportUnknownArgumentType] + if isinstance(x, tuple): # namedtuple + return type(x)(*(_as_numpy(i) for i in x)) # pyright: ignore[reportUnknownArgumentType] + if isinstance(x, dict): + return {k: _as_numpy(v) for k, v in x.items()} # pyright: ignore[reportUnknownArgumentType] + return x + + +def _lazy_apply_wrapper( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 + func: Callable[..., ArrayLike | Sequence[ArrayLike]], + as_numpy: bool, + multi_output: bool, + xp: ModuleType, +) -> Callable[..., tuple[Array, ...]]: + """ + Helper of `lazy_apply`. + + Given a function that accepts one or more arrays as positional arguments and returns + a single array-like or a sequence of array-likes, return a function that accepts the + same number of Array API arrays and always returns a tuple of Array API array. + + Any keyword arguments are passed through verbatim to the wrapped function. + """ + + # On Dask, @wraps causes the graph key to contain the wrapped function's name + @wraps(func) + def wrapper( # type: ignore[no-any-decorated,no-any-explicit] + *args: Array, **kwargs: Any + ) -> tuple[Array, ...]: # numpydoc ignore=GL08 + if as_numpy: + args = _as_numpy(args) + kwargs = _as_numpy(kwargs) + out = func(*args, **kwargs) + + if multi_output: + assert isinstance(out, Sequence) + return tuple(xp.asarray(o) for o in out) + return (xp.asarray(out),) + + return wrapper diff --git a/src/array_api_extra/_lib/_utils/_typing.py b/src/array_api_extra/_lib/_utils/_typing.py index 83b51d04..95f29f79 100644 --- a/src/array_api_extra/_lib/_utils/_typing.py +++ b/src/array_api_extra/_lib/_utils/_typing.py @@ -5,6 +5,7 @@ # To be changed to a Protocol later (see data-apis/array-api#589) Array = Any # type: ignore[no-any-explicit] Device = Any # type: ignore[no-any-explicit] +DType = Any # type: ignore[no-any-explicit] Index = Any # type: ignore[no-any-explicit] -__all__ = ["Array", "Device", "Index"] +__all__ = ["Array", "DType", "Device", "Index"] diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index cc3f01f8..f0e0d7c1 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -132,12 +132,12 @@ def test_myfunc(xp): a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c = mymodule.myfunc(a) # This is not """ - func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] - if jax_jit: - func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] - "static_argnums": static_argnums, - "static_argnames": static_argnames, - } + func.lazy_xp_function = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] + "allow_dask_compute": allow_dask_compute, + "jax_jit": jax_jit, + "static_argnums": static_argnums, + "static_argnames": static_argnames, + } def patch_lazy_xp_functions( @@ -181,10 +181,13 @@ def xp(request, monkeypatch): if is_dask_namespace(xp): for name, func in globals_.items(): - n = getattr(func, "allow_dask_compute", None) - if n is not None: + kwargs = cast( # type: ignore[no-any-explicit] + "dict[str, Any] | None", getattr(func, "lazy_xp_function", None) + ) + if kwargs is not None: + n = kwargs["allow_dask_compute"] assert isinstance(n, int) - wrapped = _allow_dask_compute(func, n) + wrapped = _dask_wrap(func, n) monkeypatch.setitem(globals_, name, wrapped) elif is_jax_namespace(xp): @@ -192,12 +195,16 @@ def xp(request, monkeypatch): for name, func in globals_.items(): kwargs = cast( # type: ignore[no-any-explicit] - "dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None) + "dict[str, Any] | None", getattr(func, "lazy_xp_function", None) ) - if kwargs is not None: + if kwargs is not None and kwargs["jax_jit"]: # suppress unused-ignore to run mypy in -e lint as well as -e dev - wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore] - monkeypatch.setitem(globals_, name, wrapped) + wrapped = jax.jit( # type: ignore[no-untyped-call,unused-ignore] + func, + static_argnums=kwargs["static_argnums"], + static_argnames=kwargs["static_argnames"], + ) + monkeypatch.setitem(globals_, name, wrapped) # pyright: ignore[reportUnknownArgumentType] class CountingDaskScheduler(SchedulerGetCallable): @@ -236,13 +243,15 @@ def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any: return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage] -def _allow_dask_compute( +def _dask_wrap( func: Callable[P, T], n: int ) -> Callable[P, T]: # numpydoc ignore=PR01,RT01 """ Wrap `func` to raise if it attempts to call `dask.compute` more than `n` times. + + After the function returns, materialize the graph in order to re-raise exceptions. """ - import dask.config + import dask func_name = getattr(func, "__name__", str(func)) n_str = f"only up to {n}" if n else "no" @@ -256,7 +265,12 @@ def _allow_dask_compute( @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 scheduler = CountingDaskScheduler(n, msg) - with dask.config.set({"scheduler": scheduler}): - return func(*args, **kwargs) + with dask.config.set({"scheduler": scheduler}): # pyright: ignore[reportPrivateImportUsage] + out = func(*args, **kwargs) + + # Block until the graph materializes and reraise exceptions. This allows + # `pytest.raises` and `pytest.warns` to work as expected. Note that this would + # not work on scheduler='distributed', as it would not block. + return dask.persist(out, scheduler="threads")[0] # type: ignore[no-any-return,attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] return wrapper diff --git a/tests/test_lazy.py b/tests/test_lazy.py new file mode 100644 index 00000000..a01e31c7 --- /dev/null +++ b/tests/test_lazy.py @@ -0,0 +1,82 @@ +from types import ModuleType +from typing import NamedTuple + +import numpy as np +import pytest + +from array_api_extra import lazy_apply +from array_api_extra._lib import Backend +from array_api_extra._lib._testing import xp_assert_equal +from array_api_extra._lib._utils._typing import Array +from array_api_extra.testing import lazy_xp_function + +skip_as_numpy = [ + pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host transfer"), + pytest.mark.skip_xp_backend(Backend.SPARSE, reason="densification"), +] + + +@pytest.mark.parametrize("as_numpy", [False, pytest.param(True, marks=skip_as_numpy)]) +def test_lazy_apply_kwargs(xp: ModuleType, library: Backend, as_numpy: bool) -> None: + expect = np.ndarray if as_numpy or library is Backend.DASK else type(xp.asarray(0)) + + class NT(NamedTuple): + a: Array + + def f( + x: Array, + z: dict[str, list[Array] | tuple[Array, ...] | NT], + msg: str, + msgs: list[str], + ) -> Array: + assert isinstance(x, expect) + assert isinstance(z["foo"], NT) + assert isinstance(z["foo"].a, expect) + assert isinstance(z["bar"][0], expect) + assert isinstance(z["baz"][0], expect) + assert msg == "Hello World" + assert msgs[0] == "Hello World" + return x + + x = xp.asarray(0) + y = lazy_apply( # pyright: ignore[reportCallIssue] + f, + x, + z={"foo": NT(x), "bar": [x], "baz": (x,)}, + msg="Hello World", + msgs=["Hello World"], + shape=x.shape, + dtype=x.dtype, + as_numpy=as_numpy, + ) + xp_assert_equal(x, y) + + +class CustomError(Exception): + pass + + +def raises(x: Array) -> Array: + def eager(_: Array) -> Array: + msg = "Hello World" + raise CustomError(msg) + + return lazy_apply(eager, x, shape=x.shape, dtype=x.dtype) + + +lazy_xp_function(raises) + + +def test_lazy_apply_raises(xp: ModuleType, library: Backend) -> None: + x = xp.asarray(0) + + with pytest.raises( + # FIXME https://github.com/jax-ml/jax/issues/26102 + RuntimeError if library is Backend.JAX else CustomError, + match="Hello World", + ): + # Here we are disregarding the return value, which would + # normally cause the graph not to materialize and the + # exception not to be raised. + # However, lazy_xp_function will do it for us on function exit. + raises(x) From ccffddfb5caa2ef6e3339d527eb2bf62dcc0a511 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 29 Jan 2025 18:02:19 +0000 Subject: [PATCH 2/2] ENH: New functions `lazy_raise`, `lazy_warn`, and `lazy_wait_on` --- docs/api-lazy.md | 3 + docs/conf.py | 1 + pixi.lock | 86 +++-- pyproject.toml | 5 +- src/array_api_extra/__init__.py | 5 +- src/array_api_extra/_lib/_lazy.py | 327 ++++++++++++++++++++ src/array_api_extra/_lib/_utils/_compat.py | 3 + src/array_api_extra/_lib/_utils/_compat.pyi | 1 + vendor_tests/test_vendor.py | 2 + 9 files changed, 386 insertions(+), 47 deletions(-) diff --git a/docs/api-lazy.md b/docs/api-lazy.md index 150fd62e..2052fbae 100644 --- a/docs/api-lazy.md +++ b/docs/api-lazy.md @@ -10,6 +10,9 @@ lazy backends, e.g. Dask or Jax: :toctree: generated lazy_apply + lazy_raise + lazy_wait_on + lazy_warn testing.lazy_xp_function testing.patch_lazy_xp_functions ``` diff --git a/docs/conf.py b/docs/conf.py index 4696e7a6..54b1e8b8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -56,6 +56,7 @@ "numpy": ("https://numpy.org/doc/stable", None), "dask": ("https://docs.dask.org/en/stable", None), "jax": ("https://jax.readthedocs.io/en/latest", None), + "equinox": ("https://docs.kidger.site/equinox/", None), } nitpick_ignore = [ diff --git a/pixi.lock b/pixi.lock index e8da46d4..8a5fd01a 100644 --- a/pixi.lock +++ b/pixi.lock @@ -9,7 +9,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.12.14-hbcca054_0.conda - conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_2.conda @@ -30,9 +29,9 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.12.14-hf0a4a13_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.6.4-h286801f_0.conda @@ -46,9 +45,9 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.12.14-h56e8100_0.conda - conda: https://prefix.dev/conda-forge/win-64/libexpat-2.6.4-he0c23c2_0.conda @@ -64,6 +63,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h5fd82a7_24.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-h6356254_24.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . dev: channels: @@ -75,7 +75,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.8-py312h7900ff3_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -324,10 +323,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.8-py312h81bd7bf_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -566,11 +565,11 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.8-py312h2e8e312_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -784,6 +783,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . dev-cuda: channels: @@ -795,7 +795,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.8-py312h7900ff3_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -1111,10 +1110,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.8-py312h81bd7bf_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -1353,11 +1352,11 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.8-py312h2e8e312_0.conda - conda: https://prefix.dev/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda @@ -1589,6 +1588,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . docs: channels: @@ -1600,7 +1600,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/linux-64/brotli-python-1.1.0-py312h2ec8cdc_2.conda @@ -1693,10 +1692,10 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/brotli-python-1.1.0-py312hde4cb15_2.conda @@ -1779,10 +1778,10 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/beautifulsoup4-4.12.3-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/win-64/brotli-python-1.1.0-py312h275cf98_2.conda @@ -1870,6 +1869,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . lint: channels: @@ -1881,7 +1881,6 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/astroid-3.3.8-py312h7900ff3_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda @@ -1991,10 +1990,10 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/astroid-3.3.8-py312h81bd7bf_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda @@ -2095,10 +2094,10 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/astroid-3.3.8-py312h2e8e312_0.conda - conda: https://prefix.dev/conda-forge/noarch/babel-2.16.0-pyhd8ed1ab_1.conda @@ -2201,6 +2200,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . tests: channels: @@ -2211,7 +2211,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.12.14-hbcca054_0.conda @@ -2256,9 +2255,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.12.14-hf0a4a13_0.conda @@ -2293,9 +2292,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.12.14-h56e8100_0.conda @@ -2334,6 +2333,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h5fd82a7_24.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-h6356254_24.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . tests-backends: channels: @@ -2344,7 +2344,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/aws-c-auth-0.8.1-h205f482_0.conda - conda: https://prefix.dev/conda-forge/linux-64/aws-c-cal-0.8.1-h1a47875_3.conda @@ -2527,9 +2526,9 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py310ha39cb0e_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/aws-c-auth-0.8.1-hfc2798a_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/aws-c-cal-0.8.1-hc8a0bd2_3.conda @@ -2703,10 +2702,10 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py310h2665a74_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/aws-c-auth-0.8.1-hd11252f_0.conda - conda: https://prefix.dev/conda-forge/win-64/aws-c-cal-0.8.1-h099ea23_3.conda @@ -2857,6 +2856,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py310he5e10e1_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . tests-cuda: channels: @@ -2867,7 +2867,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_kmp_llvm.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/attr-2.5.1-h166bdaf_1.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/aws-c-auth-0.8.1-h205f482_0.conda @@ -3117,9 +3116,9 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py310ha39cb0e_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/aws-c-auth-0.8.1-hfc2798a_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/aws-c-cal-0.8.1-hc8a0bd2_3.conda @@ -3293,10 +3292,10 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py310h2665a74_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - conda: https://prefix.dev/conda-forge/win-64/_openmp_mutex-4.5-2_gnu.conda - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/aws-c-auth-0.8.1-hd11252f_0.conda - conda: https://prefix.dev/conda-forge/win-64/aws-c-cal-0.8.1-h099ea23_3.conda @@ -3465,6 +3464,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py310he5e10e1_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . tests-py310: channels: @@ -3475,7 +3475,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.12.14-hbcca054_0.conda @@ -3515,9 +3514,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.12.14-hf0a4a13_0.conda @@ -3551,9 +3550,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.12.14-h56e8100_0.conda @@ -3591,6 +3590,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h5fd82a7_24.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-h6356254_24.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . tests-py313: channels: @@ -3601,7 +3601,6 @@ environments: linux-64: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2024.12.14-hbcca054_0.conda @@ -3641,9 +3640,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . osx-arm64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/bzip2-1.0.8-h99b78c6_7.conda - conda: https://prefix.dev/conda-forge/osx-arm64/ca-certificates-2024.12.14-hf0a4a13_0.conda @@ -3679,9 +3678,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . win-64: - - conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/bzip2-1.0.8-h2466b09_7.conda - conda: https://prefix.dev/conda-forge/win-64/ca-certificates-2024.12.14-h56e8100_0.conda @@ -3721,6 +3720,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-h5fd82a7_24.conda - conda: https://prefix.dev/conda-forge/win-64/vc14_runtime-14.42.34433-h6356254_24.conda - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda + - pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 - pypi: . packages: - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 @@ -3782,23 +3782,21 @@ packages: - pkg:pypi/alabaster?source=hash-mapping size: 18684 timestamp: 1733750512696 -- conda: https://prefix.dev/conda-forge/noarch/array-api-compat-1.10.0-pyhd8ed1ab_0.conda - sha256: c98308dcf035a413a635317c69b48143cdf4c5895853457062780395e5ea4633 - md5: e399bc184553ca13cb068d272a995f48 - depends: - - python >=3.9 - license: MIT - license_family: MIT - purls: - - pkg:pypi/array-api-compat?source=hash-mapping - size: 38442 - timestamp: 1735201429468 +- pypi: git+https://github.com/data-apis/array-api-compat#73f642637edfdb261f1e534e492da6d4e1d67ef3 + name: array-api-compat + version: 1.10.1.dev0 + requires_dist: + - cupy ; extra == 'cupy' + - dask ; extra == 'dask' + - jax ; extra == 'jax' + - numpy ; extra == 'numpy' + - pytorch ; extra == 'pytorch' + - sparse>=0.15.1 ; extra == 'sparse' + requires_python: '>=3.9' - pypi: . name: array-api-extra version: 0.6.1.dev0 - sha256: 1e032f707df46a29e306ede97d65b2129e0944b361b96317e5653bd74e695ce2 - requires_dist: - - array-api-compat>=1.10.0,<2 + sha256: 70480eafa2dcced1d14b9c12003f3f9cd97940b4353342ceb448637076f7288d requires_python: '>=3.10' editable: true - conda: https://prefix.dev/conda-forge/noarch/array-api-strict-2.2-pyhd8ed1ab_1.conda diff --git a/pyproject.toml b/pyproject.toml index 26a73fc6..94d8bc44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Typing :: Typed", ] dynamic = ["version"] -dependencies = ["array-api-compat>=1.10.0,<2"] +# dependencies = ["array-api-compat>=1.10.0,<2"] # DNM [project.urls] Homepage = "https://github.com/data-apis/array-api-extra" @@ -48,10 +48,11 @@ platforms = ["linux-64", "osx-arm64", "win-64"] [tool.pixi.dependencies] python = ">=3.10,<3.14" -array-api-compat = ">=1.10.0,<2" +# array-api-compat = ">=1.10.0,<2" # DNM [tool.pixi.pypi-dependencies] array-api-extra = { path = ".", editable = true } +array-api-compat = { git = "https://github.com/data-apis/array-api-compat" } # DNM [tool.pixi.feature.lint.dependencies] typing-extensions = "*" diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index aeedd9da..e3997709 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -12,7 +12,7 @@ setdiff1d, sinc, ) -from ._lib._lazy import lazy_apply +from ._lib._lazy import lazy_apply, lazy_raise, lazy_wait_on, lazy_warn __version__ = "0.6.1.dev0" @@ -27,6 +27,9 @@ "isclose", "kron", "lazy_apply", + "lazy_raise", + "lazy_wait_on", + "lazy_warn", "nunique", "pad", "setdiff1d", diff --git a/src/array_api_extra/_lib/_lazy.py b/src/array_api_extra/_lib/_lazy.py index 47a2cc83..53cbbb8c 100644 --- a/src/array_api_extra/_lib/_lazy.py +++ b/src/array_api_extra/_lib/_lazy.py @@ -4,6 +4,7 @@ from __future__ import annotations import math +import warnings from collections.abc import Callable, Sequence from functools import partial, wraps from types import ModuleType @@ -12,9 +13,11 @@ from ._utils._compat import ( array_namespace, is_array_api_obj, + is_dask_array, is_dask_namespace, is_jax_array, is_jax_namespace, + is_lazy_array, ) from ._utils._typing import Array, DType @@ -366,3 +369,327 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit] return (xp.asarray(out),) return wrapper + + +def lazy_raise( # numpydoc ignore=SA04 + x: Array, + cond: bool | Array, + exc: Exception, + *, + xp: ModuleType | None = None, +) -> Array: + """ + Raise an exception if an eager check fails on a lazy array. + + Consider this snippet:: + + >>> def f(x, xp): + ... if xp.any(x < 0): + ... raise ValueError("Some points are negative") + ... return x + 1 + + The above code fails to compile when x is a JAX array and the function is wrapped + by `jax.jit`; it is also extremely slow on Dask. Other lazy backends, e.g. ndonnx, + are also expected to misbehave. + + `xp.any(x < 0)` is a 0-dimensional array with `dtype=bool`; the `if` statement calls + `bool()` on the Array to convert it to a Python bool. + + On eager backends such as NumPy, this is not a problem. On Dask, `bool()` implicitly + triggers a computation of the whole graph so far; what's worse is that the + intermediate results are discarded to optimize memory usage, so when later on user + explicitly calls `compute()` on their final output, `x` is recalculated from + scratch. On JAX, `bool()` raises if its called code is wrapped by `jax.jit` for the + same reason. + + You should rewrite the above code as follows:: + + >>> def f(x, xp): + ... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative")) + ... return x + 1 + + When `xp` is eager, this is equivalent to the original code; if the error condition + resolves to True, the function raises immediately and the next line `return x + 1` + is never executed. + When `xp` is lazy, the function always returns a lazy array. When eventually the + user actually computes it, e.g. in Dask by calling `compute()` and in JAX by having + their outermost function decorated with `@jax.jit` return, only then the error + condition is evaluated. If True, the exception is raised and propagated as normal, + and the following nodes of the graph are never executed (so if the health check was + in place to prevent not only incorrect results but e.g. a segmentation fault, it's + still going to achieve its purpose). + + Parameters + ---------- + x : Array + Any one Array, potentially lazy, that is used later on to produce the value + returned by your function. + cond : bool | Array + Must be either a plain Python bool or a 0-dimensional Array with boolean dtype. + If True, raise the exception. If False, return x. + exc : Exception + The exception instance to be raised. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + Array + `x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered + to raise `exc` if `cond` is True. + + Raises + ------ + type(x) + If `cond` evaluates to True. + + Warnings + -------- + This function raises when x is eager, and quietly skips the check + when x is lazy:: + + >>> def f(x, xp): + ... lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative")) + ... return x + 1 + + And so does this one, as lazy_raise replaces `x` but it does so too late to + contribute to the return value:: + + >>> def f(x, xp): + ... y = x + 1 + ... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative")) + ... return y + + See Also + -------- + lazy_apply + lazy_warn + lazy_wait_on + dask.graph_manipulation.wait_on + equinox.error_if + + Notes + ----- + This function will raise if the :doc:`jax:transfer_guard` is active and `cond` is + a JAX array on a non-CPU device + (`jax-ml/jax#25995 `_). + """ + + def _lazy_raise(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01 + """Eager helper of `lazy_raise` running inside the lazy graph.""" + if cond: + raise exc + return x + + return _lazy_wait_on_impl(_lazy_raise, x, cond, xp=xp) + + +# Signature of warnings.warn copied from python/typeshed +@overload +def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 + x: Array, + cond: bool | Array, + message: str, + category: type[Warning] | None = None, + stacklevel: int = 1, + source: Any | None = None, + *, + xp: ModuleType | None = None, +) -> None: ... +@overload +def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08 + x: Array, + cond: bool | Array, + message: Warning, + category: Any = None, + stacklevel: int = 1, + source: Any | None = None, + *, + xp: ModuleType | None = None, +) -> None: ... + + +def lazy_warn( # type: ignore[no-any-explicit] # numpydoc ignore=SA04,PR04 + x: Array, + cond: bool | Array, + message: str | Warning, + category: Any = None, + stacklevel: int = 1, + source: Any | None = None, + *, + xp: ModuleType | None = None, +) -> Array: + """ + Call `warnings.warn` if an eager check fails on a lazy array. + + This functions works in the same way as `lazy_raise`; refer to it + for the detailed explanation. + + You should replace:: + + >>> def f(x, xp): + ... if xp.any(x < 0): + ... warnings.warn("Some points are negative", UserWarning, stacklevel=2) + ... return x + 1 + + with:: + + >>> def f(x, xp): + ... x = lazy_warn(x, xp.any(x < 0), + ... "Some points are negative", UserWarning, stacklevel=2) + ... return x + 1 + + Parameters + ---------- + x : Array + Any one Array, potentially lazy, that is used later on to produce the value + returned by your function. + cond : bool | Array + Must be either a plain Python bool or a 0-dimensional Array with boolean dtype. + If True, raise the exception. If False, return x. + message, category, stacklevel, source : + Parameters to `warnings.warn`. `stacklevel` is automatically increased to + compensate for the extra wrapper function. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + Array + `x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered + to issue the warning if `cond` is True. + + See Also + -------- + warnings.warn + lazy_apply + lazy_raise + lazy_wait_on + dask.graph_manipulation.wait_on + + Notes + ----- + This function will raise if the :doc:`jax:transfer_guard` is active and `cond` is + a JAX array on a non-CPU device + (`jax-ml/jax#25995 `_). + + On Dask, the warning is typically going to appear on the log of the + worker executing the function instead of on the client. + """ + + def _lazy_warn(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01 + """Eager helper of `lazy_raise` running inside the lazy graph.""" + if cond: + warnings.warn(message, category, stacklevel=stacklevel + 2, source=source) + return x + + return _lazy_wait_on_impl(_lazy_warn, x, cond, xp=xp) + + +def lazy_wait_on( + x: Array, *wait_on: ArrayLike, xp: ModuleType | None = None +) -> Array: # numpydoc ignore=SA04 + """ + Pause materialization of `x` until `wait_on` has been materialized. + + This is typically used to collect multiple calls to `lazy_raise` and/or + `lazy_warn` from validation functions that would otherwise return None. + If `wait_on` is not a lazy array, just return `x`. + + Read `lazy_raise` for detailed explanation. + + If you use this validation pattern for eager backends:: + + def validate(x, xp): + if xp.any(x < 10): + raise ValueError("Less than 10") + if xp.any(x > 20): + warnings.warn(UserWarning, "More than 20") + + def f(x, xp): + validate(x, xp=xp) + return x + 1 + + You should rewrite it as follows:: + + def validate(x, xp): + # Future that evaluates the checks. Contents are inconsequential. + # Avoid zero-sized arrays, as they may be elided by the graph optimizer. + future = xp.empty(1) + future = lazy_raise(future, xp.any(x < 10), ValueError("Less than 10")) + future = lazy_warn(future, xp.any(x > 20), UserWarning, "More than 20")) + return future + + def f(x, xp): + x = lazy_wait_on(x, validate(x, xp=xp), xp=xp) + return x + 1 + + Parameters + ---------- + x : Array + Any one Array, potentially lazy, that is used later on to produce the value + returned by your function. + *wait_on : ArrayLike + Zero or more objects. Block the materialization of `x` until all lazy arrays in + `wait_on` has been fully materialized. + Eager arrays, python bools and scalars, etc. are ignored. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + Array + `x`. If both `x` and `wait_on` are lazy arrays, the graph + underlying `x` is altered to wait until `wait_on` has been materialized. + If `wait_on` raises, the exception is propagated to `x`. + + See Also + -------- + lazy_apply + lazy_raise + lazy_warn + dask.graph_manipulation.wait_on + """ + xp = array_namespace(x, *wait_on) if xp is None else xp + + if is_dask_namespace(xp): + # Apply an arbitrary reduction so that + # a) all chunks of each of the wait_on objects are materialized, and + # b) the result is a 0-dimensional array, which doesn't interfere with + # map_blocks in _lazy_wait_on_impl. + # + # For all other backends, _lazy_wait_on_impl calls lazy_apply, which can be told + # to disregard the shape of wait_on, so we can skip the reduction. + # + # Dask offers `dask.graph_manipulation.bind` that does exactly the same thing as + # `lazy_wait_on`. As of 2025.1, however, dask.array is in the middle of + # transitioning from HighLevelGraph to dask_expr, and dask.graph_manipulation + # hasn't been migrated yet. + wait_on = tuple(xp.any(w) for w in wait_on if is_dask_array(w)) + + def _lazy_wait_on(x: Array, *_: Array) -> Array: # numpydoc ignore=PR01,RT01 + """Eager helper of `lazy_wait_on` running inside the lazy graph.""" + return x + + return _lazy_wait_on_impl(_lazy_wait_on, x, *wait_on, xp=xp) + + +def _lazy_wait_on_impl( # type: ignore[no-any-explicit] # numpydoc ignore=PR01,RT01 + eager_func: Callable[..., Array], + x: Array, + *wait_on: ArrayLike, + xp: ModuleType | None, +) -> Array: + """Implementation of lazy_raise, lazy_warn, and lazy_wait_on.""" + if not any(is_lazy_array(w) for w in wait_on): + return eager_func(x, *wait_on) + + xp = array_namespace(x, *wait_on) if xp is None else xp + + if is_dask_namespace(xp): + # lazy_apply would rechunk x + # Note that wait_on here are always 0-dimensional, as we special-cased + # them away in lazy_wait_on when there is a chance that they aren't. + return xp.map_blocks(eager_func, x, *wait_on, dtype=x.dtype, meta=x._meta) # pylint: disable=protected-access + + return lazy_apply(eager_func, x, *wait_on, shape=x.shape, dtype=x.dtype, xp=xp) diff --git a/src/array_api_extra/_lib/_utils/_compat.py b/src/array_api_extra/_lib/_utils/_compat.py index 34958149..b9997450 100644 --- a/src/array_api_extra/_lib/_utils/_compat.py +++ b/src/array_api_extra/_lib/_utils/_compat.py @@ -14,6 +14,7 @@ is_dask_namespace, is_jax_array, is_jax_namespace, + is_lazy_array, is_numpy_array, is_numpy_namespace, is_pydata_sparse_array, @@ -35,6 +36,7 @@ is_dask_namespace, is_jax_array, is_jax_namespace, + is_lazy_array, is_numpy_array, is_numpy_namespace, is_pydata_sparse_array, @@ -56,6 +58,7 @@ "is_dask_namespace", "is_jax_array", "is_jax_namespace", + "is_lazy_array", "is_numpy_array", "is_numpy_namespace", "is_pydata_sparse_array", diff --git a/src/array_api_extra/_lib/_utils/_compat.pyi b/src/array_api_extra/_lib/_utils/_compat.pyi index 5c8b6260..1f585a38 100644 --- a/src/array_api_extra/_lib/_utils/_compat.pyi +++ b/src/array_api_extra/_lib/_utils/_compat.pyi @@ -32,5 +32,6 @@ def is_jax_array(x: object, /) -> bool: ... def is_numpy_array(x: object, /) -> bool: ... def is_pydata_sparse_array(x: object, /) -> bool: ... def is_torch_array(x: object, /) -> bool: ... +def is_lazy_array(x: object, /) -> bool: ... def is_writeable_array(x: object, /) -> bool: ... def size(x: Array, /) -> int | None: ... diff --git a/vendor_tests/test_vendor.py b/vendor_tests/test_vendor.py index 7aaa9eba..4613edc7 100644 --- a/vendor_tests/test_vendor.py +++ b/vendor_tests/test_vendor.py @@ -14,6 +14,7 @@ def test_vendor_compat(): is_dask_namespace, is_jax_array, is_jax_namespace, + is_lazy_array, is_numpy_array, is_numpy_namespace, is_pydata_sparse_array, @@ -35,6 +36,7 @@ def test_vendor_compat(): assert not is_dask_namespace(xp) assert not is_jax_array(x) assert not is_jax_namespace(xp) + assert not is_lazy_array(x) assert not is_numpy_array(x) assert not is_numpy_namespace(xp) assert not is_pydata_sparse_array(x)