diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 950b6d4c..d9fdabd5 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -64,7 +64,7 @@ def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]: @wraps(xps.arrays) -def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]: +def arrays_no_scalars(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]: """xps.arrays() without the crazy large numbers.""" if isinstance(dtype, SearchStrategy): return dtype.flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs)) @@ -77,6 +77,19 @@ def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]: return xps.arrays(dtype, *args, elements=elements, **kwargs) +def _f(a, flag): + return a[()] if a.ndim==0 and flag else a + + +@wraps(xps.arrays) +def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]: + """xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars. + + Is only relevant for numpy: on all other libraries, array[()] is no-op. + """ + return builds(_f, arrays_no_scalars(dtype, *args, elements=elements, **kwargs), booleans()) + + _dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes] _sorted_dtypes = [d for category in _dtype_categories for d in category] diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 60014b74..8c504a2a 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -263,7 +263,8 @@ def scalar_eq(s1: Scalar, s2: Scalar) -> bool: data=st.data(), ) def test_asarray_arrays(shape, dtypes, data): - x = data.draw(hh.arrays(dtype=dtypes.input_dtype, shape=shape), label="x") + # generate arrays only since we draw the copy= kwd below (and np.asarray(scalar, copy=False) error out) + x = data.draw(hh.arrays_no_scalars(dtype=dtypes.input_dtype, shape=shape), label="x") dtypes_strat = st.just(dtypes.input_dtype) if dtypes.input_dtype == dtypes.result_dtype: dtypes_strat |= st.none()