12
12
from typing import NamedTuple
13
13
import inspect
14
14
15
- from ._helpers import array_namespace , _check_device , device , is_cupy_namespace
15
+ from ._helpers import (
16
+ array_namespace ,
17
+ _check_device ,
18
+ device as _get_device ,
19
+ is_cupy_namespace as _is_cupy_namespace
20
+ )
16
21
17
22
# These functions are modified from the NumPy versions.
18
23
@@ -287,7 +292,7 @@ def cumulative_sum(
287
292
initial_shape = list (x .shape )
288
293
initial_shape [axis ] = 1
289
294
res = xp .concatenate (
290
- [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
295
+ [wrapped_xp .zeros (shape = initial_shape , dtype = res .dtype , device = _get_device (res )), res ],
291
296
axis = axis ,
292
297
)
293
298
return res
@@ -317,7 +322,7 @@ def cumulative_prod(
317
322
initial_shape = list (x .shape )
318
323
initial_shape [axis ] = 1
319
324
res = xp .concatenate (
320
- [wrapped_xp .ones (shape = initial_shape , dtype = res .dtype , device = device (res )), res ],
325
+ [wrapped_xp .ones (shape = initial_shape , dtype = res .dtype , device = _get_device (res )), res ],
321
326
axis = axis ,
322
327
)
323
328
return res
@@ -369,7 +374,7 @@ def _isscalar(a):
369
374
if type (max ) is int and max >= wrapped_xp .iinfo (x .dtype ).max :
370
375
max = None
371
376
372
- dev = device (x )
377
+ dev = _get_device (x )
373
378
if out is None :
374
379
out = wrapped_xp .empty (result_shape , dtype = x .dtype , device = dev )
375
380
out [()] = x
@@ -567,7 +572,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
567
572
out = xp .sign (x , ** kwargs )
568
573
# CuPy sign() does not propagate nans. See
569
574
# https://github.com/data-apis/array-api-compat/issues/136
570
- if is_cupy_namespace (xp ) and isdtype (x .dtype , 'real floating' , xp = xp ):
575
+ if _is_cupy_namespace (xp ) and isdtype (x .dtype , 'real floating' , xp = xp ):
571
576
out [xp .isnan (x )] = xp .nan
572
577
return out [()]
573
578
@@ -579,3 +584,5 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
579
584
'reshape' , 'argsort' , 'sort' , 'nonzero' , 'ceil' , 'floor' , 'trunc' ,
580
585
'matmul' , 'matrix_transpose' , 'tensordot' , 'vecdot' , 'isdtype' ,
581
586
'unstack' , 'sign' ]
587
+
588
+ _all_ignore = ['inspect' , 'array_namespace' , 'NamedTuple' ]
0 commit comments