Skip to content

Commit f4657a8

Browse files
committed
Add a test for __all__ self-consistency
1 parent 967c883 commit f4657a8

File tree

8 files changed

+85
-27
lines changed

8 files changed

+85
-27
lines changed

array_api_compat/common/_aliases.py

+3
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def zeros_like(
146146

147147
# The functions here return namedtuples (np.unique() returns a normal
148148
# tuple).
149+
150+
# Note that these named tuples aren't actually part of the standard namespace,
151+
# but I don't see any issue with exporting the names here regardless.
149152
class UniqueAllResult(NamedTuple):
150153
values: ndarray
151154
indices: ndarray

array_api_compat/common/_helpers.py

+2
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,5 @@ def size(x):
302302
"size",
303303
"to_device",
304304
]
305+
306+
_all_ignore = ['sys', 'math', 'inspect']

array_api_compat/dask/array/_aliases.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
# an error with dask
5151

5252
# TODO: delete the xp stuff, it shouldn't be necessary
53-
def dask_arange(
53+
def _dask_arange(
5454
start: Union[int, float],
5555
/,
5656
stop: Optional[Union[int, float]] = None,
@@ -72,7 +72,7 @@ def dask_arange(
7272
args.append(step)
7373
return xp.arange(*args, dtype=dtype, **kwargs)
7474

75-
arange = get_xp(da)(dask_arange)
75+
arange = get_xp(da)(_dask_arange)
7676
eye = get_xp(da)(_aliases.eye)
7777

7878
from functools import partial
@@ -142,4 +142,4 @@ def dask_arange(
142142
'int16', 'int32', 'int64', 'uint8', 'uint16', 'uint32', 'uint64',
143143
'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'result_type']
144144

145-
del da, partial, common_aliases, _da_unsupported,
145+
_all_ignore = ['get_xp', 'da', 'partial', 'common_aliases', 'np']

array_api_compat/dask/array/linalg.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,4 @@ def svdvals(x: ndarray) -> Union[ndarray, Tuple[ndarray, ...]]:
5151
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
5252
"vector_norm", "diagonal"]
5353

54-
del get_xp
55-
del da
56-
del _linalg
54+
_all_ignore = ['get_xp', 'da', 'linalg_all']

array_api_compat/numpy/_aliases.py

+2
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,5 @@
7777
'acosh', 'asin', 'asinh', 'atan', 'atan2',
7878
'atanh', 'bitwise_left_shift', 'bitwise_invert',
7979
'bitwise_right_shift', 'concat', 'pow']
80+
81+
_all_ignore = ['np', 'get_xp']

array_api_compat/torch/_aliases.py

+27-20
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from __future__ import annotations
22

3-
from functools import wraps
4-
from builtins import all as builtin_all, any as builtin_any
3+
from functools import wraps as _wraps
4+
from builtins import all as _builtin_all, any as _builtin_any
55

6-
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
7-
UniqueInverseResult,
8-
matrix_transpose as _aliases_matrix_transpose,
6+
from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
97
vecdot as _aliases_vecdot)
108
from .._internal import get_xp
119

@@ -86,7 +84,7 @@
8684

8785

8886
def _two_arg(f):
89-
@wraps(f)
87+
@_wraps(f)
9088
def _f(x1, x2, /, **kwargs):
9189
x1, x2 = _fix_promotion(x1, x2)
9290
return f(x1, x2, **kwargs)
@@ -509,7 +507,7 @@ def arange(start: Union[int, float],
509507
start, stop = 0, start
510508
if step > 0 and stop <= start or step < 0 and stop >= start:
511509
if dtype is None:
512-
if builtin_all(isinstance(i, int) for i in [start, stop, step]):
510+
if _builtin_all(isinstance(i, int) for i in [start, stop, step]):
513511
dtype = torch.int64
514512
else:
515513
dtype = torch.float32
@@ -601,6 +599,11 @@ def broadcast_arrays(*arrays: array) -> List[array]:
601599
shape = torch.broadcast_shapes(*[a.shape for a in arrays])
602600
return [torch.broadcast_to(a, shape) for a in arrays]
603601

602+
# Note that these named tuples aren't actually part of the standard namespace,
603+
# but I don't see any issue with exporting the names here regardless.
604+
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
605+
UniqueInverseResult)
606+
604607
# https://github.com/pytorch/pytorch/issues/70920
605608
def unique_all(x: array) -> UniqueAllResult:
606609
# torch.unique doesn't support returning indices.
@@ -665,7 +668,7 @@ def isdtype(
665668
for more details
666669
"""
667670
if isinstance(kind, tuple) and _tuple:
668-
return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
671+
return _builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
669672
elif isinstance(kind, str):
670673
if kind == 'bool':
671674
return dtype == torch.bool
@@ -693,15 +696,19 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
693696
axis = 0
694697
return torch.index_select(x, axis, indices, **kwargs)
695698

696-
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis',
697-
'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
698-
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
699-
'floor_divide', 'greater', 'greater_equal', 'less', 'less_equal',
700-
'logaddexp', 'multiply', 'not_equal', 'pow', 'remainder',
701-
'subtract', 'max', 'min', 'sort', 'prod', 'sum', 'any', 'all',
702-
'mean', 'std', 'var', 'concat', 'squeeze', 'broadcast_to', 'flip', 'roll',
703-
'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full',
704-
'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
705-
'broadcast_arrays', 'unique_all', 'unique_counts',
706-
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
707-
'vecdot', 'tensordot', 'isdtype', 'take']
699+
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
700+
'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
701+
'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
702+
'equal', 'floor_divide', 'greater', 'greater_equal', 'less',
703+
'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
704+
'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum',
705+
'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
706+
'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
707+
'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
708+
'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
709+
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
710+
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
711+
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype',
712+
'take']
713+
714+
_all_ignore = ['torch', 'get_xp']

array_api_compat/torch/linalg.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from torch import dtype as Dtype
88
from typing import Optional
99

10+
from ._aliases import _fix_promotion, sum
11+
1012
from torch.linalg import * # noqa: F403
1113

1214
# torch.linalg doesn't define __all__
@@ -16,7 +18,7 @@
1618

1719
# outer is implemented in torch but aren't in the linalg namespace
1820
from torch import outer
19-
from ._aliases import _fix_promotion, matrix_transpose, tensordot, sum
21+
from ._aliases import matrix_transpose, tensordot
2022

2123
# Note: torch.linalg.cross does not default to axis=-1 (it defaults to the
2224
# first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743
@@ -59,4 +61,6 @@ def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> arr
5961
__all__ = linalg_all + ['outer', 'trace', 'matrix_transpose', 'tensordot',
6062
'vecdot', 'solve']
6163

64+
_all_ignore = ['torch_linalg', 'sum']
65+
6266
del linalg_all

tests/test_all.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
Test that files that define __all__ aren't missing any exports.
3+
4+
You can add names that shouldn't be exported to _all_ignore, like
5+
6+
_all_ignore = ['sys']
7+
8+
This is preferable to del-ing the names as this will break any name that is
9+
used inside of a function. Note that names starting with an underscore are automatically ignored.
10+
"""
11+
12+
13+
import sys
14+
15+
from ._helpers import import_
16+
17+
import pytest
18+
19+
@pytest.mark.parametrize("library", ["common", "cupy", "numpy", "torch", "dask.array"])
20+
def test_all(library):
21+
import_(library, wrapper=True)
22+
23+
for mod_name in sys.modules:
24+
if 'array_api_compat.' + library not in mod_name:
25+
continue
26+
27+
module = sys.modules[mod_name]
28+
29+
# TODO: We should define __all__ in the __init__.py files and test it
30+
# there too.
31+
if not hasattr(module, '__all__'):
32+
continue
33+
34+
dir_names = [n for n in dir(module) if not n.startswith('_')]
35+
ignore_all_names = getattr(module, '_all_ignore', [])
36+
ignore_all_names += ['annotations', 'TYPE_CHECKING']
37+
dir_names = set(dir_names) - set(ignore_all_names)
38+
all_names = module.__all__
39+
40+
if set(dir_names) != set(all_names):
41+
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
42+
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"

0 commit comments

Comments
 (0)