diff --git a/docs/api-reference.md b/docs/api-reference.md index 32205248..2483a55d 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -8,6 +8,7 @@ at atleast_nd + broadcast_shapes cov create_diagonal expand_dims diff --git a/docs/conf.py b/docs/conf.py index afa3bd5e..eff2a33d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -54,6 +54,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "array-api": ("https://data-apis.org/array-api/draft", None), + "numpy": ("https://numpy.org/doc/stable", None), "jax": ("https://jax.readthedocs.io/en/latest", None), } diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 840dd8e7..4a49fd48 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -4,6 +4,7 @@ from ._lib._at import at from ._lib._funcs import ( atleast_nd, + broadcast_shapes, cov, create_diagonal, expand_dims, @@ -20,6 +21,7 @@ "__version__", "at", "atleast_nd", + "broadcast_shapes", "cov", "create_diagonal", "expand_dims", diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index f7eb8c88..a5729559 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -17,6 +17,7 @@ __all__ = [ "atleast_nd", + "broadcast_shapes", "cov", "create_diagonal", "expand_dims", @@ -71,6 +72,69 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array return x +# `float` in signature to accept `math.nan` for Dask. +# `int`s are still accepted as `float` is a superclass of `int` in typing +def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]: + """ + Compute the shape of the broadcasted arrays. + + Duplicates :func:`numpy.broadcast_shapes`, with additional support for + None and NaN sizes. + + This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape`` + without needing to worry about the backend potentially deep copying + the arrays. + + Parameters + ---------- + *shapes : tuple[int | None, ...] + Shapes of the arrays to broadcast. + + Returns + ------- + tuple[int | None, ...] + The shape of the broadcasted arrays. + + See Also + -------- + numpy.broadcast_shapes : Equivalent NumPy function. + array_api.broadcast_arrays : Function to broadcast actual arrays. + + Notes + ----- + This function accepts the Array API's ``None`` for unknown sizes, + as well as Dask's non-standard ``math.nan``. + Regardless of input, the output always contains ``None`` for unknown sizes. + + Examples + -------- + >>> import array_api_extra as xpx + >>> xpx.broadcast_shapes((2, 3), (2, 1)) + (2, 3) + >>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3)) + (4, 2, 3) + """ + if not shapes: + return () # Match numpy output + + ndim = max(len(shape) for shape in shapes) + out: list[int | None] = [] + for axis in range(-ndim, 0): + sizes = {shape[axis] for shape in shapes if axis >= -len(shape)} + # Dask uses NaN for unknown shape, which predates the Array API spec for None + none_size = None in sizes or math.nan in sizes + sizes -= {1, None, math.nan} + if len(sizes) > 1: + msg = ( + "shape mismatch: objects cannot be broadcast to a single shape: " + f"{shapes}." + ) + raise ValueError(msg) + out.append(None if none_size else cast(int, sizes.pop()) if sizes else 1) + + return tuple(out) + + def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: """ Estimate a covariance matrix. diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 2b900a85..84d2f5d1 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -1,4 +1,5 @@ import contextlib +import math import warnings from types import ModuleType @@ -8,6 +9,7 @@ from array_api_extra import ( at, atleast_nd, + broadcast_shapes, cov, create_diagonal, expand_dims, @@ -113,6 +115,63 @@ def test_xp(self, xp: ModuleType): xp_assert_equal(y, xp.ones((1,))) +class TestBroadcastShapes: + @pytest.mark.parametrize( + "args", + [ + (), + ((),), + ((), ()), + ((1,),), + ((1,), (1,)), + ((2,), (1,)), + ((3, 1, 4), (2, 1)), + ((1, 1, 4), (2, 1)), + ((1,), ()), + ((), (2,), ()), + ((0,),), + ((0,), (1,)), + ((2, 0), (1, 1)), + ((2, 0, 3), (2, 1, 1)), + ], + ) + def test_simple(self, args: tuple[tuple[int, ...], ...]): + expect = np.broadcast_shapes(*args) + actual = broadcast_shapes(*args) + assert actual == expect + + @pytest.mark.parametrize( + "args", + [ + ((2,), (3,)), + ((2, 3), (1, 2)), + ((2,), (0,)), + ((2, 0, 2), (1, 3, 1)), + ], + ) + def test_fail(self, args: tuple[tuple[int, ...], ...]): + match = "cannot be broadcast to a single shape" + with pytest.raises(ValueError, match=match): + _ = np.broadcast_shapes(*args) + with pytest.raises(ValueError, match=match): + _ = broadcast_shapes(*args) + + @pytest.mark.parametrize( + "args", + [ + ((None,), (None,)), + ((math.nan,), (None,)), + ((1, None, 2, 4), (2, 3, None, 1), (2, None, None, 4)), + ((1, math.nan, 2), (4, 2, 3, math.nan), (4, 2, None, None)), + ((math.nan, 1), (None, 2), (None, 2)), + ], + ) + def test_none(self, args: tuple[tuple[float | None, ...], ...]): + expect = args[-1] + actual = broadcast_shapes(*args[:-1]) + assert actual == expect + + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype") class TestCov: def test_basic(self, xp: ModuleType):