Skip to content

Commit 0bbe9ff

Browse files
committed
Make create_diagonal support broadcasting
1 parent 27b0bf2 commit 0bbe9ff

File tree

3 files changed

+56
-24
lines changed

3 files changed

+56
-24
lines changed

src/array_api_extra/_lib/_funcs.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ._at import at
1313
from ._utils import _compat, _helpers
1414
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays
15+
from ._utils._helpers import asarrays, ndindex
1616
from ._utils._typing import Array
1717

1818
__all__ = [
@@ -29,8 +29,7 @@
2929

3030

3131
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
32-
"""
33-
Recursively expand the dimension of an array to at least `ndim`.
32+
"""Recursively expand the dimension of an array to at least `ndim`.
3433
3534
Parameters
3635
----------
@@ -72,8 +71,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
7271

7372

7473
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
75-
"""
76-
Estimate a covariance matrix.
74+
"""Estimate a covariance matrix.
7775
7876
Covariance indicates the level to which two variables vary together.
7977
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
@@ -166,13 +164,12 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
166164
def create_diagonal(
167165
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
168166
) -> Array:
169-
"""
170-
Construct a diagonal array.
167+
"""Construct a diagonal array.
171168
172169
Parameters
173170
----------
174171
x : array
175-
A 1-D array.
172+
An array having shape (*broadcast_dims, k).
176173
offset : int, optional
177174
Offset from the leading diagonal (default is ``0``).
178175
Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +180,8 @@ def create_diagonal(
183180
Returns
184181
-------
185182
array
186-
A 2-D array with `x` on the diagonal (offset by `offset`).
183+
An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
184+
on the diagonal (offset by `offset`).
187185
188186
Examples
189187
--------
@@ -206,25 +204,27 @@ def create_diagonal(
206204
if xp is None:
207205
xp = array_namespace(x)
208206

209-
if x.ndim != 1:
210-
err_msg = "`x` must be 1-dimensional."
207+
if x.ndim == 0:
208+
err_msg = "`x` must be at least 1-dimensional."
211209
raise ValueError(err_msg)
212-
n = x.shape[0] + abs(offset)
213-
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
214-
215-
start = offset if offset >= 0 else abs(offset) * n
216-
stop = min(n * (n - offset), diag.shape[0])
217-
step = n + 1
218-
diag = at(diag)[start:stop:step].set(x)
219-
220-
return xp.reshape(diag, (n, n))
210+
pre = x.shape[:-1]
211+
n = x.shape[-1] + abs(offset)
212+
diag = xp.zeros((*pre, n**2), dtype=x.dtype, device=_compat.device(x))
213+
214+
target_slice = slice(
215+
offset if offset >= 0 else abs(offset) * n,
216+
min(n * (n - offset), diag.shape[-1]),
217+
n + 1,
218+
)
219+
for index in ndindex(*pre):
220+
diag = at(diag)[(*index, target_slice)].set(x[*index, :])
221+
return xp.reshape(diag, (*pre, n, n))
221222

222223

223224
def expand_dims(
224225
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
225226
) -> Array:
226-
"""
227-
Expand the shape of an array.
227+
"""Expand the shape of an array.
228228
229229
Insert (a) new axis/axes that will appear at the position(s) specified by
230230
`axis` in the expanded array shape.

src/array_api_extra/_lib/_utils/_helpers.py

+23
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
44
from __future__ import annotations
55

6+
from collections.abc import Generator
67
from types import ModuleType
78
from typing import cast
89

@@ -175,3 +176,25 @@ def asarrays(
175176
xa, xb = xp.asarray(a), xp.asarray(b)
176177

177178
return (xb, xa) if swap else (xa, xb)
179+
180+
181+
def ndindex(*x: int) -> Generator[tuple[int, ...]]:
182+
"""Generate all N-dimensional indices for a given array shape.
183+
184+
Given the shape of an array, an ndindex instance iterates over the N-dimensional
185+
index of the array. At each iteration a tuple of indices is returned, the last
186+
dimension is iterated over first.
187+
188+
This has an identical API to numpy.ndindex.
189+
190+
Parameters
191+
----------
192+
x : int
193+
The shape of the array.
194+
"""
195+
if not x:
196+
yield ()
197+
return
198+
for i in ndindex(*x[:-1]):
199+
for j in range(x[-1]):
200+
yield *i, j

tests/test_funcs.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from array_api_extra._lib import Backend
2222
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
2323
from array_api_extra._lib._utils._compat import device as get_device
24+
from array_api_extra._lib._utils._helpers import ndindex
2425
from array_api_extra._lib._utils._typing import Array, Device
2526
from array_api_extra.testing import lazy_xp_function
2627

@@ -193,9 +194,17 @@ def test_0d(self, xp: ModuleType):
193194
with pytest.raises(ValueError, match="1-dimensional"):
194195
create_diagonal(xp.asarray(1))
195196

197+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
196198
def test_2d(self, xp: ModuleType):
197-
with pytest.raises(ValueError, match="1-dimensional"):
198-
create_diagonal(xp.asarray([[1]]))
199+
result = create_diagonal(xp.asarray([[1]]))
200+
xp_assert_equal(result, xp.asarray([[[1]]]))
201+
b = xp.zeros((3, 2, 4, 5), dtype=xp.int64)
202+
for i in ndindex(*b.shape):
203+
b = at(b)[*i].set(hash(i))
204+
c = create_diagonal(b)
205+
zero = xp.zeros((), dtype=xp.int64)
206+
for i in ndindex(*c.shape):
207+
xp_assert_equal(c[*i], b[*(i[:-1])] if i[-2] == i[-1] else zero)
199208

200209
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
201210
def test_device(self, xp: ModuleType, device: Device):

0 commit comments

Comments
 (0)