From ea42de83dd2953d97b34d3bb8f77094c8becda22 Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 7 Feb 2025 02:22:14 -0500 Subject: [PATCH] Make create_diagonal support broadcasting --- src/array_api_extra/_lib/_funcs.py | 32 ++++++++++--------- src/array_api_extra/_lib/_utils/_helpers.py | 24 ++++++++++++++ tests/test_funcs.py | 35 +++++++++++++++++---- tests/test_helpers.py | 9 +++++- 4 files changed, 79 insertions(+), 21 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index bd1b5f0..2c2e5a7 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -12,7 +12,7 @@ from ._at import at from ._utils import _compat, _helpers from ._utils._compat import array_namespace, is_jax_array -from ._utils._helpers import asarrays +from ._utils._helpers import asarrays, ndindex from ._utils._typing import Array __all__ = [ @@ -236,7 +236,7 @@ def create_diagonal( Parameters ---------- x : array - A 1-D array. + An array having shape ``(*batch_dims, k)``. offset : int, optional Offset from the leading diagonal (default is ``0``). Use positive ints for diagonals above the leading diagonal, @@ -247,7 +247,8 @@ def create_diagonal( Returns ------- array - A 2-D array with `x` on the diagonal (offset by `offset`). + An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x` + on the diagonal (offset by `offset`). Examples -------- @@ -270,18 +271,21 @@ def create_diagonal( if xp is None: xp = array_namespace(x) - if x.ndim != 1: - err_msg = "`x` must be 1-dimensional." + if x.ndim == 0: + err_msg = "`x` must be at least 1-dimensional." raise ValueError(err_msg) - n = x.shape[0] + abs(offset) - diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x)) - - start = offset if offset >= 0 else abs(offset) * n - stop = min(n * (n - offset), diag.shape[0]) - step = n + 1 - diag = at(diag)[start:stop:step].set(x) - - return xp.reshape(diag, (n, n)) + batch_dims = x.shape[:-1] + n = x.shape[-1] + abs(offset) + diag = xp.zeros((*batch_dims, n**2), dtype=x.dtype, device=_compat.device(x)) + + target_slice = slice( + offset if offset >= 0 else abs(offset) * n, + min(n * (n - offset), diag.shape[-1]), + n + 1, + ) + for index in ndindex(*batch_dims): + diag = at(diag)[(*index, target_slice)].set(x[(*index, slice(None))]) + return xp.reshape(diag, (*batch_dims, n, n)) def expand_dims( diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index accae89..e8419f2 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -3,6 +3,7 @@ # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 from __future__ import annotations +from collections.abc import Generator from types import ModuleType from typing import cast @@ -175,3 +176,26 @@ def asarrays( xa, xb = xp.asarray(a), xp.asarray(b) return (xb, xa) if swap else (xa, xb) + + +def ndindex(*x: int) -> Generator[tuple[int, ...]]: + """ + Generate all N-dimensional indices for a given array shape. + + Given the shape of an array, an ndindex instance iterates over the N-dimensional + index of the array. At each iteration a tuple of indices is returned, the last + dimension is iterated over first. + + This has an identical API to numpy.ndindex. + + Parameters + ---------- + *x : int + The shape of the array. + """ + if not x: + yield () + return + for i in ndindex(*x[:-1]): + for j in range(x[-1]): + yield *i, j diff --git a/tests/test_funcs.py b/tests/test_funcs.py index f7a2c4f..e24a7c0 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -23,6 +23,7 @@ from array_api_extra._lib import Backend from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal from array_api_extra._lib._utils._compat import device as get_device +from array_api_extra._lib._utils._helpers import ndindex from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function @@ -221,7 +222,7 @@ def test_xp(self, xp: ModuleType): class TestCreateDiagonal: @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()") - def test_1d(self, xp: ModuleType): + def test_1d_from_numpy(self, xp: ModuleType): # from np.diag tests vals = 100 * xp.arange(5, dtype=xp.float64) b = xp.zeros((5, 5), dtype=xp.float64) @@ -239,7 +240,7 @@ def test_1d(self, xp: ModuleType): @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()") @pytest.mark.parametrize("n", range(1, 10)) @pytest.mark.parametrize("offset", range(1, 10)) - def test_create_diagonal(self, xp: ModuleType, n: int, offset: int): + def test_1d_from_scipy(self, xp: ModuleType, n: int, offset: int): # from scipy._lib tests rng = np.random.default_rng(2347823) one = xp.asarray(1.0) @@ -248,13 +249,35 @@ def test_create_diagonal(self, xp: ModuleType, n: int, offset: int): B = xp.asarray(np.diag(x, offset), dtype=one.dtype) xp_assert_equal(A, B) - def test_0d(self, xp: ModuleType): + def test_0d_raises(self, xp: ModuleType): with pytest.raises(ValueError, match="1-dimensional"): create_diagonal(xp.asarray(1)) - def test_2d(self, xp: ModuleType): - with pytest.raises(ValueError, match="1-dimensional"): - create_diagonal(xp.asarray([[1]])) + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()") + @pytest.mark.parametrize( + "shape", + [ + (0,), + (10,), + (0, 1), + (1, 0), + (0, 0), + (4, 2, 1), + (1, 1, 7), + (0, 0, 1), + (3, 2, 4, 5), + ], + ) + def test_nd(self, xp: ModuleType, shape: tuple[int, ...]): + rng = np.random.default_rng(2347823) + b = xp.asarray( + rng.integers((1 << 64) - 1, size=shape, dtype=np.uint64), dtype=xp.uint64 + ) + c = create_diagonal(b) + zero = xp.zeros((), dtype=xp.uint64) + assert c.shape == (*b.shape, b.shape[-1]) + for i in ndindex(*c.shape): + xp_assert_equal(c[i], b[i[:-1]] if i[-2] == i[-1] else zero) @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()") def test_device(self, xp: ModuleType, device: Device): diff --git a/tests/test_helpers.py b/tests/test_helpers.py index f0d8224..1710ff8 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -6,7 +6,7 @@ from array_api_extra._lib import Backend from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import device as get_device -from array_api_extra._lib._utils._helpers import asarrays, in1d +from array_api_extra._lib._utils._helpers import asarrays, in1d, ndindex from array_api_extra._lib._utils._typing import Device from array_api_extra.testing import lazy_xp_function @@ -148,3 +148,10 @@ def test_numpy_generics(self, dtype: type): xa, xb = asarrays(a, 0, xp=np) assert xa.dtype == dtype assert xb.dtype == dtype + + +@pytest.mark.parametrize( + "shape", [(), (1,), (5,), (2, 3), (5, 3, 8), (0,), (3, 0), (0, 0, 1)] +) +def test_ndindex(shape: tuple[int, ...]): + assert tuple(ndindex(*shape)) == tuple(np.ndindex(*shape))