7
7
import inspect
8
8
from typing import NamedTuple , Optional , Sequence , Tuple , Union
9
9
10
- from ._helpers import array_namespace , _check_device , device , is_cupy_namespace
11
10
from ._typing import Array , Device , DType , Namespace
11
+ from ._helpers import (
12
+ array_namespace ,
13
+ _check_device ,
14
+ device as _get_device ,
15
+ is_cupy_namespace as _is_cupy_namespace
16
+ )
17
+
12
18
13
19
# These functions are modified from the NumPy versions.
14
20
@@ -298,7 +304,7 @@ def cumulative_sum(
298
304
initial_shape = list (x .shape )
299
305
initial_shape [axis ] = 1
300
306
res = xp .concatenate (
301
- [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
307
+ [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = _get_device (res )), res ],
302
308
axis = axis ,
303
309
)
304
310
return res
@@ -328,7 +334,7 @@ def cumulative_prod(
328
334
initial_shape = list (x .shape )
329
335
initial_shape [axis ] = 1
330
336
res = xp .concatenate (
331
- [wrapped_xp .ones (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
337
+ [wrapped_xp .ones (shape = initial_shape , dtype = res .dtype , device = _get_device (res )), res ],
332
338
axis = axis ,
333
339
)
334
340
return res
@@ -381,7 +387,7 @@ def _isscalar(a):
381
387
if type (max ) is int and max >= wrapped_xp .iinfo (x .dtype ).max :
382
388
max = None
383
389
384
- dev = device (x )
390
+ dev = _get_device (x )
385
391
if out is None :
386
392
out = wrapped_xp .empty (result_shape , dtype = x .dtype , device = dev )
387
393
out [()] = x
@@ -599,7 +605,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
599
605
out = xp .sign (x , ** kwargs )
600
606
# CuPy sign() does not propagate nans. See
601
607
# https://github.com/data-apis/array-api-compat/issues/136
602
- if is_cupy_namespace (xp ) and isdtype (x .dtype , 'real floating' , xp = xp ):
608
+ if _is_cupy_namespace (xp ) and isdtype (x .dtype , 'real floating' , xp = xp ):
603
609
out [xp .isnan (x )] = xp .nan
604
610
return out [()]
605
611
@@ -611,3 +617,5 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
611
617
'reshape' , 'argsort' , 'sort' , 'nonzero' , 'ceil' , 'floor' , 'trunc' ,
612
618
'matmul' , 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ,
613
619
'unstack' , 'sign' ]
620
+
621
+ _all_ignore = ['inspect' , 'array_namespace' , 'NamedTuple' ]
0 commit comments