diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index f6b7ae25..22deaa23 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -3,9 +3,12 @@ from inspect import getfullargspec from typing import Any, Dict, Optional, Sequence, Tuple, Union -from . import _array_module as xp +from hypothesis import note + +from . import _array_module as xp, xps from . import dtype_helpers as dh from . import shape_helpers as sh +from . import hypothesis_helpers as hh from . import stubs from . import xp as _xp from .typing import Array, DataType, Scalar, ScalarType, Shape @@ -28,6 +31,7 @@ "assert_0d_equals", "assert_fill", "assert_array_elements", + "assert_kw_copy" ] @@ -483,6 +487,48 @@ def assert_fill( assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg +def scalar_eq(s1: Scalar, s2: Scalar) -> bool: + if cmath.isnan(s1): + return cmath.isnan(s2) + else: + return s1 == s2 + + +def assert_kw_copy(func_name, x, out, data, copy): + """ + Assert copy=True/False functionality is respected + + TODO: we're not able to check scalars with this approach + """ + if copy is not None and len(x.shape) > 0: + stype = dh.get_scalar_type(x.dtype) + idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx") + old_value = stype(x[idx]) + scalar_strat = hh.from_dtype(x.dtype).filter( + lambda n: not scalar_eq(n, old_value) + ) + value = data.draw( + scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)), + label="mutating value", + ) + x[idx] = value + note(f"mutated {x=}") + # sanity check + assert_scalar_equals( + "__setitem__", type_=stype, idx=idx, out=stype(x[idx]), expected=value, repr_name="x" + ) + new_out_value = stype(out[idx]) + f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}" + if copy: + assert scalar_eq( + new_out_value, old_value + ), f"{f_out}, but should be {old_value} even after x was mutated" + else: + assert scalar_eq( + new_out_value, value + ), f"{f_out}, but should be {value} after x was mutated" + + def _has_functional_signbit() -> bool: # signbit can be available but not implemented (e.g., in array-api-strict) if not hasattr(_xp, "signbit"): diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 8c504a2a..4235ecf1 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -282,34 +282,7 @@ def test_asarray_arrays(shape, dtypes, data): ph.assert_kw_dtype("asarray", kw_dtype=dtype, out_dtype=out.dtype) ph.assert_shape("asarray", out_shape=out.shape, expected=x.shape) ph.assert_array_elements("asarray", out=out, expected=x, kw=kw) - copy = kw.get("copy", None) - if copy is not None: - stype = dh.get_scalar_type(x.dtype) - idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx") - old_value = stype(x[idx]) - scalar_strat = hh.from_dtype(dtypes.input_dtype).filter( - lambda n: not scalar_eq(n, old_value) - ) - value = data.draw( - scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)), - label="mutating value", - ) - x[idx] = value - note(f"mutated {x=}") - # sanity check - ph.assert_scalar_equals( - "__setitem__", type_=stype, idx=idx, out=stype(x[idx]), expected=value, repr_name="x" - ) - new_out_value = stype(out[idx]) - f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}" - if copy: - assert scalar_eq( - new_out_value, old_value - ), f"{f_out}, but should be {old_value} even after x was mutated" - else: - assert scalar_eq( - new_out_value, value - ), f"{f_out}, but should be {value} after x was mutated" + ph.assert_kw_copy("asarray", x, out, data, kw.get("copy", None)) @given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes)) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index f9642f31..90fd7e26 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -81,7 +81,9 @@ def test_astype(x_dtype, dtype, kw, data): ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype) ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw) # TODO: test values - # TODO: test copy + # Check copy is respected (only if input dtype is same as output dtype) + if dtype == x_dtype: + ph.assert_kw_copy("astype", x, out, data, kw.get("copy", None)) @given(