Skip to content

Commit de8e38e

Browse files
committed
Make create_diagonal support broadcasting
1 parent 27b0bf2 commit de8e38e

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

src/array_api_extra/_lib/_funcs.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from types import ModuleType
1010
from typing import cast
1111

12+
import numpy as np
13+
1214
from ._at import at
1315
from ._utils import _compat, _helpers
1416
from ._utils._compat import array_namespace, is_jax_array
@@ -172,7 +174,7 @@ def create_diagonal(
172174
Parameters
173175
----------
174176
x : array
175-
A 1-D array.
177+
An array having shape (*broadcast_dims, k).
176178
offset : int, optional
177179
Offset from the leading diagonal (default is ``0``).
178180
Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +185,8 @@ def create_diagonal(
183185
Returns
184186
-------
185187
array
186-
A 2-D array with `x` on the diagonal (offset by `offset`).
188+
An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
189+
on the diagonal (offset by `offset`).
187190
188191
Examples
189192
--------
@@ -206,18 +209,20 @@ def create_diagonal(
206209
if xp is None:
207210
xp = array_namespace(x)
208211

209-
if x.ndim != 1:
210-
err_msg = "`x` must be 1-dimensional."
212+
if x.ndim == 0:
213+
err_msg = "`x` must be at least 1-dimensional."
211214
raise ValueError(err_msg)
212-
n = x.shape[0] + abs(offset)
213-
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
214-
215-
start = offset if offset >= 0 else abs(offset) * n
216-
stop = min(n * (n - offset), diag.shape[0])
217-
step = n + 1
218-
diag = at(diag)[start:stop:step].set(x)
219-
220-
return xp.reshape(diag, (n, n))
215+
pre = x.shape[:-1]
216+
n = x.shape[-1] + abs(offset)
217+
diag = xp.zeros((*pre, n ** 2), dtype=x.dtype, device=_compat.device(x))
218+
219+
target_slice = slice(offset if offset >= 0 else abs(offset) * n,
220+
min(n * (n - offset), diag.shape[0]),
221+
n + 1)
222+
for index in np.ndindex(*pre):
223+
indexed_x = x[*index, :]
224+
diag = at(diag)[(*index, target_slice)].set(indexed_x)
225+
return xp.reshape(diag, (*pre, n, n))
221226

222227

223228
def expand_dims(

tests/test_funcs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ def test_0d(self, xp: ModuleType):
194194
create_diagonal(xp.asarray(1))
195195

196196
def test_2d(self, xp: ModuleType):
197-
with pytest.raises(ValueError, match="1-dimensional"):
198-
create_diagonal(xp.asarray([[1]]))
197+
result = create_diagonal(xp.asarray([[1]], dtype=xp.float64))
198+
xp_assert_equal(result, xp.asarray([[[1]]], dtype=xp.float64))
199199

200200
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
201201
def test_device(self, xp: ModuleType, device: Device):

0 commit comments

Comments
 (0)