Skip to content

Commit 0fd29e5

Browse files
committed
Make create_diagonal support broadcasting
1 parent e585c97 commit 0fd29e5

File tree

4 files changed

+79
-21
lines changed

4 files changed

+79
-21
lines changed

src/array_api_extra/_lib/_funcs.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ._at import at
1313
from ._utils import _compat, _helpers
1414
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays
15+
from ._utils._helpers import asarrays, ndindex
1616
from ._utils._typing import Array
1717

1818
__all__ = [
@@ -236,7 +236,7 @@ def create_diagonal(
236236
Parameters
237237
----------
238238
x : array
239-
A 1-D array.
239+
An array having shape ``(*batch_dims, k)``.
240240
offset : int, optional
241241
Offset from the leading diagonal (default is ``0``).
242242
Use positive ints for diagonals above the leading diagonal,
@@ -247,7 +247,8 @@ def create_diagonal(
247247
Returns
248248
-------
249249
array
250-
A 2-D array with `x` on the diagonal (offset by `offset`).
250+
An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x`
251+
on the diagonal (offset by `offset`).
251252
252253
Examples
253254
--------
@@ -270,18 +271,21 @@ def create_diagonal(
270271
if xp is None:
271272
xp = array_namespace(x)
272273

273-
if x.ndim != 1:
274-
err_msg = "`x` must be 1-dimensional."
274+
if x.ndim == 0:
275+
err_msg = "`x` must be at least 1-dimensional."
275276
raise ValueError(err_msg)
276-
n = x.shape[0] + abs(offset)
277-
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
278-
279-
start = offset if offset >= 0 else abs(offset) * n
280-
stop = min(n * (n - offset), diag.shape[0])
281-
step = n + 1
282-
diag = at(diag)[start:stop:step].set(x)
283-
284-
return xp.reshape(diag, (n, n))
277+
pre = x.shape[:-1]
278+
n = x.shape[-1] + abs(offset)
279+
diag = xp.zeros((*pre, n**2), dtype=x.dtype, device=_compat.device(x))
280+
281+
target_slice = slice(
282+
offset if offset >= 0 else abs(offset) * n,
283+
min(n * (n - offset), diag.shape[-1]),
284+
n + 1,
285+
)
286+
for index in ndindex(*pre):
287+
diag = at(diag)[(*index, target_slice)].set(x[(*index, slice(None))])
288+
return xp.reshape(diag, (*pre, n, n))
285289

286290

287291
def expand_dims(

src/array_api_extra/_lib/_utils/_helpers.py

+24
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6+
from collections.abc import Generator
67
from types import ModuleType
78
from typing import cast
89

@@ -175,3 +176,26 @@ def asarrays(
175176
xa, xb = xp.asarray(a), xp.asarray(b)
176177

177178
return (xb, xa) if swap else (xa, xb)
179+
180+
181+
def ndindex(*x: int) -> Generator[tuple[int, ...]]:
182+
"""
183+
Generate all N-dimensional indices for a given array shape.
184+
185+
Given the shape of an array, an ndindex instance iterates over the N-dimensional
186+
index of the array. At each iteration a tuple of indices is returned, the last
187+
dimension is iterated over first.
188+
189+
This has an identical API to numpy.ndindex.
190+
191+
Parameters
192+
----------
193+
*x : int
194+
The shape of the array.
195+
"""
196+
if not x:
197+
yield ()
198+
return
199+
for i in ndindex(*x[:-1]):
200+
for j in range(x[-1]):
201+
yield *i, j

tests/test_funcs.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from array_api_extra._lib import Backend
2424
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
2525
from array_api_extra._lib._utils._compat import device as get_device
26+
from array_api_extra._lib._utils._helpers import ndindex
2627
from array_api_extra._lib._utils._typing import Array, Device
2728
from array_api_extra.testing import lazy_xp_function
2829

@@ -221,7 +222,7 @@ def test_xp(self, xp: ModuleType):
221222

222223
class TestCreateDiagonal:
223224
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
224-
def test_1d(self, xp: ModuleType):
225+
def test_1d_from_numpy(self, xp: ModuleType):
225226
# from np.diag tests
226227
vals = 100 * xp.arange(5, dtype=xp.float64)
227228
b = xp.zeros((5, 5), dtype=xp.float64)
@@ -239,7 +240,7 @@ def test_1d(self, xp: ModuleType):
239240
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
240241
@pytest.mark.parametrize("n", range(1, 10))
241242
@pytest.mark.parametrize("offset", range(1, 10))
242-
def test_create_diagonal(self, xp: ModuleType, n: int, offset: int):
243+
def test_1d_from_scipy(self, xp: ModuleType, n: int, offset: int):
243244
# from scipy._lib tests
244245
rng = np.random.default_rng(2347823)
245246
one = xp.asarray(1.0)
@@ -248,13 +249,35 @@ def test_create_diagonal(self, xp: ModuleType, n: int, offset: int):
248249
B = xp.asarray(np.diag(x, offset), dtype=one.dtype)
249250
xp_assert_equal(A, B)
250251

251-
def test_0d(self, xp: ModuleType):
252+
def test_0d_raises(self, xp: ModuleType):
252253
with pytest.raises(ValueError, match="1-dimensional"):
253254
create_diagonal(xp.asarray(1))
254255

255-
def test_2d(self, xp: ModuleType):
256-
with pytest.raises(ValueError, match="1-dimensional"):
257-
create_diagonal(xp.asarray([[1]]))
256+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
257+
@pytest.mark.parametrize(
258+
"shape",
259+
[
260+
(0,),
261+
(10,),
262+
(0, 1),
263+
(1, 0),
264+
(0, 0),
265+
(4, 2, 1),
266+
(1, 1, 7),
267+
(0, 0, 1),
268+
(3, 2, 4, 5),
269+
],
270+
)
271+
def test_nd(self, xp: ModuleType, shape: tuple[int, ...]):
272+
rng = np.random.default_rng(2347823)
273+
b = xp.asarray(
274+
rng.integers((1 << 64) - 1, size=shape, dtype=np.uint64), dtype=xp.uint64
275+
)
276+
c = create_diagonal(b)
277+
zero = xp.zeros((), dtype=xp.uint64)
278+
assert c.shape == (*b.shape, b.shape[-1])
279+
for i in ndindex(*c.shape):
280+
xp_assert_equal(c[i], b[i[:-1]] if i[-2] == i[-1] else zero)
258281

259282
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
260283
def test_device(self, xp: ModuleType, device: Device):

tests/test_helpers.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from array_api_extra._lib import Backend
77
from array_api_extra._lib._testing import xp_assert_equal
88
from array_api_extra._lib._utils._compat import device as get_device
9-
from array_api_extra._lib._utils._helpers import asarrays, in1d
9+
from array_api_extra._lib._utils._helpers import asarrays, in1d, ndindex
1010
from array_api_extra._lib._utils._typing import Device
1111
from array_api_extra.testing import lazy_xp_function
1212

@@ -148,3 +148,10 @@ def test_numpy_generics(self, dtype: type):
148148
xa, xb = asarrays(a, 0, xp=np)
149149
assert xa.dtype == dtype
150150
assert xb.dtype == dtype
151+
152+
153+
@pytest.mark.parametrize(
154+
"shape", [(), (1,), (5,), (2, 3), (5, 3, 8), (0,), (3, 0), (0, 0, 1)]
155+
)
156+
def test_ndindex(shape: tuple[int, ...]):
157+
assert tuple(ndindex(*shape)) == tuple(np.ndindex(*shape))

0 commit comments

Comments
 (0)