Skip to content

Commit 89466a6

Browse files
committed
MAINT: common._aliases.__all__
1 parent 07a3cd4 commit 89466a6

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

Diff for: array_api_compat/common/_aliases.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
import inspect
88
from typing import NamedTuple, Optional, Sequence, Tuple, Union
99

10-
from ._helpers import array_namespace, _check_device, device, is_cupy_namespace
1110
from ._typing import Array, Device, DType, Namespace
11+
from ._helpers import (
12+
array_namespace,
13+
_check_device,
14+
device as _get_device,
15+
is_cupy_namespace as _is_cupy_namespace
16+
)
17+
1218

1319
# These functions are modified from the NumPy versions.
1420

@@ -298,7 +304,7 @@ def cumulative_sum(
298304
initial_shape = list(x.shape)
299305
initial_shape[axis] = 1
300306
res = xp.concatenate(
301-
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
307+
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
302308
axis=axis,
303309
)
304310
return res
@@ -328,7 +334,7 @@ def cumulative_prod(
328334
initial_shape = list(x.shape)
329335
initial_shape[axis] = 1
330336
res = xp.concatenate(
331-
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
337+
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
332338
axis=axis,
333339
)
334340
return res
@@ -381,7 +387,7 @@ def _isscalar(a):
381387
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
382388
max = None
383389

384-
dev = device(x)
390+
dev = _get_device(x)
385391
if out is None:
386392
out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
387393
out[()] = x
@@ -599,7 +605,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
599605
out = xp.sign(x, **kwargs)
600606
# CuPy sign() does not propagate nans. See
601607
# https://github.com/data-apis/array-api-compat/issues/136
602-
if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
608+
if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
603609
out[xp.isnan(x)] = xp.nan
604610
return out[()]
605611

@@ -611,3 +617,5 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
611617
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
612618
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
613619
'unstack', 'sign']
620+
621+
_all_ignore = ['inspect', 'array_namespace', 'NamedTuple']

Diff for: tests/test_all.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_all(library):
3333
else:
3434
import_(library, wrapper=True)
3535

36-
for mod_name in sys.modules:
36+
for mod_name in sys.modules.copy():
3737
if not mod_name.startswith('array_api_compat.' + library):
3838
continue
3939

0 commit comments

Comments
 (0)