12
12
from ._at import at
13
13
from ._utils import _compat , _helpers
14
14
from ._utils ._compat import array_namespace , is_jax_array
15
- from ._utils ._helpers import asarrays
15
+ from ._utils ._helpers import asarrays , ndindex
16
16
from ._utils ._typing import Array
17
17
18
18
__all__ = [
29
29
30
30
31
31
def atleast_nd (x : Array , / , * , ndim : int , xp : ModuleType | None = None ) -> Array :
32
- """
33
- Recursively expand the dimension of an array to at least `ndim`.
32
+ """Recursively expand the dimension of an array to at least `ndim`.
34
33
35
34
Parameters
36
35
----------
@@ -72,8 +71,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
72
71
73
72
74
73
def cov (m : Array , / , * , xp : ModuleType | None = None ) -> Array :
75
- """
76
- Estimate a covariance matrix.
74
+ """Estimate a covariance matrix.
77
75
78
76
Covariance indicates the level to which two variables vary together.
79
77
If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
@@ -166,13 +164,12 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
166
164
def create_diagonal (
167
165
x : Array , / , * , offset : int = 0 , xp : ModuleType | None = None
168
166
) -> Array :
169
- """
170
- Construct a diagonal array.
167
+ """Construct a diagonal array.
171
168
172
169
Parameters
173
170
----------
174
171
x : array
175
- A 1-D array.
172
+ An array having shape (*broadcast_dims, k) .
176
173
offset : int, optional
177
174
Offset from the leading diagonal (default is ``0``).
178
175
Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +180,8 @@ def create_diagonal(
183
180
Returns
184
181
-------
185
182
array
186
- A 2-D array with `x` on the diagonal (offset by `offset`).
183
+ An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
184
+ on the diagonal (offset by `offset`).
187
185
188
186
Examples
189
187
--------
@@ -206,25 +204,27 @@ def create_diagonal(
206
204
if xp is None :
207
205
xp = array_namespace (x )
208
206
209
- if x .ndim != 1 :
210
- err_msg = "`x` must be 1-dimensional."
207
+ if x .ndim == 0 :
208
+ err_msg = "`x` must be at least 1-dimensional."
211
209
raise ValueError (err_msg )
212
- n = x .shape [0 ] + abs (offset )
213
- diag = xp .zeros (n ** 2 , dtype = x .dtype , device = _compat .device (x ))
214
-
215
- start = offset if offset >= 0 else abs (offset ) * n
216
- stop = min (n * (n - offset ), diag .shape [0 ])
217
- step = n + 1
218
- diag = at (diag )[start :stop :step ].set (x )
219
-
220
- return xp .reshape (diag , (n , n ))
210
+ pre = x .shape [:- 1 ]
211
+ n = x .shape [- 1 ] + abs (offset )
212
+ diag = xp .zeros ((* pre , n ** 2 ), dtype = x .dtype , device = _compat .device (x ))
213
+
214
+ target_slice = slice (
215
+ offset if offset >= 0 else abs (offset ) * n ,
216
+ min (n * (n - offset ), diag .shape [- 1 ]),
217
+ n + 1 ,
218
+ )
219
+ for index in ndindex (* pre ):
220
+ diag = at (diag )[(* index , target_slice )].set (x [* index , :])
221
+ return xp .reshape (diag , (* pre , n , n ))
221
222
222
223
223
224
def expand_dims (
224
225
a : Array , / , * , axis : int | tuple [int , ...] = (0 ,), xp : ModuleType | None = None
225
226
) -> Array :
226
- """
227
- Expand the shape of an array.
227
+ """Expand the shape of an array.
228
228
229
229
Insert (a) new axis/axes that will appear at the position(s) specified by
230
230
`axis` in the expanded array shape.
0 commit comments