Skip to content

Commit ab821b0

Browse files
committed
Make create_diagonal support broadcasting
1 parent a71bd2e commit ab821b0

File tree

4 files changed

+65
-21
lines changed

4 files changed

+65
-21
lines changed

src/array_api_extra/_lib/_funcs.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ._at import at
1313
from ._utils import _compat, _helpers
1414
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays
15+
from ._utils._helpers import asarrays, ndindex
1616
from ._utils._typing import Array
1717

1818
__all__ = [
@@ -172,7 +172,7 @@ def create_diagonal(
172172
Parameters
173173
----------
174174
x : array
175-
A 1-D array.
175+
An array having shape (*broadcast_dims, k).
176176
offset : int, optional
177177
Offset from the leading diagonal (default is ``0``).
178178
Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +183,8 @@ def create_diagonal(
183183
Returns
184184
-------
185185
array
186-
A 2-D array with `x` on the diagonal (offset by `offset`).
186+
An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
187+
on the diagonal (offset by `offset`).
187188
188189
Examples
189190
--------
@@ -206,18 +207,21 @@ def create_diagonal(
206207
if xp is None:
207208
xp = array_namespace(x)
208209

209-
if x.ndim != 1:
210-
err_msg = "`x` must be 1-dimensional."
210+
if x.ndim == 0:
211+
err_msg = "`x` must be at least 1-dimensional."
211212
raise ValueError(err_msg)
212-
n = x.shape[0] + abs(offset)
213-
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
214-
215-
start = offset if offset >= 0 else abs(offset) * n
216-
stop = min(n * (n - offset), diag.shape[0])
217-
step = n + 1
218-
diag = at(diag)[start:stop:step].set(x)
219-
220-
return xp.reshape(diag, (n, n))
213+
pre = x.shape[:-1]
214+
n = x.shape[-1] + abs(offset)
215+
diag = xp.zeros((*pre, n**2), dtype=x.dtype, device=_compat.device(x))
216+
217+
target_slice = slice(
218+
offset if offset >= 0 else abs(offset) * n,
219+
min(n * (n - offset), diag.shape[-1]),
220+
n + 1,
221+
)
222+
for index in ndindex(*pre):
223+
diag = at(diag)[(*index, target_slice)].set(x[(*index, slice(None))])
224+
return xp.reshape(diag, (*pre, n, n))
221225

222226

223227
def expand_dims(

src/array_api_extra/_lib/_utils/_helpers.py

+24
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6+
from collections.abc import Generator
67
from types import ModuleType
78
from typing import cast
89

@@ -175,3 +176,26 @@ def asarrays(
175176
xa, xb = xp.asarray(a), xp.asarray(b)
176177

177178
return (xb, xa) if swap else (xa, xb)
179+
180+
181+
def ndindex(*x: int) -> Generator[tuple[int, ...]]:
182+
"""
183+
Generate all N-dimensional indices for a given array shape.
184+
185+
Given the shape of an array, an ndindex instance iterates over the N-dimensional
186+
index of the array. At each iteration a tuple of indices is returned, the last
187+
dimension is iterated over first.
188+
189+
This has an identical API to numpy.ndindex.
190+
191+
Parameters
192+
----------
193+
*x : int
194+
The shape of the array.
195+
"""
196+
if not x:
197+
yield ()
198+
return
199+
for i in ndindex(*x[:-1]):
200+
for j in range(x[-1]):
201+
yield *i, j

tests/test_funcs.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from array_api_extra._lib import Backend
2222
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
2323
from array_api_extra._lib._utils._compat import device as get_device
24+
from array_api_extra._lib._utils._helpers import ndindex
2425
from array_api_extra._lib._utils._typing import Array, Device
2526
from array_api_extra.testing import lazy_xp_function
2627

@@ -162,7 +163,7 @@ def test_xp(self, xp: ModuleType):
162163

163164
class TestCreateDiagonal:
164165
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
165-
def test_1d(self, xp: ModuleType):
166+
def test_1d_from_numpy(self, xp: ModuleType):
166167
# from np.diag tests
167168
vals = 100 * xp.arange(5, dtype=xp.float64)
168169
b = xp.zeros((5, 5), dtype=xp.float64)
@@ -180,7 +181,7 @@ def test_1d(self, xp: ModuleType):
180181
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
181182
@pytest.mark.parametrize("n", range(1, 10))
182183
@pytest.mark.parametrize("offset", range(1, 10))
183-
def test_create_diagonal(self, xp: ModuleType, n: int, offset: int):
184+
def test_1d_from_scipy(self, xp: ModuleType, n: int, offset: int):
184185
# from scipy._lib tests
185186
rng = np.random.default_rng(2347823)
186187
one = xp.asarray(1.0)
@@ -189,13 +190,22 @@ def test_create_diagonal(self, xp: ModuleType, n: int, offset: int):
189190
B = xp.asarray(np.diag(x, offset), dtype=one.dtype)
190191
xp_assert_equal(A, B)
191192

192-
def test_0d(self, xp: ModuleType):
193+
def test_0d_raises(self, xp: ModuleType):
193194
with pytest.raises(ValueError, match="1-dimensional"):
194195
create_diagonal(xp.asarray(1))
195196

196-
def test_2d(self, xp: ModuleType):
197-
with pytest.raises(ValueError, match="1-dimensional"):
198-
create_diagonal(xp.asarray([[1]]))
197+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
198+
@pytest.mark.parametrize("shape", [(0,), (10,), (0, 1), (1, 0), (0, 0), (4, 2, 1),
199+
(1, 1, 7), (0, 0, 1), (3, 2, 4, 5)])
200+
def test_nd(self, xp: ModuleType, shape: tuple[int, ...]):
201+
rng = np.random.default_rng(2347823)
202+
b = xp.asarray(rng.integers((1 << 64) - 1, size=shape, dtype=np.uint64),
203+
dtype=xp.uint64)
204+
c = create_diagonal(b)
205+
zero = xp.zeros((), dtype=xp.uint64)
206+
assert c.shape == (*b.shape, b.shape[-1])
207+
for i in ndindex(*c.shape):
208+
xp_assert_equal(c[i], b[i[:-1]] if i[-2] == i[-1] else zero)
199209

200210
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
201211
def test_device(self, xp: ModuleType, device: Device):

tests/test_utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from array_api_extra._lib import Backend
77
from array_api_extra._lib._testing import xp_assert_equal
88
from array_api_extra._lib._utils._compat import device as get_device
9-
from array_api_extra._lib._utils._helpers import asarrays, in1d
9+
from array_api_extra._lib._utils._helpers import asarrays, in1d, ndindex
1010
from array_api_extra._lib._utils._typing import Device
1111
from array_api_extra.testing import lazy_xp_function
1212

@@ -151,3 +151,9 @@ def test_asarrays_numpy_generics(dtype: type):
151151
xa, xb = asarrays(a, 0, xp=np)
152152
assert xa.dtype == dtype
153153
assert xb.dtype == dtype
154+
155+
156+
@pytest.mark.parametrize("shape", [(), (1,), (5,), (2, 3), (5, 3, 8), (0,), (3, 0),
157+
(0, 0, 1)])
158+
def test_ndindex(shape: tuple[int, ...]):
159+
assert tuple(ndindex(*shape)) == tuple(np.ndindex(*shape))

0 commit comments

Comments
 (0)