Skip to content

Commit 40603a9

Browse files
authored
Merge pull request #95 from asmeurer/revert-all-changes2
Revert __all__ related changes from #82
2 parents ab74e4a + a73388d commit 40603a9

24 files changed

+450
-1083
lines changed

.github/workflows/ruff.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ jobs:
1616
pip install ruff
1717
# Update output format to enable automatic inline annotations.
1818
- name: Run Ruff
19-
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview .
19+
run: ruff check --output-format=github .

array_api_compat/_internal.py

-29
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import wraps
66
from inspect import signature
77

8-
98
def get_xp(xp):
109
"""
1110
Decorator to automatically replace xp with the corresponding array module.
@@ -45,31 +44,3 @@ def wrapped_f(*args, **kwargs):
4544
return wrapped_f
4645

4746
return inner
48-
49-
50-
def _get_all_public_members(module, exclude=None, extend_all=False):
51-
"""Get all public members of a module.
52-
53-
Parameters
54-
----------
55-
module : module
56-
The module to get members from.
57-
exclude : callable, optional
58-
A callable that takes a name and returns True if the name should be
59-
excluded from the list of members.
60-
extend_all : bool, optional
61-
If True, extend the module's __all__ attribute with the members of the
62-
module derived from dir(module). To be used for libraries that do not have a complete __all__ list.
63-
"""
64-
members = getattr(module, "__all__", [])
65-
66-
if members and not extend_all:
67-
return members
68-
69-
if exclude is None:
70-
exclude = lambda name: name.startswith("_") # noqa: E731
71-
72-
members = members + [_ for _ in dir(module) if not exclude(_)]
73-
74-
# remove duplicates
75-
return list(set(members))

array_api_compat/common/__init__.py

+1-27
Original file line numberDiff line numberDiff line change
@@ -1,27 +1 @@
1-
from ._helpers import (
2-
array_namespace,
3-
device,
4-
get_namespace,
5-
is_array_api_obj,
6-
is_cupy_array,
7-
is_dask_array,
8-
is_jax_array,
9-
is_numpy_array,
10-
is_torch_array,
11-
size,
12-
to_device,
13-
)
14-
15-
__all__ = [
16-
"array_namespace",
17-
"device",
18-
"get_namespace",
19-
"is_array_api_obj",
20-
"is_cupy_array",
21-
"is_dask_array",
22-
"is_jax_array",
23-
"is_numpy_array",
24-
"is_torch_array",
25-
"size",
26-
"to_device",
27-
]
1+
from ._helpers import * # noqa: F403

array_api_compat/common/_aliases.py

+11
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def zeros_like(
146146

147147
# The functions here return namedtuples (np.unique() returns a normal
148148
# tuple).
149+
150+
# Note that these named tuples aren't actually part of the standard namespace,
151+
# but I don't see any issue with exporting the names here regardless.
149152
class UniqueAllResult(NamedTuple):
150153
values: ndarray
151154
indices: ndarray
@@ -545,3 +548,11 @@ def isdtype(
545548
# more strict here to match the type annotation? Note that the
546549
# array_api_strict implementation will be very strict.
547550
return dtype == kind
551+
552+
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
553+
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
554+
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
555+
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
556+
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
557+
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
558+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/common/_helpers.py

+16
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,19 @@ def size(x):
288288
if None in x.shape:
289289
return None
290290
return math.prod(x.shape)
291+
292+
__all__ = [
293+
"array_namespace",
294+
"device",
295+
"get_namespace",
296+
"is_array_api_obj",
297+
"is_cupy_array",
298+
"is_dask_array",
299+
"is_jax_array",
300+
"is_numpy_array",
301+
"is_torch_array",
302+
"size",
303+
"to_device",
304+
]
305+
306+
_all_ignore = ['sys', 'math', 'inspect']

array_api_compat/common/_linalg.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
else:
1212
from numpy.core.numeric import normalize_axis_tuple
1313

14-
from ._aliases import matrix_transpose, isdtype
14+
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
1515
from .._internal import get_xp
1616

1717
# These are in the main NumPy namespace but not in numpy.linalg
@@ -149,4 +149,10 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra
149149
dtype = xp.float64
150150
elif x.dtype == xp.complex64:
151151
dtype = xp.complex128
152-
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
152+
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
153+
154+
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
155+
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
156+
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
157+
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
158+
'trace']

array_api_compat/common/_typing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def __len__(self, /) -> int: ...
2020
SupportsBufferProtocol = Any
2121

2222
Array = Any
23-
Device = Any
23+
Device = Any

array_api_compat/cupy/__init__.py

+7-146
Original file line numberDiff line numberDiff line change
@@ -1,153 +1,14 @@
1-
import cupy as _cp
2-
from cupy import * # noqa: F401, F403
1+
from cupy import * # noqa: F403
32

43
# from cupy import * doesn't overwrite these builtin names
5-
from cupy import abs, max, min, round
6-
7-
from .._internal import _get_all_public_members
8-
from ..common._helpers import (
9-
array_namespace,
10-
device,
11-
get_namespace,
12-
is_array_api_obj,
13-
size,
14-
to_device,
15-
)
4+
from cupy import abs, max, min, round # noqa: F401
165

176
# These imports may overwrite names from the import * above.
18-
from ._aliases import (
19-
UniqueAllResult,
20-
UniqueCountsResult,
21-
UniqueInverseResult,
22-
acos,
23-
acosh,
24-
arange,
25-
argsort,
26-
asarray,
27-
asarray_cupy,
28-
asin,
29-
asinh,
30-
astype,
31-
atan,
32-
atan2,
33-
atanh,
34-
bitwise_invert,
35-
bitwise_left_shift,
36-
bitwise_right_shift,
37-
bool,
38-
ceil,
39-
concat,
40-
empty,
41-
empty_like,
42-
eye,
43-
floor,
44-
full,
45-
full_like,
46-
isdtype,
47-
linspace,
48-
matmul,
49-
matrix_transpose,
50-
nonzero,
51-
ones,
52-
ones_like,
53-
permute_dims,
54-
pow,
55-
prod,
56-
reshape,
57-
sort,
58-
std,
59-
sum,
60-
tensordot,
61-
trunc,
62-
unique_all,
63-
unique_counts,
64-
unique_inverse,
65-
unique_values,
66-
var,
67-
vecdot,
68-
zeros,
69-
zeros_like,
70-
)
71-
72-
__all__ = []
73-
74-
__all__ += _get_all_public_members(_cp)
75-
76-
__all__ += [
77-
"abs",
78-
"max",
79-
"min",
80-
"round",
81-
]
82-
83-
__all__ += [
84-
"array_namespace",
85-
"device",
86-
"get_namespace",
87-
"is_array_api_obj",
88-
"size",
89-
"to_device",
90-
]
91-
92-
__all__ += [
93-
"UniqueAllResult",
94-
"UniqueCountsResult",
95-
"UniqueInverseResult",
96-
"acos",
97-
"acosh",
98-
"arange",
99-
"argsort",
100-
"asarray",
101-
"asarray_cupy",
102-
"asin",
103-
"asinh",
104-
"astype",
105-
"atan",
106-
"atan2",
107-
"atanh",
108-
"bitwise_invert",
109-
"bitwise_left_shift",
110-
"bitwise_right_shift",
111-
"bool",
112-
"ceil",
113-
"concat",
114-
"empty",
115-
"empty_like",
116-
"eye",
117-
"floor",
118-
"full",
119-
"full_like",
120-
"isdtype",
121-
"linspace",
122-
"matmul",
123-
"matrix_transpose",
124-
"nonzero",
125-
"ones",
126-
"ones_like",
127-
"permute_dims",
128-
"pow",
129-
"prod",
130-
"reshape",
131-
"sort",
132-
"std",
133-
"sum",
134-
"tensordot",
135-
"trunc",
136-
"unique_all",
137-
"unique_counts",
138-
"unique_inverse",
139-
"unique_values",
140-
"var",
141-
"zeros",
142-
"zeros_like",
143-
]
144-
145-
__all__ += [
146-
"matrix_transpose",
147-
"vecdot",
148-
]
7+
from ._aliases import * # noqa: F403
1498

1509
# See the comment in the numpy __init__.py
151-
__import__(__package__ + ".linalg")
10+
__import__(__package__ + '.linalg')
11+
12+
from ..common._helpers import * # noqa: F401,F403
15213

153-
__array_api_version__ = "2022.12"
14+
__array_api_version__ = '2022.12'

array_api_compat/cupy/_aliases.py

+5-27
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
import cupy as cp
66

77
from ..common import _aliases
8-
from ..common import _linalg
9-
108
from .._internal import get_xp
119

1210
asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
1311
asarray.__doc__ = _aliases._asarray.__doc__
12+
del partial
1413

1514
bool = cp.bool_
1615

@@ -74,28 +73,7 @@
7473
else:
7574
isdtype = get_xp(cp)(_aliases.isdtype)
7675

77-
78-
cross = get_xp(cp)(_linalg.cross)
79-
outer = get_xp(cp)(_linalg.outer)
80-
EighResult = _linalg.EighResult
81-
QRResult = _linalg.QRResult
82-
SlogdetResult = _linalg.SlogdetResult
83-
SVDResult = _linalg.SVDResult
84-
eigh = get_xp(cp)(_linalg.eigh)
85-
qr = get_xp(cp)(_linalg.qr)
86-
slogdet = get_xp(cp)(_linalg.slogdet)
87-
svd = get_xp(cp)(_linalg.svd)
88-
cholesky = get_xp(cp)(_linalg.cholesky)
89-
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
90-
pinv = get_xp(cp)(_linalg.pinv)
91-
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
92-
svdvals = get_xp(cp)(_linalg.svdvals)
93-
diagonal = get_xp(cp)(_linalg.diagonal)
94-
trace = get_xp(cp)(_linalg.trace)
95-
96-
# These functions are completely new here. If the library already has them
97-
# (i.e., numpy 2.0), use the library version instead of our wrapper.
98-
if hasattr(cp.linalg, 'vector_norm'):
99-
vector_norm = cp.linalg.vector_norm
100-
else:
101-
vector_norm = get_xp(cp)(_linalg.vector_norm)
76+
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
77+
'acosh', 'asin', 'asinh', 'atan', 'atan2',
78+
'atanh', 'bitwise_left_shift', 'bitwise_invert',
79+
'bitwise_right_shift', 'concat', 'pow']

array_api_compat/cupy/_typing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
__all__ = [
4+
"ndarray",
45
"Device",
56
"Dtype",
6-
"ndarray",
77
]
88

99
import sys

0 commit comments

Comments
 (0)