diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 301a851..5645e39 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -5,10 +5,13 @@ See also ..testing for public testing utilities. """ +from __future__ import annotations + import math from types import ModuleType -from typing import cast +from typing import Any, cast +import numpy as np import pytest from ._utils._compat import ( @@ -16,16 +19,23 @@ is_array_api_strict_namespace, is_cupy_namespace, is_dask_namespace, + is_jax_namespace, + is_numpy_namespace, is_pydata_sparse_namespace, is_torch_namespace, + to_device, ) -from ._utils._typing import Array +from ._utils._typing import Array, Device -__all__ = ["xp_assert_close", "xp_assert_equal"] +__all__ = ["as_numpy_array", "xp_assert_close", "xp_assert_equal", "xp_assert_less"] def _check_ns_shape_dtype( - actual: Array, desired: Array + actual: Array, + desired: Array, + check_dtype: bool, + check_shape: bool, + check_scalar: bool, ) -> ModuleType: # numpydoc ignore=RT03 """ Assert that namespace, shape and dtype of the two arrays match. @@ -47,43 +57,67 @@ def _check_ns_shape_dtype( msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" assert actual_xp == desired_xp, msg - actual_shape = actual.shape - desired_shape = desired.shape - if is_dask_namespace(desired_xp): - # Dask uses nan instead of None for unknown shapes - if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)): - actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)): - desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - - msg = f"shapes do not match: {actual_shape} != f{desired_shape}" - assert actual_shape == desired_shape, msg - - msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}" - assert actual.dtype == desired.dtype, msg + if check_shape: + actual_shape = actual.shape + desired_shape = desired.shape + if is_dask_namespace(desired_xp): + # Dask uses nan instead of None for unknown shapes + if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)): + actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)): + desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + + msg = f"shapes do not match: {actual_shape} != f{desired_shape}" + assert actual_shape == desired_shape, msg + + if check_dtype: + msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}" + assert actual.dtype == desired.dtype, msg + + if is_numpy_namespace(actual_xp) and check_scalar: + # only NumPy distinguishes between scalars and arrays; we do if check_scalar. + _msg = ( + "array-ness does not match:\n Actual: " + f"{type(actual)}\n Desired: {type(desired)}" + ) + assert np.isscalar(actual) == np.isscalar(desired), _msg return desired_xp -def _prepare_for_test(array: Array, xp: ModuleType) -> Array: +def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any] """ - Ensure that the array can be compared with xp.testing or np.testing. - - This involves transferring it from GPU to CPU memory, densifying it, etc. + Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards. """ - if is_torch_namespace(xp): - return array.cpu() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + if is_cupy_namespace(xp): + return xp.asnumpy(array) if is_pydata_sparse_namespace(xp): return array.todense() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + + if is_torch_namespace(xp): + array = to_device(array, "cpu") if is_array_api_strict_namespace(xp): - # Note: we deliberately did not add a `.to_device` method in _typing.pyi - # even if it is required by the standard as many backends don't support it - return array.to_device(xp.Device("CPU_DEVICE")) # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - # Note: nothing to do for CuPy, because it uses a bespoke test function - return array + cpu: Device = xp.Device("CPU_DEVICE") + array = to_device(array, cpu) + if is_jax_namespace(xp): + import jax + # Note: only needed if the transfer guard is enabled + cpu = cast(Device, jax.devices("cpu")[0]) + array = to_device(array, cpu) -def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: + return np.asarray(array) + + +def xp_assert_equal( + actual: Array, + desired: Array, + *, + err_msg: str = "", + check_dtype: bool = True, + check_shape: bool = True, + check_scalar: bool = False, +) -> None: """ Array-API compatible version of `np.testing.assert_array_equal`. @@ -95,34 +129,56 @@ def xp_assert_equal(actual: Array, desired: Array, err_msg: str = "") -> None: The expected array (typically hardcoded). err_msg : str, optional Error message to display on failure. + check_dtype, check_shape : bool, default: True + Whether to check agreement between actual and desired dtypes and shapes + check_scalar : bool, default: False + NumPy only: whether to check agreement between actual and desired types - + 0d array vs scalar. See Also -------- xp_assert_close : Similar function for inexact equality checks. numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ - xp = _check_ns_shape_dtype(actual, desired) - actual = _prepare_for_test(actual, xp) - desired = _prepare_for_test(desired, xp) + xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + actual_np = as_numpy_array(actual, xp=xp) + desired_np = as_numpy_array(desired, xp=xp) + np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg) - if is_cupy_namespace(xp): - xp.testing.assert_array_equal(actual, desired, err_msg=err_msg) - elif is_torch_namespace(xp): - # PyTorch recommends using `rtol=0, atol=0` like this - # to test for exact equality - xp.testing.assert_close( - actual, - desired, - rtol=0, - atol=0, - equal_nan=True, - check_dtype=False, - msg=err_msg or None, - ) - else: - import numpy as np # pylint: disable=import-outside-toplevel - np.testing.assert_array_equal(actual, desired, err_msg=err_msg) +def xp_assert_less( + x: Array, + y: Array, + *, + err_msg: str = "", + check_dtype: bool = True, + check_shape: bool = True, + check_scalar: bool = False, +) -> None: + """ + Array-API compatible version of `np.testing.assert_array_less`. + + Parameters + ---------- + x, y : Array + The arrays to compare according to ``x < y`` (elementwise). + err_msg : str, optional + Error message to display on failure. + check_dtype, check_shape : bool, default: True + Whether to check agreement between actual and desired dtypes and shapes + check_scalar : bool, default: False + NumPy only: whether to check agreement between actual and desired types - + 0d array vs scalar. + + See Also + -------- + xp_assert_close : Similar function for inexact equality checks. + numpy.testing.assert_array_equal : Similar function for NumPy arrays. + """ + xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) + x_np = as_numpy_array(x, xp=xp) + y_np = as_numpy_array(y, xp=xp) + np.testing.assert_array_less(x_np, y_np, err_msg=err_msg) def xp_assert_close( @@ -132,6 +188,9 @@ def xp_assert_close( rtol: float | None = None, atol: float = 0, err_msg: str = "", + check_dtype: bool = True, + check_shape: bool = True, + check_scalar: bool = False, ) -> None: """ Array-API compatible version of `np.testing.assert_allclose`. @@ -148,6 +207,11 @@ def xp_assert_close( Absolute tolerance. Default: 0. err_msg : str, optional Error message to display on failure. + check_dtype, check_shape : bool, default: True + Whether to check agreement between actual and desired dtypes and shapes + check_scalar : bool, default: False + NumPy only: whether to check agreement between actual and desired types - + 0d array vs scalar. See Also -------- @@ -159,40 +223,26 @@ def xp_assert_close( ----- The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`. """ - xp = _check_ns_shape_dtype(actual, desired) - - floating = xp.isdtype(actual.dtype, ("real floating", "complex floating")) - if rtol is None and floating: - # multiplier of 4 is used as for `np.float64` this puts the default `rtol` - # roughly half way between sqrt(eps) and the default for - # `numpy.testing.assert_allclose`, 1e-7 - rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 - elif rtol is None: - rtol = 1e-7 - - actual = _prepare_for_test(actual, xp) - desired = _prepare_for_test(desired, xp) - - if is_cupy_namespace(xp): - xp.testing.assert_allclose( - actual, desired, rtol=rtol, atol=atol, err_msg=err_msg - ) - elif is_torch_namespace(xp): - xp.testing.assert_close( - actual, desired, rtol=rtol, atol=atol, equal_nan=True, msg=err_msg or None - ) - else: - import numpy as np # pylint: disable=import-outside-toplevel - - # JAX/Dask arrays work directly with `np.testing` - assert isinstance(rtol, float) - np.testing.assert_allclose( # type: ignore[call-overload] # pyright: ignore[reportCallIssue] - actual, # pyright: ignore[reportArgumentType] - desired, # pyright: ignore[reportArgumentType] - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) + xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + + if rtol is None: + if xp.isdtype(actual.dtype, ("real floating", "complex floating")): + # multiplier of 4 is used as for `np.float64` this puts the default `rtol` + # roughly half way between sqrt(eps) and the default for + # `numpy.testing.assert_allclose`, 1e-7 + rtol = xp.finfo(actual.dtype).eps ** 0.5 * 4 + else: + rtol = 1e-7 + + actual_np = as_numpy_array(actual, xp=xp) + desired_np = as_numpy_array(desired, xp=xp) + np.testing.assert_allclose( # pyright: ignore[reportCallIssue] + actual_np, + desired_np, + rtol=rtol, # pyright: ignore[reportArgumentType] + atol=atol, + err_msg=err_msg, + ) def xfail( diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 0cee0b4..5f34bd6 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -196,7 +196,6 @@ def test_device(self, xp: ModuleType, device: Device): y = apply_where(x % 2 == 0, x, self.f1, fill_value=x) assert get_device(y) == device - @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype") @pytest.mark.filterwarnings("ignore::RuntimeWarning") # overflows, etc. @hypothesis.settings( # The xp and library fixtures are not regenerated between hypothesis iterations diff --git a/tests/test_testing.py b/tests/test_testing.py index fb9ba58..97585c9 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from contextlib import nullcontext from types import ModuleType from typing import cast @@ -6,13 +7,18 @@ import pytest from array_api_extra._lib._backends import Backend -from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal +from array_api_extra._lib._testing import ( + as_numpy_array, + xp_assert_close, + xp_assert_equal, + xp_assert_less, +) from array_api_extra._lib._utils._compat import ( array_namespace, is_dask_namespace, is_jax_namespace, ) -from array_api_extra._lib._utils._typing import Array +from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function # mypy: disable-error-code=decorated-any @@ -22,15 +28,25 @@ "func", [ xp_assert_equal, + xp_assert_less, pytest.param( xp_assert_close, - marks=pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype"), + marks=pytest.mark.xfail_xp_backend( + Backend.SPARSE, reason="no isdtype", strict=False + ), ), ], ) -@param_assert_equal_close +def test_as_numpy_array(xp: ModuleType, device: Device): + x = xp.asarray([1, 2, 3], device=device) + y = as_numpy_array(x, xp=xp) + assert isinstance(y, np.ndarray) + + +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False) +@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close]) def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any] func(xp.asarray(0), xp.asarray(0)) func(xp.asarray([1, 2]), xp.asarray([1, 2])) @@ -50,8 +66,8 @@ def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # @pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy") @pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="test other ns vs. numpy") -@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close]) -def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any] +@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less]) +def test_assert_close_equal_less_namespace(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any] with pytest.raises(AssertionError, match="namespaces do not match"): func(xp.asarray(0), np.asarray(0)) with pytest.raises(TypeError, match="Unrecognized array input"): @@ -60,6 +76,58 @@ def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None]) func(xp.asarray([0]), [0]) +@param_assert_equal_close +@pytest.mark.parametrize("check_shape", [False, True]) +def test_assert_close_equal_less_shape( # type: ignore[explicit-any] + xp: ModuleType, + func: Callable[..., None], + check_shape: bool, +): + context = ( + pytest.raises(AssertionError, match="shapes do not match") + if check_shape + else nullcontext() + ) + with context: + func(xp.asarray([xp.nan, xp.nan]), xp.asarray(xp.nan), check_shape=check_shape) + + +@param_assert_equal_close +@pytest.mark.parametrize("check_dtype", [False, True]) +def test_assert_close_equal_less_dtype( # type: ignore[explicit-any] + xp: ModuleType, + func: Callable[..., None], + check_dtype: bool, +): + context = ( + pytest.raises(AssertionError, match="dtypes do not match") + if check_dtype + else nullcontext() + ) + with context: + func( + xp.asarray(xp.nan, dtype=xp.float32), + xp.asarray(xp.nan, dtype=xp.float64), + check_dtype=check_dtype, + ) + + +@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less]) +@pytest.mark.parametrize("check_scalar", [False, True]) +def test_assert_close_equal_less_scalar( # type: ignore[explicit-any] + xp: ModuleType, + func: Callable[..., None], + check_scalar: bool, +): + context = ( + pytest.raises(AssertionError, match="array-ness does not match") + if check_scalar + else nullcontext() + ) + with context: + func(np.asarray(xp.nan), np.asarray(xp.nan)[()], check_scalar=check_scalar) + + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype") def test_assert_close_tolerance(xp: ModuleType): xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03) @@ -71,9 +139,18 @@ def test_assert_close_tolerance(xp: ModuleType): xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1) -@param_assert_equal_close +def test_assert_less_basic(xp: ModuleType): + xp_assert_less(xp.asarray(-1), xp.asarray(0)) + xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3])) + with pytest.raises(AssertionError): + xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1])) + with pytest.raises(AssertionError, match="hello"): + xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]), err_msg="hello") + + @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array") @pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing") +@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close]) def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any] """On Dask and other lazy backends, test that a shape with NaN's or None's can be compared to a real shape.