Skip to content

Commit 206baa2

Browse files
committed
Test xpx.testing
1 parent 57a38f9 commit 206baa2

File tree

4 files changed

+90
-11
lines changed

4 files changed

+90
-11
lines changed

pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ checks = [
315315
exclude = [ # don't report on objects that match any of these regex
316316
'.*test_at.*',
317317
'.*test_funcs.*',
318+
'.*test_testing.*',
318319
'.*test_utils.*',
319320
'.*test_version.*',
320321
'.*test_vendor.*',

src/array_api_extra/testing.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,42 @@
66
is_pydata_sparse_namespace,
77
is_torch_namespace,
88
)
9-
from ._lib._typing import Array
9+
from ._lib._typing import Array, ModuleType
1010

1111
__all__ = ["xp_assert_close", "xp_assert_equal"]
1212

1313

14-
def _check_shape_dtype(actual: Array, desired: Array) -> None:
14+
def _check_ns_shape_dtype(
15+
actual: Array, desired: Array
16+
) -> ModuleType: # numpydoc ignore=RT03
1517
"""
16-
Assert that shape and dtype of the two arrays match.
18+
Assert that namespace, shape and dtype of the two arrays match.
1719
1820
Parameters
1921
----------
2022
actual : Array
2123
The array produced by the tested function.
2224
desired : Array
2325
The expected array (typically hardcoded).
26+
27+
Returns
28+
-------
29+
Arrays namespace.
2430
"""
31+
actual_xp = array_namespace(actual) # Raises on scalars and lists
32+
desired_xp = array_namespace(desired)
33+
34+
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
35+
assert actual_xp == desired_xp, msg
36+
2537
msg = f"shapes do not match: {actual.shape} != f{desired.shape}"
2638
assert actual.shape == desired.shape, msg
2739

28-
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}".format(
29-
actual.dtype, desired.dtype
30-
)
40+
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
3141
assert actual.dtype == desired.dtype, msg
3242

43+
return desired_xp
44+
3345

3446
def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
3547
"""
@@ -44,8 +56,7 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None:
4456
err_msg : str, optional
4557
Error message to display on failure.
4658
"""
47-
xp = array_namespace(actual, desired)
48-
_check_shape_dtype(actual, desired)
59+
xp = _check_ns_shape_dtype(actual, desired)
4960

5061
if is_cupy_namespace(xp):
5162
xp.testing.assert_array_equal(actual, desired, err_msg=err_msg)
@@ -96,8 +107,7 @@ def xp_assert_close(
96107
err_msg : str, optional
97108
Error message to display on failure.
98109
"""
99-
xp = array_namespace(actual, desired)
100-
_check_shape_dtype(actual, desired)
110+
xp = _check_ns_shape_dtype(actual, desired)
101111

102112
floating = xp.isdtype(actual.dtype, ("real floating", "complex floating"))
103113
if rtol is None and floating:

tests/test_testing.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import numpy as np
2+
import pytest
3+
4+
from array_api_extra.testing import xp_assert_close, xp_assert_equal
5+
6+
from .conftest import Library
7+
8+
# mypy: disable-error-code=no-any-decorated
9+
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
10+
11+
12+
@pytest.mark.parametrize(
13+
"func",
14+
[
15+
xp_assert_equal,
16+
pytest.param(
17+
xp_assert_close,
18+
marks=pytest.mark.skip_xp_backend(Library.SPARSE, reason="no isdtype"),
19+
),
20+
],
21+
)
22+
def test_assert_close_equal_basic(xp, func):
23+
func(xp.asarray(0), xp.asarray(0))
24+
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
25+
26+
with pytest.raises(AssertionError, match="shapes do not match"):
27+
func(xp.asarray([0]), xp.asarray([[0]]))
28+
29+
with pytest.raises(AssertionError, match="dtypes do not match"):
30+
func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64))
31+
32+
with pytest.raises(AssertionError):
33+
func(xp.asarray([1, 2]), xp.asarray([1, 3]))
34+
35+
with pytest.raises(AssertionError, match="hello"):
36+
func(xp.asarray([1, 2]), xp.asarray([1, 3]), err_msg="hello")
37+
38+
39+
@pytest.mark.skip_xp_backend(Library.NUMPY)
40+
@pytest.mark.skip_xp_backend(Library.NUMPY_READONLY)
41+
@pytest.mark.parametrize(
42+
"func",
43+
[
44+
xp_assert_equal,
45+
pytest.param(
46+
xp_assert_close,
47+
marks=pytest.mark.skip_xp_backend(Library.SPARSE, reason="no isdtype"),
48+
),
49+
],
50+
)
51+
def test_assert_close_equal_namespace(xp, func):
52+
with pytest.raises(AssertionError):
53+
func(xp.asarray(0), np.asarray(0))
54+
with pytest.raises(TypeError):
55+
func(xp.asarray(0), 0)
56+
with pytest.raises(TypeError):
57+
func(xp.asarray([0]), [0])
58+
59+
60+
@pytest.mark.skip_xp_backend(Library.SPARSE, reason="no isdtype")
61+
def test_assert_close_tolerance(xp):
62+
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03)
63+
with pytest.raises(AssertionError):
64+
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.01)
65+
66+
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=3)
67+
with pytest.raises(AssertionError):
68+
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)

0 commit comments

Comments
 (0)