Skip to content

Revert __all__ related changes from #82 #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
pip install ruff
# Update output format to enable automatic inline annotations.
- name: Run Ruff
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview .
run: ruff check --output-format=github .
29 changes: 0 additions & 29 deletions array_api_compat/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import wraps
from inspect import signature


def get_xp(xp):
"""
Decorator to automatically replace xp with the corresponding array module.
Expand Down Expand Up @@ -45,31 +44,3 @@ def wrapped_f(*args, **kwargs):
return wrapped_f

return inner


def _get_all_public_members(module, exclude=None, extend_all=False):
"""Get all public members of a module.

Parameters
----------
module : module
The module to get members from.
exclude : callable, optional
A callable that takes a name and returns True if the name should be
excluded from the list of members.
extend_all : bool, optional
If True, extend the module's __all__ attribute with the members of the
module derived from dir(module). To be used for libraries that do not have a complete __all__ list.
"""
members = getattr(module, "__all__", [])

if members and not extend_all:
return members

if exclude is None:
exclude = lambda name: name.startswith("_") # noqa: E731

members = members + [_ for _ in dir(module) if not exclude(_)]

# remove duplicates
return list(set(members))
28 changes: 1 addition & 27 deletions array_api_compat/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1 @@
from ._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
is_cupy_array,
is_dask_array,
is_jax_array,
is_numpy_array,
is_torch_array,
size,
to_device,
)

__all__ = [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"is_cupy_array",
"is_dask_array",
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"size",
"to_device",
]
from ._helpers import * # noqa: F403
11 changes: 11 additions & 0 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def zeros_like(

# The functions here return namedtuples (np.unique() returns a normal
# tuple).

# Note that these named tuples aren't actually part of the standard namespace,
# but I don't see any issue with exporting the names here regardless.
class UniqueAllResult(NamedTuple):
values: ndarray
indices: ndarray
Expand Down Expand Up @@ -545,3 +548,11 @@ def isdtype(
# more strict here to match the type annotation? Note that the
# array_api_strict implementation will be very strict.
return dtype == kind

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
16 changes: 16 additions & 0 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,19 @@ def size(x):
if None in x.shape:
return None
return math.prod(x.shape)

__all__ = [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"is_cupy_array",
"is_dask_array",
"is_jax_array",
"is_numpy_array",
"is_torch_array",
"size",
"to_device",
]

_all_ignore = ['sys', 'math', 'inspect']
10 changes: 8 additions & 2 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
else:
from numpy.core.numeric import normalize_axis_tuple

from ._aliases import matrix_transpose, isdtype
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
from .._internal import get_xp

# These are in the main NumPy namespace but not in numpy.linalg
Expand Down Expand Up @@ -149,4 +149,10 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra
dtype = xp.float64
elif x.dtype == xp.complex64:
dtype = xp.complex128
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))

__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
'trace']
2 changes: 1 addition & 1 deletion array_api_compat/common/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ def __len__(self, /) -> int: ...
SupportsBufferProtocol = Any

Array = Any
Device = Any
Device = Any
153 changes: 7 additions & 146 deletions array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,153 +1,14 @@
import cupy as _cp
from cupy import * # noqa: F401, F403
from cupy import * # noqa: F403

# from cupy import * doesn't overwrite these builtin names
from cupy import abs, max, min, round

from .._internal import _get_all_public_members
from ..common._helpers import (
array_namespace,
device,
get_namespace,
is_array_api_obj,
size,
to_device,
)
from cupy import abs, max, min, round # noqa: F401

# These imports may overwrite names from the import * above.
from ._aliases import (
UniqueAllResult,
UniqueCountsResult,
UniqueInverseResult,
acos,
acosh,
arange,
argsort,
asarray,
asarray_cupy,
asin,
asinh,
astype,
atan,
atan2,
atanh,
bitwise_invert,
bitwise_left_shift,
bitwise_right_shift,
bool,
ceil,
concat,
empty,
empty_like,
eye,
floor,
full,
full_like,
isdtype,
linspace,
matmul,
matrix_transpose,
nonzero,
ones,
ones_like,
permute_dims,
pow,
prod,
reshape,
sort,
std,
sum,
tensordot,
trunc,
unique_all,
unique_counts,
unique_inverse,
unique_values,
var,
vecdot,
zeros,
zeros_like,
)

__all__ = []

__all__ += _get_all_public_members(_cp)

__all__ += [
"abs",
"max",
"min",
"round",
]

__all__ += [
"array_namespace",
"device",
"get_namespace",
"is_array_api_obj",
"size",
"to_device",
]

__all__ += [
"UniqueAllResult",
"UniqueCountsResult",
"UniqueInverseResult",
"acos",
"acosh",
"arange",
"argsort",
"asarray",
"asarray_cupy",
"asin",
"asinh",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_invert",
"bitwise_left_shift",
"bitwise_right_shift",
"bool",
"ceil",
"concat",
"empty",
"empty_like",
"eye",
"floor",
"full",
"full_like",
"isdtype",
"linspace",
"matmul",
"matrix_transpose",
"nonzero",
"ones",
"ones_like",
"permute_dims",
"pow",
"prod",
"reshape",
"sort",
"std",
"sum",
"tensordot",
"trunc",
"unique_all",
"unique_counts",
"unique_inverse",
"unique_values",
"var",
"zeros",
"zeros_like",
]

__all__ += [
"matrix_transpose",
"vecdot",
]
from ._aliases import * # noqa: F403

# See the comment in the numpy __init__.py
__import__(__package__ + ".linalg")
__import__(__package__ + '.linalg')

from ..common._helpers import * # noqa: F401,F403

__array_api_version__ = "2022.12"
__array_api_version__ = '2022.12'
32 changes: 5 additions & 27 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import cupy as cp

from ..common import _aliases
from ..common import _linalg

from .._internal import get_xp

asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
asarray.__doc__ = _aliases._asarray.__doc__
del partial

bool = cp.bool_

Expand Down Expand Up @@ -74,28 +73,7 @@
else:
isdtype = get_xp(cp)(_aliases.isdtype)


cross = get_xp(cp)(_linalg.cross)
outer = get_xp(cp)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(cp)(_linalg.eigh)
qr = get_xp(cp)(_linalg.qr)
slogdet = get_xp(cp)(_linalg.slogdet)
svd = get_xp(cp)(_linalg.svd)
cholesky = get_xp(cp)(_linalg.cholesky)
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
pinv = get_xp(cp)(_linalg.pinv)
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
svdvals = get_xp(cp)(_linalg.svdvals)
diagonal = get_xp(cp)(_linalg.diagonal)
trace = get_xp(cp)(_linalg.trace)

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp.linalg, 'vector_norm'):
vector_norm = cp.linalg.vector_norm
else:
vector_norm = get_xp(cp)(_linalg.vector_norm)
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
2 changes: 1 addition & 1 deletion array_api_compat/cupy/_typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

__all__ = [
"ndarray",
"Device",
"Dtype",
"ndarray",
]

import sys
Expand Down
Loading