Skip to content

Commit d7f7549

Browse files
committed
ENH: add xp_assert_less
1 parent 054fba0 commit d7f7549

File tree

2 files changed

+68
-16
lines changed

2 files changed

+68
-16
lines changed

src/array_api_extra/_lib/_testing.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def _check_ns_shape_dtype(
7676
"array-ness does not match:\n Actual: "
7777
f"{type(actual)}\n Desired: {type(desired)}"
7878
)
79-
assert (np.isscalar(actual) and np.isscalar(desired)) or (
80-
not np.isscalar(actual) and not np.isscalar(desired)
81-
), _msg
79+
assert np.isscalar(actual) == np.isscalar(desired), _msg
8280

8381
return desired_xp
8482

@@ -139,6 +137,41 @@ def xp_assert_equal(
139137
np.testing.assert_array_equal(actual, desired, err_msg=err_msg)
140138

141139

140+
def xp_assert_less(
141+
x: Array,
142+
y: Array,
143+
*,
144+
err_msg: str = "",
145+
check_dtype: bool = True,
146+
check_shape: bool = True,
147+
check_scalar: bool = False,
148+
) -> None:
149+
"""
150+
Array-API compatible version of `np.testing.assert_array_less`.
151+
152+
Parameters
153+
----------
154+
x, y : Array
155+
The arrays to compare according to ``x < y`` (elementwise).
156+
err_msg : str, optional
157+
Error message to display on failure.
158+
check_dtype, check_shape : bool, default: True
159+
Whether to check agreement between actual and desired dtypes and shapes
160+
check_scalar : bool, default: False
161+
NumPy only: whether to check agreement between actual and desired types -
162+
0d array vs scalar.
163+
164+
See Also
165+
--------
166+
xp_assert_close : Similar function for inexact equality checks.
167+
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
168+
"""
169+
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
170+
x = _prepare_for_test(x, xp)
171+
y = _prepare_for_test(y, xp)
172+
np.testing.assert_array_less(x, y, err_msg=err_msg) # type: ignore[call-overload]
173+
174+
142175
def xp_assert_close(
143176
actual: Array,
144177
desired: Array,
@@ -196,7 +229,6 @@ def xp_assert_close(
196229
desired = _prepare_for_test(desired, xp)
197230

198231
# JAX/Dask arrays work directly with `np.testing`
199-
assert isinstance(rtol, float)
200232
np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue]
201233
actual, # pyright: ignore[reportArgumentType]
202234
desired, # pyright: ignore[reportArgumentType]

tests/test_testing.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import pytest
88

99
from array_api_extra._lib._backends import Backend
10-
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
10+
from array_api_extra._lib._testing import (
11+
xp_assert_close,
12+
xp_assert_equal,
13+
xp_assert_less,
14+
)
1115
from array_api_extra._lib._utils._compat import (
1216
array_namespace,
1317
is_dask_namespace,
@@ -23,6 +27,7 @@
2327
"func",
2428
[
2529
xp_assert_equal,
30+
xp_assert_less,
2631
pytest.param(
2732
xp_assert_close,
2833
marks=pytest.mark.xfail_xp_backend(
@@ -33,7 +38,8 @@
3338
)
3439

3540

36-
@param_assert_equal_close
41+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
42+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
3743
def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
3844
func(xp.asarray(0), xp.asarray(0))
3945
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
@@ -53,8 +59,8 @@ def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): #
5359

5460
@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
5561
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="test other ns vs. numpy")
56-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
57-
def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
62+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
63+
def test_assert_close_equal_less_namespace(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
5864
with pytest.raises(AssertionError, match="namespaces do not match"):
5965
func(xp.asarray(0), np.asarray(0))
6066
with pytest.raises(TypeError, match="Unrecognized array input"):
@@ -65,7 +71,7 @@ def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None])
6571

6672
@param_assert_equal_close
6773
@pytest.mark.parametrize("check_shape", [False, True])
68-
def test_assert_close_equal_shape( # type: ignore[explicit-any]
74+
def test_assert_close_equal_less_shape( # type: ignore[explicit-any]
6975
xp: ModuleType,
7076
func: Callable[..., None],
7177
check_shape: bool,
@@ -76,12 +82,12 @@ def test_assert_close_equal_shape( # type: ignore[explicit-any]
7682
else nullcontext()
7783
)
7884
with context:
79-
func(xp.asarray([0, 0]), xp.asarray(0), check_shape=check_shape)
85+
func(xp.asarray([xp.nan, xp.nan]), xp.asarray(xp.nan), check_shape=check_shape)
8086

8187

8288
@param_assert_equal_close
8389
@pytest.mark.parametrize("check_dtype", [False, True])
84-
def test_assert_close_equal_dtype( # type: ignore[explicit-any]
90+
def test_assert_close_equal_less_dtype( # type: ignore[explicit-any]
8591
xp: ModuleType,
8692
func: Callable[..., None],
8793
check_dtype: bool,
@@ -92,12 +98,17 @@ def test_assert_close_equal_dtype( # type: ignore[explicit-any]
9298
else nullcontext()
9399
)
94100
with context:
95-
func(xp.asarray(0.0), xp.asarray(0), check_dtype=check_dtype)
101+
func(
102+
xp.asarray(xp.nan, dtype=xp.float32),
103+
xp.asarray(xp.nan, dtype=xp.float64),
104+
check_dtype=check_dtype,
105+
)
96106

97107

98-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
108+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
99109
@pytest.mark.parametrize("check_scalar", [False, True])
100-
def test_assert_close_equal_scalar( # type: ignore[explicit-any]
110+
def test_assert_close_equal_less_scalar( # type: ignore[explicit-any]
111+
xp: ModuleType,
101112
func: Callable[..., None],
102113
check_scalar: bool,
103114
):
@@ -107,7 +118,7 @@ def test_assert_close_equal_scalar( # type: ignore[explicit-any]
107118
else nullcontext()
108119
)
109120
with context:
110-
func(np.asarray(0), np.asarray(0)[()], check_scalar=check_scalar)
121+
func(np.asarray(xp.nan), np.asarray(xp.nan)[()], check_scalar=check_scalar)
111122

112123

113124
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
@@ -121,9 +132,18 @@ def test_assert_close_tolerance(xp: ModuleType):
121132
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)
122133

123134

124-
@param_assert_equal_close
135+
def test_assert_less_basic(xp: ModuleType):
136+
xp_assert_less(xp.asarray(-1), xp.asarray(0))
137+
xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
138+
with pytest.raises(AssertionError):
139+
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
140+
with pytest.raises(AssertionError, match="hello"):
141+
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]), err_msg="hello")
142+
143+
125144
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
126145
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
146+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
127147
def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
128148
"""On Dask and other lazy backends, test that a shape with NaN's or None's
129149
can be compared to a real shape.

0 commit comments

Comments
 (0)