Skip to content

Commit 4bee37d

Browse files
committed
MAINT: common._aliases.__all__
1 parent becc7ab commit 4bee37d

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

array_api_compat/common/_aliases.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import array_namespace, _check_device, device, is_cupy_namespace
15+
from ._helpers import (
16+
array_namespace,
17+
_check_device,
18+
device as _get_device,
19+
is_cupy_namespace as _is_cupy_namespace
20+
)
1621

1722
# These functions are modified from the NumPy versions.
1823

@@ -287,7 +292,7 @@ def cumulative_sum(
287292
initial_shape = list(x.shape)
288293
initial_shape[axis] = 1
289294
res = xp.concatenate(
290-
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
295+
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
291296
axis=axis,
292297
)
293298
return res
@@ -317,7 +322,7 @@ def cumulative_prod(
317322
initial_shape = list(x.shape)
318323
initial_shape[axis] = 1
319324
res = xp.concatenate(
320-
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
325+
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
321326
axis=axis,
322327
)
323328
return res
@@ -369,7 +374,7 @@ def _isscalar(a):
369374
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
370375
max = None
371376

372-
dev = device(x)
377+
dev = _get_device(x)
373378
if out is None:
374379
out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
375380
out[()] = x
@@ -567,7 +572,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
567572
out = xp.sign(x, **kwargs)
568573
# CuPy sign() does not propagate nans. See
569574
# https://github.com/data-apis/array-api-compat/issues/136
570-
if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
575+
if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
571576
out[xp.isnan(x)] = xp.nan
572577
return out[()]
573578

@@ -579,3 +584,5 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
579584
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
580585
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
581586
'unstack', 'sign']
587+
588+
_all_ignore = ['inspect', 'array_namespace', 'NamedTuple']

tests/test_all.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_all(library):
2323
else:
2424
import_(library, wrapper=True)
2525

26-
for mod_name in sys.modules:
26+
for mod_name in sys.modules.copy():
2727
if not mod_name.startswith('array_api_compat.' + library):
2828
continue
2929

0 commit comments

Comments
 (0)