Skip to content

Commit 62507f4

Browse files
authored
Merge pull request #286 from ev-br/test_all
Reenable test_all, fix `_aliases.__all__`
2 parents b5a57eb + 71d90ea commit 62507f4

10 files changed

+44
-16
lines changed

array_api_compat/common/_aliases.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
import inspect
88
from typing import NamedTuple, Optional, Sequence, Tuple, Union
99

10-
from ._helpers import array_namespace, _check_device, device, is_cupy_namespace
1110
from ._typing import Array, Device, DType, Namespace
11+
from ._helpers import (
12+
array_namespace,
13+
_check_device,
14+
device as _get_device,
15+
is_cupy_namespace as _is_cupy_namespace
16+
)
17+
1218

1319
# These functions are modified from the NumPy versions.
1420

@@ -298,7 +304,7 @@ def cumulative_sum(
298304
initial_shape = list(x.shape)
299305
initial_shape[axis] = 1
300306
res = xp.concatenate(
301-
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
307+
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
302308
axis=axis,
303309
)
304310
return res
@@ -328,7 +334,7 @@ def cumulative_prod(
328334
initial_shape = list(x.shape)
329335
initial_shape[axis] = 1
330336
res = xp.concatenate(
331-
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
337+
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res],
332338
axis=axis,
333339
)
334340
return res
@@ -381,7 +387,7 @@ def _isscalar(a):
381387
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
382388
max = None
383389

384-
dev = device(x)
390+
dev = _get_device(x)
385391
if out is None:
386392
out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
387393
out[()] = x
@@ -599,7 +605,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
599605
out = xp.sign(x, **kwargs)
600606
# CuPy sign() does not propagate nans. See
601607
# https://github.com/data-apis/array-api-compat/issues/136
602-
if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
608+
if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
603609
out[xp.isnan(x)] = xp.nan
604610
return out[()]
605611

@@ -611,3 +617,5 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array:
611617
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
612618
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
613619
'unstack', 'sign']
620+
621+
_all_ignore = ['inspect', 'array_namespace', 'NamedTuple']

numpy-1-21-xfails.txt

+1
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
192192
array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
193193
array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
194194
array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
195+
array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
195196

196197
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that
197198
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]

numpy-1-26-xfails.txt

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
4646
array_api_tests/test_signatures.py::test_func_signature[bitwise_or]
4747
array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
4848
array_api_tests/test_signatures.py::test_func_signature[bitwise_xor]
49+
array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars
4950

5051
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
5152
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]

tests/test_all.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@
2626
"SupportsBufferProtocol",
2727
))
2828

29-
@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277")
3029
@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
3130
def test_all(library):
3231
if library == "common":
3332
import array_api_compat.common # noqa: F401
3433
else:
3534
import_(library, wrapper=True)
3635

37-
for mod_name in sys.modules:
36+
# NB: iterate over a copy to avoid a "dictionary size changed" error
37+
for mod_name in sys.modules.copy():
3838
if not mod_name.startswith('array_api_compat.' + library):
3939
continue
4040

tests/test_array_namespace.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
import sys
33
import warnings
44

5-
import jax
65
import numpy as np
76
import pytest
8-
import torch
97

108
import array_api_compat
119
from array_api_compat import array_namespace
@@ -76,6 +74,7 @@ def test_array_namespace(library, api_version, use_compat):
7674
subprocess.run([sys.executable, "-c", code], check=True)
7775

7876
def test_jax_zero_gradient():
77+
jax = import_("jax")
7978
jx = jax.numpy.arange(4)
8079
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
8180
assert array_namespace(jax_zero) is array_namespace(jx)
@@ -89,11 +88,13 @@ def test_array_namespace_errors():
8988
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
9089

9190
def test_array_namespace_errors_torch():
91+
torch = import_("torch")
9292
y = torch.asarray([1, 2])
9393
x = np.asarray([1, 2])
9494
pytest.raises(TypeError, lambda: array_namespace(x, y))
9595

9696
def test_api_version_torch():
97+
torch = import_("torch")
9798
x = torch.asarray([1, 2])
9899
torch_ = import_("torch", wrapper=True)
99100
assert array_namespace(x, api_version="2023.12") == torch_
@@ -118,6 +119,7 @@ def test_get_namespace():
118119
assert array_api_compat.get_namespace is array_namespace
119120

120121
def test_python_scalars():
122+
torch = import_("torch")
121123
a = torch.asarray([1, 2])
122124
xp = import_("torch", wrapper=True)
123125

tests/test_dask.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from contextlib import contextmanager
22

33
import array_api_strict
4-
import dask
54
import numpy as np
65
import pytest
7-
import dask.array as da
6+
7+
try:
8+
import dask
9+
import dask.array as da
10+
except ImportError:
11+
pytestmark = pytest.skip(allow_module_level=True, reason="dask not found")
812

913
from array_api_compat import array_namespace
1014

tests/test_jax.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
import jax
2-
import jax.numpy as jnp
31
from numpy.testing import assert_equal
42
import pytest
53

64
from array_api_compat import device, to_device
75

6+
try:
7+
import jax
8+
import jax.numpy as jnp
9+
except ImportError:
10+
pytestmark = pytest.skip(allow_module_level=True, reason="jax not found")
11+
812
HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31"
913

1014

tests/test_torch.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import itertools
44

55
import pytest
6-
import torch
6+
7+
try:
8+
import torch
9+
except ImportError:
10+
pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found")
711

812
from array_api_compat import torch as xp
913

tests/test_vendoring.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ def test_vendoring_cupy():
1616

1717

1818
def test_vendoring_torch():
19+
pytest.importorskip("torch")
1920
from vendor_test import uses_torch
2021

2122
uses_torch._test_torch()
2223

2324

2425
def test_vendoring_dask():
26+
pytest.importorskip("dask")
2527
from vendor_test import uses_dask
2628
uses_dask._test_dask()

torch-xfails.txt

+4-2
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,12 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc
144144
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp]
145145
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum]
146146
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum]
147+
148+
# https://github.com/pytorch/pytorch/issues/149815
147149
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal]
148-
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[neq]
150+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[not_equal]
149151
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less]
150-
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[les_equal]
152+
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less_equal]
151153
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater]
152154
array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal]
153155

0 commit comments

Comments
 (0)