7
7
from ._lib ._typing import Array , ModuleType
8
8
9
9
from ._lib import _utils
10
+ from ._lib ._compat import array_namespace
10
11
11
12
__all__ = [
12
13
"atleast_nd" ,
19
20
]
20
21
21
22
22
- def atleast_nd (x : Array , / , * , ndim : int , xp : ModuleType ) -> Array :
23
+ def atleast_nd (x : Array , / , * , ndim : int , xp : ModuleType | None = None ) -> Array :
23
24
"""
24
25
Recursively expand the dimension of an array to at least `ndim`.
25
26
@@ -28,8 +29,8 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
28
29
x : array
29
30
ndim : int
30
31
The minimum number of dimensions for the result.
31
- xp : array_namespace
32
- The standard-compatible namespace for `x`.
32
+ xp : array_namespace, optional
33
+ The standard-compatible namespace for `x`. Default: infer
33
34
34
35
Returns
35
36
-------
@@ -53,13 +54,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
53
54
True
54
55
55
56
"""
57
+ if xp is None :
58
+ xp = array_namespace (x )
59
+
56
60
if x .ndim < ndim :
57
61
x = xp .expand_dims (x , axis = 0 )
58
62
x = atleast_nd (x , ndim = ndim , xp = xp )
59
63
return x
60
64
61
65
62
- def cov (m : Array , / , * , xp : ModuleType ) -> Array :
66
+ def cov (m : Array , / , * , xp : ModuleType | None = None ) -> Array :
63
67
"""
64
68
Estimate a covariance matrix.
65
69
@@ -77,8 +81,8 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
77
81
A 1-D or 2-D array containing multiple variables and observations.
78
82
Each row of `m` represents a variable, and each column a single
79
83
observation of all those variables.
80
- xp : array_namespace
81
- The standard-compatible namespace for `m`.
84
+ xp : array_namespace, optional
85
+ The standard-compatible namespace for `m`. Default: infer
82
86
83
87
Returns
84
88
-------
@@ -125,6 +129,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
125
129
Array(2.14413333, dtype=array_api_strict.float64)
126
130
127
131
"""
132
+ if xp is None :
133
+ xp = array_namespace (m )
134
+
128
135
m = xp .asarray (m , copy = True )
129
136
dtype = (
130
137
xp .float64 if xp .isdtype (m .dtype , "integral" ) else xp .result_type (m , xp .float64 )
@@ -150,7 +157,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
150
157
return xp .squeeze (c , axis = axes )
151
158
152
159
153
- def create_diagonal (x : Array , / , * , offset : int = 0 , xp : ModuleType ) -> Array :
160
+ def create_diagonal (
161
+ x : Array , / , * , offset : int = 0 , xp : ModuleType | None = None
162
+ ) -> Array :
154
163
"""
155
164
Construct a diagonal array.
156
165
@@ -162,8 +171,8 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
162
171
Offset from the leading diagonal (default is ``0``).
163
172
Use positive ints for diagonals above the leading diagonal,
164
173
and negative ints for diagonals below the leading diagonal.
165
- xp : array_namespace
166
- The standard-compatible namespace for `x`.
174
+ xp : array_namespace, optional
175
+ The standard-compatible namespace for `x`. Default: infer
167
176
168
177
Returns
169
178
-------
@@ -189,6 +198,9 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
189
198
[0, 0, 8, 0, 0]], dtype=array_api_strict.int64)
190
199
191
200
"""
201
+ if xp is None :
202
+ xp = array_namespace (x )
203
+
192
204
if x .ndim != 1 :
193
205
err_msg = "`x` must be 1-dimensional."
194
206
raise ValueError (err_msg )
@@ -200,7 +212,7 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
200
212
201
213
202
214
def expand_dims (
203
- a : Array , / , * , axis : int | tuple [int , ...] = (0 ,), xp : ModuleType
215
+ a : Array , / , * , axis : int | tuple [int , ...] = (0 ,), xp : ModuleType | None = None
204
216
) -> Array :
205
217
"""
206
218
Expand the shape of an array.
@@ -220,8 +232,8 @@ def expand_dims(
220
232
given by a positive index could also be referred to by a negative index -
221
233
that will also result in an error).
222
234
Default: ``(0,)``.
223
- xp : array_namespace
224
- The standard-compatible namespace for `a`.
235
+ xp : array_namespace, optional
236
+ The standard-compatible namespace for `a`. Default: infer
225
237
226
238
Returns
227
239
-------
@@ -265,6 +277,9 @@ def expand_dims(
265
277
[2]]], dtype=array_api_strict.int64)
266
278
267
279
"""
280
+ if xp is None :
281
+ xp = array_namespace (a )
282
+
268
283
if not isinstance (axis , tuple ):
269
284
axis = (axis ,)
270
285
ndim = a .ndim + len (axis )
@@ -282,7 +297,7 @@ def expand_dims(
282
297
return a
283
298
284
299
285
- def kron (a : Array , b : Array , / , * , xp : ModuleType ) -> Array :
300
+ def kron (a : Array , b : Array , / , * , xp : ModuleType | None = None ) -> Array :
286
301
"""
287
302
Kronecker product of two arrays.
288
303
@@ -294,8 +309,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
294
309
Parameters
295
310
----------
296
311
a, b : array
297
- xp : array_namespace
298
- The standard-compatible namespace for `a` and `b`.
312
+ xp : array_namespace, optional
313
+ The standard-compatible namespace for `a` and `b`. Default: infer
299
314
300
315
Returns
301
316
-------
@@ -357,6 +372,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
357
372
Array(True, dtype=array_api_strict.bool)
358
373
359
374
"""
375
+ if xp is None :
376
+ xp = array_namespace (a , b )
360
377
361
378
b = xp .asarray (b )
362
379
singletons = (1 ,) * (b .ndim - a .ndim )
@@ -390,7 +407,12 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
390
407
391
408
392
409
def setdiff1d (
393
- x1 : Array , x2 : Array , / , * , assume_unique : bool = False , xp : ModuleType
410
+ x1 : Array ,
411
+ x2 : Array ,
412
+ / ,
413
+ * ,
414
+ assume_unique : bool = False ,
415
+ xp : ModuleType | None = None ,
394
416
) -> Array :
395
417
"""
396
418
Find the set difference of two arrays.
@@ -406,8 +428,8 @@ def setdiff1d(
406
428
assume_unique : bool
407
429
If ``True``, the input arrays are both assumed to be unique, which
408
430
can speed up the calculation. Default is ``False``.
409
- xp : array_namespace
410
- The standard-compatible namespace for `x1` and `x2`.
431
+ xp : array_namespace, optional
432
+ The standard-compatible namespace for `x1` and `x2`. Default: infer
411
433
412
434
Returns
413
435
-------
@@ -427,6 +449,8 @@ def setdiff1d(
427
449
Array([1, 2], dtype=array_api_strict.int64)
428
450
429
451
"""
452
+ if xp is None :
453
+ xp = array_namespace (x1 , x2 )
430
454
431
455
if assume_unique :
432
456
x1 = xp .reshape (x1 , (- 1 ,))
@@ -436,7 +460,7 @@ def setdiff1d(
436
460
return x1 [_utils .in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
437
461
438
462
439
- def sinc (x : Array , / , * , xp : ModuleType ) -> Array :
463
+ def sinc (x : Array , / , * , xp : ModuleType | None = None ) -> Array :
440
464
r"""
441
465
Return the normalized sinc function.
442
466
@@ -456,8 +480,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
456
480
x : array
457
481
Array (possibly multi-dimensional) of values for which to calculate
458
482
``sinc(x)``. Must have a real floating point dtype.
459
- xp : array_namespace
460
- The standard-compatible namespace for `x`.
483
+ xp : array_namespace, optional
484
+ The standard-compatible namespace for `x`. Default: infer
461
485
462
486
Returns
463
487
-------
@@ -511,6 +535,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
511
535
-3.89817183e-17], dtype=array_api_strict.float64)
512
536
513
537
"""
538
+ if xp is None :
539
+ xp = array_namespace (x )
540
+
514
541
if not xp .isdtype (x .dtype , "real floating" ):
515
542
err_msg = "`x` must have a real floating data type."
516
543
raise ValueError (err_msg )
0 commit comments