9
9
from types import ModuleType
10
10
from typing import cast
11
11
12
+ import numpy as np
13
+
12
14
from ._at import at
13
15
from ._utils import _compat , _helpers
14
16
from ._utils ._compat import array_namespace , is_jax_array
@@ -172,7 +174,7 @@ def create_diagonal(
172
174
Parameters
173
175
----------
174
176
x : array
175
- A 1-D array.
177
+ An array having shape (*broadcast_dims, k) .
176
178
offset : int, optional
177
179
Offset from the leading diagonal (default is ``0``).
178
180
Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +185,8 @@ def create_diagonal(
183
185
Returns
184
186
-------
185
187
array
186
- A 2-D array with `x` on the diagonal (offset by `offset`).
188
+ An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
189
+ on the diagonal (offset by `offset`).
187
190
188
191
Examples
189
192
--------
@@ -206,18 +209,20 @@ def create_diagonal(
206
209
if xp is None :
207
210
xp = array_namespace (x )
208
211
209
- if x .ndim != 1 :
210
- err_msg = "`x` must be 1-dimensional."
212
+ if x .ndim == 0 :
213
+ err_msg = "`x` must be at least 1-dimensional."
211
214
raise ValueError (err_msg )
212
- n = x .shape [0 ] + abs (offset )
213
- diag = xp .zeros (n ** 2 , dtype = x .dtype , device = _compat .device (x ))
214
-
215
- start = offset if offset >= 0 else abs (offset ) * n
216
- stop = min (n * (n - offset ), diag .shape [0 ])
217
- step = n + 1
218
- diag = at (diag )[start :stop :step ].set (x )
219
-
220
- return xp .reshape (diag , (n , n ))
215
+ pre = x .shape [:- 1 ]
216
+ n = x .shape [- 1 ] + abs (offset )
217
+ diag = xp .zeros ((* pre , n ** 2 ), dtype = x .dtype , device = _compat .device (x ))
218
+
219
+ target_slice = slice (offset if offset >= 0 else abs (offset ) * n ,
220
+ min (n * (n - offset ), diag .shape [0 ]),
221
+ n + 1 )
222
+ for index in np .ndindex (* pre ):
223
+ indexed_x = x [* index , :]
224
+ diag = at (diag )[(* index , target_slice )].set (indexed_x )
225
+ return xp .reshape (diag , (* pre , n , n ))
221
226
222
227
223
228
def expand_dims (
0 commit comments