Skip to content

2023.12 coverage #255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions array_api_tests/test_inspection_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
from hypothesis import given, strategies as st

from . import xp

pytestmark = pytest.mark.min_version("2023.12")


def test_array_namespace_info():
out = xp.__array_namespace_info__()

capabilities = out.capabilities()
assert isinstance(capabilities, dict)

out.default_device()

default_dtypes = out.default_dtypes()
assert isinstance(default_dtypes, dict)
assert {"real floating", "complex floating", "integral", "indexing"}.issubset(set(default_dtypes.keys()))

devices = out.devices()
assert isinstance(devices, list)


atomic_kinds = [
"bool",
"signed integer",
"unsigned integer",
"real floating",
"complex floating",
]


@given(
st.one_of(
st.none(),
st.sampled_from(atomic_kinds + ["integral", "numeric"]),
st.lists(st.sampled_from(atomic_kinds), unique=True, min_size=1).map(tuple),
)
)
def test_array_namespace_info_dtypes(kind):
out = xp.__array_namespace_info__().dtypes(kind=kind)
assert isinstance(out, dict)
58 changes: 58 additions & 0 deletions array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,31 @@ def test_expand_dims(x, axis):
)


@pytest.mark.min_version("2023.12")
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)), data=st.data())
def test_moveaxis(x, data):
source = data.draw(
st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim), label="source"
)
if isinstance(source, int):
destination = data.draw(st.integers(-x.ndim, x.ndim - 1), label="destination")
else:
assert isinstance(source, tuple) # sanity check
destination = data.draw(
st.lists(
st.integers(-x.ndim, x.ndim - 1),
min_size=len(source),
max_size=len(source),
unique_by=lambda n: n if n >= 0 else x.ndim + n,
).map(tuple),
label="destination"
)

out = xp.moveaxis(x, source, destination)

ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype)
# TODO: shape and values testing

@pytest.mark.unvectorized
@given(
x=hh.arrays(
Expand Down Expand Up @@ -253,6 +278,20 @@ def reshape_shapes(draw, shape):
return tuple(rshape)


@pytest.mark.min_version("2023.12")
@given(
x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)),
repeats=st.integers(1, 4),
)
def test_repeat(x, repeats):
# TODO: test array repeats and non-None axis, adjust shape and value testing accordingly
out = xp.repeat(x, repeats)
ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
expected_shape = (math.prod(x.shape) * repeats,)
ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape)
# TODO: values testing


@pytest.mark.unvectorized
@pytest.mark.skip("flaky") # TODO: fix!
@given(
Expand Down Expand Up @@ -371,3 +410,22 @@ def test_stack(shape, dtypes, kw, data):
out_val=out[out_idx],
kw=kw,
)


@pytest.mark.min_version("2023.12")
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data())
def test_tile(x, data):
repetitions = data.draw(st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple), label="repetitions")
out = xp.tile(x, repetitions)
ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype)
# TODO: shapes and values testing


@pytest.mark.min_version("2023.12")
@given(x=hh.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_dims=1)), data=st.data())
def test_unstack(x, data):
axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis")
kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw")
out = xp.asarray(xp.unstack(x, **kw), dtype=x.dtype)
ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=out.dtype)
# TODO: shapes and values testing
55 changes: 55 additions & 0 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,16 @@ def test_ceil(x):
unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True)


@pytest.mark.min_version("2023.12")
@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
def test_clip(x):
# TODO: test min/max kwargs, adjust values testing accordingly
out = xp.clip(x)
ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("clip", out_shape=out.shape, expected=x.shape)
ph.assert_array_elements("clip", out=out, expected=x)


if api_version >= "2022.12":

@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
Expand All @@ -943,6 +953,15 @@ def test_conj(x):
unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_copysign(x1, x2):
out = xp.copysign(x1, x2)
ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
# TODO: values testing


@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
def test_cos(x):
out = xp.cos(x)
Expand Down Expand Up @@ -1095,6 +1114,15 @@ def test_greater_equal(ctx, data):
)


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_hypot(x1, x2):
out = xp.hypot(x1, x2)
ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)


if api_version >= "2022.12":

@given(hh.arrays(dtype=xps.complex_dtypes(), shape=hh.shapes()))
Expand Down Expand Up @@ -1261,6 +1289,24 @@ def test_logical_xor(x1, x2):
)


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_maximum(x1, x2):
out = xp.maximum(x1, x2)
ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True)


@pytest.mark.min_version("2023.12")
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
def test_minimum(x1, x2):
out = xp.minimum(x1, x2)
ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True)


@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
@given(data=st.data())
def test_multiply(ctx, data):
Expand Down Expand Up @@ -1380,6 +1426,15 @@ def test_round(x):
unary_assert_against_refimpl("round", x, out, round, strict_check=True)


@pytest.mark.min_version("2023.12")
@given(hh.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes()))
def test_signbit(x):
out = xp.signbit(x)
ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape)
# TODO: values testing


@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw))
def test_sign(x):
out = xp.sign(x)
Expand Down
38 changes: 37 additions & 1 deletion array_api_tests/test_searching_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math

import pytest
from hypothesis import given
from hypothesis import given, note
from hypothesis import strategies as st

from . import _array_module as xp
Expand Down Expand Up @@ -167,3 +167,39 @@ def test_where(shapes, dtypes, data):
out_repr=f"out[{idx}]",
out_val=out[idx]
)


@pytest.mark.min_version("2023.12")
@given(data=st.data())
def test_searchsorted(data):
# TODO: test side="right"
_x1 = data.draw(
st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
label="_x1",
)
x1 = xp.asarray(_x1, dtype=dh.default_float)
if data.draw(st.booleans(), label="use sorter?"):
sorter = data.draw(
st.permutations(_x1).map(lambda o: xp.asarray(o, dtype=dh.default_float)),
label="sorter",
)
else:
sorter = None
x1 = xp.sort(x1)
note(f"{x1=}")
x2 = data.draw(
st.lists(st.sampled_from(_x1), unique=True, min_size=1).map(
lambda o: xp.asarray(o, dtype=dh.default_float)
),
label="x2",
)

out = xp.searchsorted(x1, x2, sorter=sorter)

ph.assert_dtype(
"searchsorted",
in_dtype=[x1.dtype, x2.dtype],
out_dtype=out.dtype,
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
)
# TODO: shapes and values testing
10 changes: 10 additions & 0 deletions array_api_tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
from .typing import DataType


@pytest.mark.min_version("2023.12")
@given(hh.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(min_dims=1, max_dims=1)))
def test_cumulative_sum(x):
# TODO: test kwargs + diff shapes, adjust shape and values testing accordingly
out = xp.cumulative_sum(x)
# TODO: assert dtype
ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=x.shape)
# TODO: assert values


def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
Expand Down
Loading