Skip to content

Commit cddc9ef

Browse files
authored
ENH: Review exported symbols; redesign test_all (#315)
Review and discussion at #315
1 parent 2b559e6 commit cddc9ef

23 files changed

+435
-197
lines changed

array_api_compat/_internal.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Internal helpers
33
"""
44

5+
import importlib
56
from collections.abc import Callable
67
from functools import wraps
78
from inspect import signature
@@ -52,8 +53,25 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
5253
return inner
5354

5455

55-
__all__ = ["get_xp"]
56+
def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]:
57+
"""Import everything from module, updating globals().
58+
Returns __all__.
59+
"""
60+
mod = importlib.import_module(mod_name)
61+
# Neither of these two methods is sufficient by itself,
62+
# depending on various idiosyncrasies of the libraries we're wrapping.
63+
objs = {}
64+
exec(f"from {mod.__name__} import *", objs)
65+
66+
for n in dir(mod):
67+
if not n.startswith("_") and hasattr(mod, n):
68+
objs[n] = getattr(mod, n)
69+
70+
globals_.update(objs)
71+
return list(objs)
72+
5673

74+
__all__ = ["get_xp", "clone_module"]
5775

5876
def __dir__() -> list[str]:
5977
return __all__

array_api_compat/common/_aliases.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
721721
"finfo",
722722
"iinfo",
723723
]
724-
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
725-
726724

727725
def __dir__() -> list[str]:
728726
return __all__

array_api_compat/common/_helpers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,5 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
10621062
"to_device",
10631063
]
10641064

1065-
_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings']
1066-
10671065
def __dir__() -> list[str]:
10681066
return __all__

array_api_compat/common/_linalg.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,6 @@ def trace(
225225
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
226226
'trace']
227227

228-
_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']
229-
230228

231229
def __dir__() -> list[str]:
232230
return __all__

array_api_compat/cupy/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
1+
from typing import Final
12
from cupy import * # noqa: F403
23

34
# from cupy import * doesn't overwrite these builtin names
45
from cupy import abs, max, min, round # noqa: F401
56

67
# These imports may overwrite names from the import * above.
78
from ._aliases import * # noqa: F403
9+
from ._info import __array_namespace_info__ # noqa: F401
810

911
# See the comment in the numpy __init__.py
1012
__import__(__package__ + '.linalg')
1113
__import__(__package__ + '.fft')
1214

13-
__array_api_version__ = '2024.12'
15+
__array_api_version__: Final = '2024.12'
16+
17+
__all__ = sorted(
18+
{name for name in globals() if not name.startswith("__")}
19+
- {"Final", "_aliases", "_info", "_typing"}
20+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
21+
)
22+
23+
def __dir__() -> list[str]:
24+
return __all__

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from ..common import _aliases, _helpers
88
from ..common._typing import NestedSequence, SupportsBufferProtocol
99
from .._internal import get_xp
10-
from ._info import __array_namespace_info__
1110
from ._typing import Array, Device, DType
1211

1312
bool = cp.bool_
@@ -141,7 +140,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array:
141140
else:
142141
unstack = get_xp(cp)(_aliases.unstack)
143142

144-
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
143+
__all__ = _aliases.__all__ + ['asarray', 'astype',
145144
'acos', 'acosh', 'asin', 'asinh', 'atan',
146145
'atan2', 'atanh', 'bitwise_left_shift',
147146
'bitwise_invert', 'bitwise_right_shift',

array_api_compat/cupy/_typing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
__all__ = ["Array", "DType", "Device"]
4-
_all_ignore = ["cp"]
54

65
from typing import TYPE_CHECKING
76

array_api_compat/cupy/fft.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
__all__ = fft_all + _fft.__all__
3333

34-
del get_xp
35-
del cp
36-
del fft_all
37-
del _fft
34+
def __dir__() -> list[str]:
35+
return __all__
36+

array_api_compat/cupy/linalg.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,5 @@
4343

4444
__all__ = linalg_all + _linalg.__all__
4545

46-
del get_xp
47-
del cp
48-
del linalg_all
49-
del _linalg
46+
def __dir__() -> list[str]:
47+
return __all__
Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
11
from typing import Final
22

3-
from dask.array import * # noqa: F403
3+
from ..._internal import clone_module
4+
5+
__all__ = clone_module("dask.array", globals())
46

57
# These imports may overwrite names from the import * above.
8+
from . import _aliases
69
from ._aliases import * # type: ignore[assignment] # noqa: F403
10+
from ._info import __array_namespace_info__ # noqa: F401
711

812
__array_api_version__: Final = "2024.12"
13+
del Final
914

1015
# See the comment in the numpy __init__.py
1116
__import__(__package__ + '.linalg')
1217
__import__(__package__ + '.fft')
18+
19+
__all__ = sorted(
20+
set(__all__)
21+
| set(_aliases.__all__)
22+
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
23+
)
24+
25+
def __dir__() -> list[str]:
26+
return __all__

0 commit comments

Comments
 (0)